--- /dev/null
+import logging
+import os
+import signal
+
+
+log = logging.getLogger(__name__)
+
+
+class Exiter(object):
+ """
+ A helper to manage any signal handlers we need to call upon receiving a
+ given signal
+ """
+ def __init__(self):
+ self.handlers = list()
+
+ def add_handler(self, signals, func):
+ """
+ Adds a handler function to be called when any of the given signals are
+ received.
+
+ The handler function should have a signature like::
+
+ my_handler(signal, frame)
+ """
+ if type(signals) is int:
+ signals = [signals]
+
+ for signal_ in signals:
+ signal.signal(signal_, self.default_handler)
+
+ handler = Handler(self, func, signals)
+ log.debug(
+ "Installing handler: %s",
+ repr(handler),
+ )
+ self.handlers.append(handler)
+ return handler
+
+ def default_handler(self, signal_, frame):
+ log.debug(
+ "Got signal %s; running %s handler%s...",
+ signal_,
+ len(self.handlers),
+ '' if len(self.handlers) == 1 else 's',
+ )
+ for handler in self.handlers:
+ handler.func(signal_, frame)
+ log.debug("Finished running handlers")
+ # Restore the default handler
+ signal.signal(signal_, 0)
+ # Re-send the signal to our main process
+ os.kill(os.getpid(), signal_)
+
+
+class Handler(object):
+ def __init__(self, exiter, func, signals):
+ self.exiter = exiter
+ self.func = func
+ self.signals = signals
+
+ def remove(self):
+ try:
+ log.debug("Removing handler: %s", self)
+ self.exiter.handlers.remove(self)
+ except ValueError:
+ pass
+
+ def __repr__(self):
+ return "{c}(exiter={e}, func={f}, signals={s})".format(
+ c=self.__class__.__name__,
+ e=self.exiter,
+ f=self.func,
+ s=self.signals,
+ )
+
+
+exiter = Exiter()
--- /dev/null
+import os
+import random
+import signal
+
+from mock import patch, Mock
+
+from teuthology import exit
+
+
+class TestExiter(object):
+ klass = exit.Exiter
+
+ def setup(self):
+ self.pid = os.getpid()
+
+ # Below, we patch os.kill() in such a way that the first time it is
+ # invoked it does actually send the signal. Any subsequent invocation
+ # won't send any signal - this is so we don't kill the process running
+ # our unit tests!
+ self.patcher_kill = patch(
+ 'teuthology.exit.os.kill',
+ wraps=os.kill,
+ )
+
+ self.m_kill = self.patcher_kill.start()
+
+ def m_kill_unwrap(pid, sig):
+ # Setting return_value of a mocked object disables the wrapping
+ if self.m_kill.call_count > 1:
+ self.m_kill.return_value = None
+
+ self.m_kill.side_effect = m_kill_unwrap
+
+ def teardown(self):
+ self.patcher_kill.stop()
+ del self.m_kill
+
+ def test_noop(self):
+ sig = 15
+ obj = self.klass()
+ assert len(obj.handlers) == 0
+ assert signal.getsignal(sig) == 0
+
+ def test_basic(self):
+ sig = 15
+ obj = self.klass()
+ m_func = Mock()
+ obj.add_handler(sig, m_func)
+ assert len(obj.handlers) == 1
+ os.kill(self.pid, sig)
+ assert m_func.call_count == 1
+ assert self.m_kill.call_count == 2
+ for arg_list in self.m_kill.call_args_list:
+ assert arg_list[0] == (self.pid, sig)
+
+ def test_remove_handlers(self):
+ sig = [1, 15]
+ send_sig = random.choice(sig)
+ n = 3
+ obj = self.klass()
+ handlers = list()
+ for i in range(n):
+ m_func = Mock(name="handler %s" % i)
+ handlers.append(obj.add_handler(sig, m_func))
+ assert obj.handlers == handlers
+ for handler in handlers:
+ handler.remove()
+ assert obj.handlers == list()
+ os.kill(self.pid, send_sig)
+ assert self.m_kill.call_count == 2
+ for handler in handlers:
+ assert handler.func.call_count == 0
+
+ def test_n_handlers(self, n=10, sig=11):
+ if type(sig) is int:
+ send_sig = sig
+ else:
+ send_sig = random.choice(sig)
+ obj = self.klass()
+ handlers = list()
+ for i in range(n):
+ m_func = Mock(name="handler %s" % i)
+ handlers.append(obj.add_handler(sig, m_func))
+ assert obj.handlers == handlers
+ os.kill(self.pid, send_sig)
+ for i in range(n):
+ assert handlers[i].func.call_count == 1
+ assert self.m_kill.call_count == 2
+ for arg_list in self.m_kill.call_args_list:
+ assert arg_list[0] == (self.pid, send_sig)
+
+ def test_multiple_signals(self):
+ self.test_n_handlers(n=3, sig=[1, 6, 11, 15])