]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/ssh: Fix Promises
authorSebastian Wagner <sebastian.wagner@suse.com>
Mon, 11 Nov 2019 09:50:21 +0000 (10:50 +0100)
committerSebastian Wagner <sebastian.wagner@suse.com>
Wed, 27 Nov 2019 12:39:11 +0000 (13:39 +0100)
Signed-off-by: Sebastian Wagner <sebastian.wagner@suse.com>
src/pybind/mgr/orchestrator.py
src/pybind/mgr/rook/module.py
src/pybind/mgr/selftest/module.py
src/pybind/mgr/ssh/module.py
src/pybind/mgr/ssh/tests/fixtures.py [new file with mode: 0644]
src/pybind/mgr/ssh/tests/test_completion.py [new file with mode: 0644]
src/pybind/mgr/ssh/tests/test_ssh.py
src/pybind/mgr/test_orchestrator/module.py
src/pybind/mgr/tests/__init__.py
src/pybind/mgr/tests/test_orchestrator.py
src/pybind/mgr/tox.ini

index 2ba4f1f538f5e8115e03647ac237538fa03e0382..aa0f6c2a9d9443e72b1976601406aceed9c6a1b0 100644 (file)
@@ -165,16 +165,20 @@ class _Promise(object):
     helper to build orchestrator modules.
     """
     INITIALIZED = 1  # We have a parent completion and a next completion
-    FINISHED = 2  # we have a final result
+    RUNNING = 2
+    FINISHED = 3  # we have a final result
 
     NO_RESULT = _no_result()  # type: None
+    ASYNC_RESULT = object()
 
     def __init__(self,
                  _first_promise=None,  # type: Optional["_Promise"]
                  value=NO_RESULT,  # type: Optional
-                 on_complete=None    # type: Optional[Callable]
+                 on_complete=None,    # type: Optional[Callable]
+                 name=None,  # type: Optional[str]
                  ):
         self._on_complete = on_complete
+        self._name = name
         self._next_promise = None  # type: Optional[_Promise]
 
         self._state = self.INITIALIZED
@@ -188,10 +192,10 @@ class _Promise(object):
         self._first_promise = _first_promise or self  # type: 'Completion'
 
     def __repr__(self):
-        name = getattr(self._on_complete, '__name__', '??') if self._on_complete else 'None'
+        name = self._name or getattr(self._on_complete, '__name__', '??') if self._on_complete else 'None'
         val = repr(self._value) if self._value is not self.NO_RESULT else 'NA'
-        return '{}(_s={}, val={}, id={}, name={}, pr={}, _next={})'.format(
-            self.__class__, self._state, val, id(self), name, getattr(next, '_progress_reference', 'NA'), repr(self._next_promise)
+        return '{}(_s={}, val={}, _on_c={}, id={}, name={}, pr={}, _next={})'.format(
+            self.__class__, self._state, val, self._on_complete, id(self), name, getattr(next, '_progress_reference', 'NA'), repr(self._next_promise)
         )
 
     def then(self, on_complete):
@@ -199,7 +203,7 @@ class _Promise(object):
         """
         Call ``on_complete`` as soon as this promise is finalized.
         """
-        assert self._state is self.INITIALIZED
+        assert self._state in (self.INITIALIZED, self.RUNNING)
         if self._on_complete is not None:
             assert self._next_promise is None
             self._set_next_promise(self.__class__(
@@ -216,14 +220,14 @@ class _Promise(object):
     def _set_next_promise(self, next):
         # type: (_Promise) -> None
         assert self is not next
-        assert self._state is self.INITIALIZED
+        assert self._state in (self.INITIALIZED, self.RUNNING)
 
         self._next_promise = next
         assert self._next_promise is not None
         for p in iter(self._next_promise):
             p._first_promise = self._first_promise
 
-    def finalize(self, value=NO_RESULT):
+    def _finalize(self, value=NO_RESULT):
         """
         Sets this promise to complete.
 
@@ -231,11 +235,13 @@ class _Promise(object):
 
         :param value: new value.
         """
-        assert self._state is self.INITIALIZED
+        assert self._state in (self.INITIALIZED, self.RUNNING)
+
+        self._state = self.RUNNING
 
         if value is not self.NO_RESULT:
             self._value = value
-        assert self._value is not self.NO_RESULT
+        assert self._value is not self.NO_RESULT, repr(self)
 
         if self._on_complete:
             try:
@@ -255,21 +261,25 @@ class _Promise(object):
             self._set_next_promise(next_result)
             if self._next_promise._value is self.NO_RESULT:
                 self._next_promise._value = self._value
-        else:
+            self.propagate_to_next()
+        elif next_result is not self.ASYNC_RESULT:
             # simple map. simply forward
             if self._next_promise:
                 self._next_promise._value = next_result
             else:
                 # Hack: next_result is of type U, _value is of type T
                 self._value = next_result  # type: ignore
-        self._state = self.FINISHED
-        logger.debug('finalized {}'.format(repr(self)))
-        self.propagate_to_next()
+            self.propagate_to_next()
+        else:
+            # asynchronous promise
+            pass
+
 
     def propagate_to_next(self):
-        assert self._state is self.FINISHED
+        self._state = self.FINISHED
+        logger.debug('finalized {}'.format(repr(self)))
         if self._next_promise:
-            self._next_promise.finalize()
+            self._next_promise._finalize()
 
     def fail(self, e):
         # type: (Exception) -> None
@@ -277,7 +287,7 @@ class _Promise(object):
         Sets the whole completion to be faild with this exception and end the
         evaluation.
         """
-        assert self._state is self.INITIALIZED
+        assert self._state in (self.INITIALIZED, self.RUNNING)
         logger.exception('_Promise failed')
         self._exception = e
         self._value = 'exception'
@@ -379,10 +389,10 @@ class ProgressReference(object):
         return self.progress == 1 and self._completion_has_result
 
     def update(self):
-        def run(progress):
+        def progress_run(progress):
             self.progress = progress
         if self.completion:
-            c = self.completion().then(run)
+            c = self.completion().then(progress_run)
             self.mgr.process([c._first_promise])
         else:
             self.progress = 1
@@ -424,9 +434,10 @@ class Completion(_Promise):
     def __init__(self,
                  _first_promise=None,  # type: Optional["Completion"]
                  value=_Promise.NO_RESULT,  # type: Any
-                 on_complete=None  # type: Optional[Callable]
+                 on_complete=None,  # type: Optional[Callable],
+                 name=None,  # type: Optional[str]
                  ):
-        super(Completion, self).__init__(_first_promise, value, on_complete)
+        super(Completion, self).__init__(_first_promise, value, on_complete, name)
 
     @property
     def _progress_reference(self):
@@ -478,6 +489,10 @@ class Completion(_Promise):
         if self._progress_reference:
             self._progress_reference.fail()
 
+    def finalize(self, result=_Promise.NO_RESULT):
+        if self._first_promise._state == self.INITIALIZED:
+            self._first_promise._finalize(result)
+
     @property
     def result(self):
         """
@@ -846,7 +861,7 @@ class Orchestrator(object):
         raise NotImplementedError()
 
     def update_mons(self, num, hosts):
-        # type: (int, List[Tuple[str,str]]) -> Completion
+        # type: (int, List[Tuple[str,str,str]]) -> Completion
         """
         Update the number of cluster monitors.
 
@@ -874,17 +889,17 @@ class Orchestrator(object):
         raise NotImplementedError()
 
     def add_rbd_mirror(self, spec):
-        # type: (StatelessServiceSpec) -> WriteCompletion
+        # type: (StatelessServiceSpec) -> Completion
         """Create rbd-mirror cluster"""
         raise NotImplementedError()
 
-    def remove_rbd_mirror(self):
-        # type: (str) -> WriteCompletion
+    def remove_rbd_mirror(self, name):
+        # type: (str) -> Completion
         """Remove rbd-mirror cluster"""
         raise NotImplementedError()
 
     def update_rbd_mirror(self, spec):
-        # type: (StatelessServiceSpec) -> WriteCompletion
+        # type: (StatelessServiceSpec) -> Completion
         """
         Update / redeploy rbd-mirror cluster
         Like for example changing the number of service instances.
@@ -1287,6 +1302,9 @@ class InventoryNode(object):
         # type: (List[InventoryNode]) -> List[str]
         return [node.name for node in nodes]
 
+    def __eq__(self, other):
+        return self.name == other.name and self.devices == other.devices
+
 
 class DeviceLightLoc(namedtuple('DeviceLightLoc', ['host', 'dev'])):
     """
index c0384437f0a8fd62f3bfc21c0e20c229b817ec37..a0accd69c607e76bae499c0f4bb3a25037158694 100644 (file)
@@ -35,7 +35,7 @@ from .rook_cluster import RookCluster
 
 class RookCompletion(orchestrator.Completion):
     def evaluate(self):
-        self._first_promise.finalize(None)
+        self.finalize(None)
 
 
 def deferred_read(f):
index 444a92d63aa5c864de310ea14b170952fffdb61f..7cf49f79efc2522f9477838cc0bb9380bea89c4d 100644 (file)
@@ -448,11 +448,11 @@ class Module(MgrModule):
         import orchestrator
         if what == 'OrchestratorError':
             c = orchestrator.TrivialReadCompletion(result=None)
-            c.exception = orchestrator.OrchestratorError('hello', 'world')
+            c.fail(orchestrator.OrchestratorError('hello', 'world'))
             return c
         elif what == "ZeroDivisionError":
             c = orchestrator.TrivialReadCompletion(result=None)
-            c.exception = ZeroDivisionError('hello', 'world')
+            c.fail(ZeroDivisionError('hello', 'world'))
             return c
         assert False, repr(what)
 
index 12f73aa52a170f75840d36cf0f5c1948d5a20fc4..a458a4a7bc0b4a09cefb667995bf254eea4c3b7e 100644 (file)
@@ -71,14 +71,19 @@ except ImportError:
 
 
 class AsyncCompletion(orchestrator.Completion):
-    def __init__(self, *args, **kwargs):
-        self.__on_complete = None  # type: Callable
-        self.many = kwargs.pop('many', False)
-        super(AsyncCompletion, self).__init__(*args, **kwargs)
-
-    def propagate_to_next(self):
-        # We don't have a synchronous result.
-        pass
+    def __init__(self,
+                 _first_promise=None,  # type: Optional["Completion"]
+                 value=orchestrator._Promise.NO_RESULT,  # type: Any
+                 on_complete=None,  # type: Optional[Callable],
+                 name=None,  # type: Optional[str],
+                 many=False, # type: bool
+                 ):
+
+        assert SSHOrchestrator.instance is not None
+        self.many = many
+        if name is None and on_complete is not None:
+            name = on_complete.__name__
+        super(AsyncCompletion, self).__init__(_first_promise, value, on_complete, name)
 
     @property
     def _progress_reference(self):
@@ -93,19 +98,35 @@ class AsyncCompletion(orchestrator.Completion):
             return None
 
         def callback(result):
-            if self._next_promise:
-                self._next_promise._value = result
-            else:
-                self._value = result
-            super(AsyncCompletion, self).propagate_to_next()
+            try:
+                self.__on_complete = None
+                self._finalize(result)
+            except Exception as e:
+                self.fail(e)
+
+        def error_callback(e):
+            self.fail(e)
 
         def run(value):
             if self.many:
-                SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
-                                                                callback=callback)
+                if not value:
+                    logger.info('calling map_async without values')
+                    callback([])
+                if six.PY3:
+                    SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
+                                                                    callback=callback,
+                                                                    error_callback=error_callback)
+                else:
+                    SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
+                                                                    callback=callback)
             else:
-                SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
-                                                                  callback=callback)
+                if six.PY3:
+                    SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
+                                                                      callback=callback, error_callback=error_callback)
+                else:
+                    SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
+                                                                      callback=callback)
+            return self.ASYNC_RESULT
 
         return run
 
@@ -122,8 +143,31 @@ def ssh_completion(cls=AsyncCompletion, **c_kwargs):
     """
     def decorator(f):
         @wraps(f)
-        def wrapper(*args, **kwargs):
-            return cls(on_complete=lambda _: f(*args, **kwargs), **c_kwargs)
+        def wrapper(*args):
+
+            name = f.__name__
+            many = c_kwargs.get('many', False)
+
+            # Some weired logic to make calling functions with multiple arguments work.
+            if len(args) == 1:
+                [value] = args
+                if many and value and isinstance(value[0], tuple):
+                    return cls(on_complete=lambda x: f(*x), value=value, name=name, **c_kwargs)
+                else:
+                    return cls(on_complete=f, value=value, name=name, **c_kwargs)
+            else:
+                if many:
+                    self, value = args
+
+                    def call_self(inner_args):
+                        if not isinstance(inner_args, tuple):
+                            inner_args = (inner_args, )
+                        return f(self, *inner_args)
+
+                    return cls(on_complete=call_self, value=value, name=name, **c_kwargs)
+                else:
+                    return cls(on_complete=lambda x: f(*x), value=args, name=name, **c_kwargs)
+
 
         return wrapper
     return decorator
@@ -136,6 +180,18 @@ def async_completion(f):
 
 def async_map_completion(f):
     # type: (Callable) -> Callable[..., AsyncCompletion]
+    """
+    kind of similar to
+
+    >>> def sync_map(f):
+    ...     return lambda x: map(f, x)
+
+    Limitation: This does not work, as you cannot return completions form `f`
+
+    >>> @async_map_completion
+    ... def run(x):
+    ...     return async_completion(str)(x)
+    """
     return ssh_completion(many=True)(f)
 
 
@@ -144,6 +200,10 @@ def trivial_completion(f):
     return ssh_completion(cls=orchestrator.Completion)(f)
 
 
+def trivial_result(val):
+    return AsyncCompletion(value=val, name='trivial_result')
+
+
 class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
 
     _STORE_HOST_PREFIX = "host"
@@ -189,7 +249,7 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         try:
             with open(path, 'r') as f:
                 self._ceph_daemon = f.read()
-        except IOError as e:
+        except (IOError, TypeError) as e:
             raise RuntimeError("unable to read ceph-daemon at '%s': %s" % (
                 path, str(e)))
 
@@ -233,6 +293,12 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             if h not in self.inventory:
                 del self.service_cache[h]
 
+    def shutdown(self):
+        self.log.error('ssh: shutdown')
+        self._worker_pool.close()
+        self._worker_pool.join()
+        self._worker_pool = None
+
     def config_notify(self):
         """
         This method is called whenever one of our config options is changed.
@@ -342,6 +408,11 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         """
         Does nothing, as completions are processed in another thread.
         """
+        if completions:
+            self.log.info("wait: promises={0}".format(completions))
+
+            for p in completions:
+                p.finalize()
 
     def _require_hosts(self, hosts):
         """
@@ -566,7 +637,8 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         TODO:
           - InventoryNode probably needs to be able to report labels
         """
-        nodes = [orchestrator.InventoryNode(host_name, []) for host_name in self.inventory_cache]
+        return [orchestrator.InventoryNode(host_name) for host_name in self.inventory_cache]
+
 """
     def add_host_label(self, host, label):
         if host not in self.inventory:
@@ -601,11 +673,12 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             self._worker_pool.apply_async(run, (host, label)))
 """
 
+    @async_map_completion
     def _refresh_host_services(self, host):
         out, code = self._run_ceph_daemon(
             host, 'mon', 'ls', [], no_fsid=True)
         data = json.loads(''.join(out))
-        self.log.debug('refreshed host %s services: %s' % (host, data))
+        self.log.error('refreshed host %s services: %s' % (host, data))
         self.service_cache[host] = orchestrator.OutdatableData(data)
         return data
 
@@ -628,10 +701,10 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
                     host, host_info.data))
                 in_cache.append(host_info.data)
 
-        def _get_services_result(self, results):
+        def _get_services_result(results):
             services = {}
-            for host, c in zip(hosts, results + in_cache):
-                services[host] = c.result[0]
+            for host, data in zip(hosts, results + in_cache):
+                services[host] = data
 
             result = []
             for host, ls in services.items():
@@ -670,7 +743,7 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
                     result.append(sd)
             return result
 
-        return async_map_completion(self._refresh_host_services)(wait_for_args).then(
+        return self._refresh_host_services(wait_for_args).then(
             _get_services_result)
 
 
@@ -683,7 +756,7 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
                                     service_id=service_id,
                                     node_name=node_name,
                                     refresh=refresh)
-        return orchestrator.TrivialReadCompletion(result)
+        return result
 
     def service_action(self, action, service_type,
                        service_name=None,
@@ -691,8 +764,7 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         self.log.debug('service_action action %s type %s name %s id %s' % (
             action, service_type, service_name, service_id))
         if action == 'reload':
-            return orchestrator.TrivialReadCompletion(
-                ["Reload is a no-op"])
+            return trivial_result(["Reload is a no-op"])
         daemons = self._get_services(
             service_type,
             service_name=service_name,
@@ -702,13 +774,16 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             args.append((d.service_type, d.service_instance,
                                        d.nodename, action))
         if not args:
-            n = service_name
-            if n:
-                n += '-*'
-            raise orchestrator.OrchestratorError('Unable to find %s.%s daemon(s)' % (
-                service_type, n))
-        return async_map_completion(self._service_action)(args)
+            if service_name:
+                n = service_name + '-*'
+            else:
+                n = service_id
+            raise orchestrator.OrchestratorError(
+                'Unable to find %s.%s daemon(s)' % (
+                    service_type, n))
+        return self._service_action(args)
 
+    @async_map_completion
     def _service_action(self, service_type, service_id, host, action):
         if action == 'redeploy':
             # recreate the systemd unit and then restart
@@ -758,9 +833,9 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             # this implies the returned hosts are registered
             hosts = self._get_hosts()
 
-        def run(host_info):
-            # type: (orchestrator.OutdatableData) -> orchestrator.InventoryNode
-            host = host_info.data['name']
+        @async_map_completion
+        def _get_inventory(host, host_info):
+            # type: (str, orchestrator.OutdatableData) -> orchestrator.InventoryNode
 
             if host_info.outdated(self.inventory_cache_timeout) or refresh:
                 self.log.info("refresh stale inventory for '{}'".format(host))
@@ -777,17 +852,16 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             devices = inventory.Devices.from_json(host_info.data)
             return orchestrator.InventoryNode(host, devices)
 
-        return async_map_completion(run)(hosts.values())
+        return _get_inventory(hosts)
 
     def blink_device_light(self, ident_fault, on, locs):
-        # type: (str, bool, List[orchestrator.DeviceLightLoc]) -> SSHWriteCompletion
-        def blink(host, dev, ident_fault_, on_):
-            # type: (str, str, str, bool) -> str
+        @async_map_completion
+        def blink(host, dev):
             cmd = [
                 'lsmcli',
                 'local-disk-%s-led-%s' % (
-                    ident_fault_,
-                    'on' if on_ else 'off'),
+                    ident_fault,
+                    'on' if on else 'off'),
                 '--path', '/dev/' + dev,
             ]
             out, code = self._run_ceph_daemon(host, 'osd', 'shell', ['--'] + cmd,
@@ -795,11 +869,11 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             if code:
                 raise RuntimeError(
                     'Unable to affect %s light for %s:%s. Command: %s' % (
-                        ident_fault_, host, dev, ' '.join(cmd)))
+                        ident_fault, host, dev, ' '.join(cmd)))
             return "Set %s light for %s:%s %s" % (
-                ident_fault_, host, dev, 'on' if on_ else 'off')
+                ident_fault, host, dev, 'on' if on else 'off')
 
-        return async_map_completion(blink)(locs)
+        return blink(locs)
 
     @async_completion
     def _create_osd(self, all_hosts_, drive_group):
@@ -891,12 +965,14 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
           - support batch creation
         """
 
-        return self.get_hosts().then(self._create_osd)
+        return self.get_hosts().then(lambda hosts: self._create_osd(hosts, drive_group))
 
     def remove_osds(self, name):
         daemons = self._get_services('osd', service_id=name)
         args = [('osd.%s' % d.service_instance, d.nodename) for d in daemons]
-        return async_map_completion(self._remove_daemon)(args)
+        if not args:
+            raise OrchestratorError('Unable to find osd.%s' % name)
+        return self._remove_daemon(args)
 
     def _create_daemon(self, daemon_type, daemon_id, host, keyring,
                        extra_args=[]):
@@ -941,6 +1017,7 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             self.log.info("create_daemon({}): finished".format(host))
             conn.exit()
 
+    @async_map_completion
     def _remove_daemon(self, name, host):
         """
         Remove a daemon
@@ -953,22 +1030,24 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         return "Removed {} from host '{}'".format(name, host)
 
     def _update_service(self, daemon_type, add_func, spec):
-        daemons = self._get_services(daemon_type, service_name=spec.name)
-        results = []
-        if len(daemons) > spec.count:
-            # remove some
-            to_remove = len(daemons) - spec.count
-            for d in daemons[0:to_remove]:
-                results.append(self._worker_pool.apply_async(
-                    self._remove_daemon,
-                    ('%s.%s' % (d.service_type, d.service_instance),
-                     d.nodename)))
-        elif len(daemons) < spec.count:
-            # add some
-            spec.count -= len(daemons)
-            return add_func(spec)
-        return SSHWriteCompletion(results)
-
+        def ___update_service(daemons):
+            if len(daemons) > spec.count:
+                # remove some
+                to_remove = len(daemons) - spec.count
+                args = []
+                for d in daemons[0:to_remove]:
+                    args.append(
+                        ('%s.%s' % (d.service_type, d.service_instance), d.nodename)
+                    )
+                return self._remove_daemon(args)
+            elif len(daemons) < spec.count:
+                # add some
+                spec.count -= len(daemons)
+                return add_func(spec)
+            return []
+        return self._get_services(daemon_type, service_name=spec.name).then(___update_service)
+
+    @async_map_completion
     def _create_mon(self, host, network, name):
         """
         Create a new monitor on the given host.
@@ -1031,8 +1110,9 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
 
         # TODO: we may want to chain the creation of the monitors so they join
         # the quorum one at a time.
-        return async_map_completion(self._create_mon)(host_specs)
+        return self._create_mon(host_specs)
 
+    @async_map_completion
     def _create_mgr(self, host, name):
         """
         Create a new manager instance on a host.
@@ -1054,7 +1134,9 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         """
         Adjust the number of cluster managers.
         """
-        daemons = self._get_services('mgr')
+        return self._get_services('mgr').then(lambda daemons: self._update_mgrs(num, host_specs, daemons))
+
+    def _update_mgrs(self, num, host_specs, daemons):
         num_mgrs = len(daemons)
         if num == num_mgrs:
             return orchestrator.Completion(value="The requested number of managers exist.")
@@ -1076,7 +1158,6 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             for standby in mgr_map.get('standbys', []):
                 connected.append(standby.get('name', ''))
             to_remove_damons = []
-            to_remove_mgr = []
             for d in daemons:
                 if d.service_instance not in connected:
                     to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance),
@@ -1087,16 +1168,12 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
 
             # otherwise, remove *any* mgr
             if num_to_remove > 0:
-                for daemon in daemons:
-                    to_remove_mgr.append((('%s.%s' % (d.service_type, d.service_instance), daemon.nodename))
+                for d in daemons:
+                    to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance), d.nodename))
                     num_to_remove -= 1
                     if num_to_remove == 0:
                         break
-            return async_map_completion(self._remove_daemon)(to_remove_damons).then(
-                lambda remove_daemon_result: async_map_completion(self._remove_daemon)(to_remove_mgr).then(
-                    lambda remove_mgr_result: remove_daemon_result + remove_mgr_result
-                )
-            )
+            return self._remove_daemon(to_remove_damons)
 
         else:
             # we assume explicit placement by which there are the same number of
@@ -1118,26 +1195,19 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             self.log.info("creating {} managers on hosts: '{}'".format(
                 num_new_mgrs, ",".join([spec.hostname for spec in host_specs])))
 
-            for host_spec in host_specs:
-                name = host_spec.name or self.get_unique_name(daemons)
-                result = self._worker_pool.apply_async(self._create_mgr,
-                                                       (host_spec.hostname, name))
-                results.append(result)
-                               
             args = []
             for host_spec in host_specs:
-                           name = host_spec.name or self.get_unique_name(daemons)
-                               host = host_spec.hostname
+                name = host_spec.name or self.get_unique_name(daemons)
+                host = host_spec.hostname
                 args.append((host, name))
-        return async_map_completion(self._create_mgr)(args)
-                       
-
-        return SSHWriteCompletion(results)
+        return self._create_mgr(args)
 
     def add_mds(self, spec):
         if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
             raise RuntimeError("must specify at least %d hosts" % spec.count)
-        daemons = self._get_services('mds')
+        return self._get_services('mds').then(lambda ds: self._add_mds(ds, spec))
+
+    def _add_mds(self, daemons, spec):
         args = []
         num_added = 0
         for host, _, name in spec.placement.nodes:
@@ -1153,11 +1223,12 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             sd.nodename = host
             daemons.append(sd)
             num_added += 1
-        return async_map_completion(self._create_mds)(args)
+        return self._create_mds(args)
 
     def update_mds(self, spec):
         return self._update_service('mds', self.add_mds, spec)
 
+    @async_map_completion
     def _create_mds(self, mds_id, host):
         # get mgr. key
         ret, keyring, err = self.mon_command({
@@ -1170,13 +1241,18 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         return self._create_daemon('mds', mds_id, host, keyring)
 
     def remove_mds(self, name):
-        daemons = self._get_services('mds')
         self.log.debug("Attempting to remove volume: {}".format(name))
-        if daemons:
-            args = [('%s.%s' % (d.service_type, d.service_instance),
-                     d.nodename) for d in daemons]
-            return async_map_completion(self._remove_daemon)(args)
-        raise orchestrator.OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
+        def _remove_mds(daemons):
+            args = []
+            for d in daemons:
+                if d.service_instance == name or d.service_instance.startswith(name + '.'):
+                    args.append(
+                        ('%s.%s' % (d.service_type, d.service_instance), d.nodename)
+                    )
+            if not args:
+                raise OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
+            return self._remove_daemon(args)
+        return self._get_services('mds').then(_remove_mds)
 
     def add_rgw(self, spec):
         if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
@@ -1188,24 +1264,28 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
             'name': 'rgw_zone',
             'value': spec.name,
         })
-        daemons = self._get_services('rgw')
-        args = []
-        num_added = 0
-        for host, _, name in spec.placement.nodes:
-            if num_added >= spec.count:
-                break
-            rgw_id = self.get_unique_name(daemons, spec.name, name)
-            self.log.debug('placing rgw.%s on host %s' % (rgw_id, host))
-            args.append((rgw_id, host))
-            # add to daemon list so next name(s) will also be unique
-            sd = orchestrator.ServiceDescription()
-            sd.service_instance = rgw_id
-            sd.service_type = 'rgw'
-            sd.nodename = host
-            daemons.append(sd)
-            num_added += 1
-        return async_map_completion(self._create_rgw)(args)
 
+        def _add_rgw(daemons):
+            args = []
+            num_added = 0
+            for host, _, name in spec.placement.nodes:
+                if num_added >= spec.count:
+                    break
+                rgw_id = self.get_unique_name(daemons, spec.name, name)
+                self.log.debug('placing rgw.%s on host %s' % (rgw_id, host))
+                args.append((rgw_id, host))
+                # add to daemon list so next name(s) will also be unique
+                sd = orchestrator.ServiceDescription()
+                sd.service_instance = rgw_id
+                sd.service_type = 'rgw'
+                sd.nodename = host
+                daemons.append(sd)
+                num_added += 1
+            return self._create_rgw(args)
+
+        return self._get_services('rgw').then(_add_rgw)
+
+    @async_map_completion
     def _create_rgw(self, rgw_id, host):
         ret, keyring, err = self.mon_command({
             'prefix': 'auth get-or-create',
@@ -1217,57 +1297,50 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         return self._create_daemon('rgw', rgw_id, host, keyring)
 
     def remove_rgw(self, name):
-        daemons = self._get_services('rgw')
 
-        args = []
-        for d in daemons:
-            if d.service_instance == name or d.service_instance.startswith(name + '.'):
-                args.append(('%s.%s' % (d.service_type, d.service_instance),
-                     d.nodename))
-        if args:
-            return async_map_completion(self._remove_daemon)(args)
-        raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
+        def _remove_rgw(daemons):
+            args = []
+            for d in daemons:
+                if d.service_instance == name or d.service_instance.startswith(name + '.'):
+                    args.append(('%s.%s' % (d.service_type, d.service_instance),
+                         d.nodename))
+            if args:
+                return self._remove_daemon(args)
+            raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
+
+        return self._get_services('rgw').then(_remove_rgw)
 
     def update_rgw(self, spec):
-        daemons = self._get_services('rgw', service_name=spec.name)
-        if len(daemons) > spec.count:
-            # remove some
-            to_remove = len(daemons) - spec.count
-            args = []
-            for d in daemons[0:to_remove]:
-                args.append((d.service_instance, d.nodename))
-            return async_map_completion(self._remove_rgw)(args)
-        elif len(daemons) < spec.count:
-            # add some
-            spec.count -= len(daemons)
-            return self.add_rgw(spec)
+        return self._update_service('rgw', self.add_rgw, spec)
 
     def add_rbd_mirror(self, spec):
         if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
             raise RuntimeError("must specify at least %d hosts" % spec.count)
         self.log.debug('nodes %s' % spec.placement.nodes)
-        daemons = self._get_services('rbd-mirror')
-        results = []
-        num_added = 0
-        for host, _, name in spec.placement.nodes:
-            if num_added >= spec.count:
-                break
-            daemon_id = self.get_unique_name(daemons, None, name)
-            self.log.debug('placing rbd-mirror.%s on host %s' % (daemon_id,
-                                                                 host))
-            results.append(
-                self._worker_pool.apply_async(self._create_rbd_mirror,
-                                              (daemon_id, host))
-            )
-            # add to daemon list so next name(s) will also be unique
-            sd = orchestrator.ServiceDescription()
-            sd.service_instance = daemon_id
-            sd.service_type = 'rbd-mirror'
-            sd.nodename = host
-            daemons.append(sd)
-            num_added += 1
-        return SSHWriteCompletion(results)
 
+        def _add_rbd_mirror(daemons):
+            args = []
+            num_added = 0
+            for host, _, name in spec.placement.nodes:
+                if num_added >= spec.count:
+                    break
+                daemon_id = self.get_unique_name(daemons, None, name)
+                self.log.debug('placing rbd-mirror.%s on host %s' % (daemon_id,
+                                                                     host))
+                args.append((daemon_id, host))
+
+                # add to daemon list so next name(s) will also be unique
+                sd = orchestrator.ServiceDescription()
+                sd.service_instance = daemon_id
+                sd.service_type = 'rbd-mirror'
+                sd.nodename = host
+                daemons.append(sd)
+                num_added += 1
+            return self._create_rbd_mirror(args)
+
+        return self._get_services('rbd-mirror').then(_add_rbd_mirror)
+
+    @async_map_completion
     def _create_rbd_mirror(self, daemon_id, host):
         ret, keyring, err = self.mon_command({
             'prefix': 'auth get-or-create',
@@ -1278,17 +1351,19 @@ class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
         return self._create_daemon('rbd-mirror', daemon_id, host, keyring)
 
     def remove_rbd_mirror(self, name):
-        daemons = self._get_services('rbd-mirror')
-        results = []
-        for d in daemons:
-            if not name or d.service_instance == name:
-                results.append(self._worker_pool.apply_async(
-                    self._remove_daemon,
-                    ('%s.%s' % (d.service_type, d.service_instance),
-                     d.nodename)))
-        if not results and name:
-            raise RuntimeError('Unable to find rbd-mirror.%s daemon' % name)
-        return SSHWriteCompletion(results)
+        def _remove_rbd_mirror(daemons):
+            args = []
+            for d in daemons:
+                if not name or d.service_instance == name:
+                    args.append(
+                        ('%s.%s' % (d.service_type, d.service_instance),
+                         d.nodename)
+                    )
+            if not args and name:
+                raise RuntimeError('Unable to find rbd-mirror.%s daemon' % name)
+            return self._remove_daemon(args)
+
+        return self._get_services('rbd-mirror').then(_remove_rbd_mirror)
 
     def update_rbd_mirror(self, spec):
         return self._update_service('rbd-mirror', self.add_rbd_mirror, spec)
diff --git a/src/pybind/mgr/ssh/tests/fixtures.py b/src/pybind/mgr/ssh/tests/fixtures.py
new file mode 100644 (file)
index 0000000..12caf4d
--- /dev/null
@@ -0,0 +1,40 @@
+from contextlib import contextmanager
+
+import pytest
+
+from ssh import SSHOrchestrator
+from tests import mock
+
+
+def set_store(self, k, v):
+    if v is None:
+        del self._store[k]
+    else:
+        self._store[k] = v
+
+
+def get_store(self, k):
+    return self._store[k]
+
+
+def get_store_prefix(self, prefix):
+    return {
+        k: v for k, v in self._store.items()
+        if k.startswith(prefix)
+    }
+
+
+@pytest.yield_fixture()
+def ssh_module():
+    with mock.patch("ssh.module.SSHOrchestrator.get_ceph_option", lambda _, key: __file__),\
+            mock.patch("ssh.module.SSHOrchestrator.set_store", set_store),\
+            mock.patch("ssh.module.SSHOrchestrator.get_store", get_store),\
+            mock.patch("ssh.module.SSHOrchestrator.get_store_prefix", get_store_prefix):
+        m = SSHOrchestrator.__new__ (SSHOrchestrator)
+        m._store = {
+            'ssh_config': '',
+            'ssh_identity_key': '',
+            'ssh_identity_pub': '',
+        }
+        m.__init__('ssh', 0, 0)
+        yield m
diff --git a/src/pybind/mgr/ssh/tests/test_completion.py b/src/pybind/mgr/ssh/tests/test_completion.py
new file mode 100644 (file)
index 0000000..f076de8
--- /dev/null
@@ -0,0 +1,171 @@
+import sys
+import time
+
+
+try:
+    from typing import Any
+except ImportError:
+    pass
+
+import pytest
+
+
+from orchestrator import raise_if_exception, Completion
+from .fixtures import ssh_module
+from ..module import trivial_completion, async_completion, async_map_completion, SSHOrchestrator
+
+
+class TestCompletion(object):
+    def _wait(self, m, c):
+        # type: (SSHOrchestrator, Completion) -> Any
+        m.process([c])
+        m.process([c])
+
+        for _ in range(30):
+            if c.is_finished:
+                raise_if_exception(c)
+                return c.result
+            time.sleep(0.1)
+        assert False, "timeout" + str(c._state)
+
+    def test_trivial(self, ssh_module):
+        @trivial_completion
+        def run(x):
+            return x+1
+        assert self._wait(ssh_module, run(1)) == 2
+
+    @pytest.mark.parametrize("input", [
+        ((1, ), ),
+        ((1, 2), ),
+        (("hallo", ), ),
+        (("hallo", "foo"), ),
+    ])
+    def test_async(self, input, ssh_module):
+        @async_completion
+        def run(*args):
+            return str(args)
+
+        assert self._wait(ssh_module, run(*input)) == str(input)
+
+    @pytest.mark.parametrize("input,expected", [
+        ([], []),
+        ([1], ["(1,)"]),
+        (["hallo"], ["('hallo',)"]),
+        ("hi", ["('h',)", "('i',)"]),
+        (list(range(5)), [str((x, )) for x in range(5)]),
+        ([(1, 2), (3, 4)], ["(1, 2)", "(3, 4)"]),
+    ])
+    def test_async_map(self, input, expected, ssh_module):
+        @async_map_completion
+        def run(*args):
+            return str(args)
+
+        c = run(input)
+        self._wait(ssh_module, c)
+        assert c.result == expected
+
+    def test_async_self(self, ssh_module):
+        class Run(object):
+            def __init__(self):
+                self.attr = 1
+
+            @async_completion
+            def run(self, x):
+                assert self.attr == 1
+                return x + 1
+
+        assert self._wait(ssh_module, Run().run(1)) == 2
+
+    @pytest.mark.parametrize("input,expected", [
+        ([], []),
+        ([1], ["(1,)"]),
+        (["hallo"], ["('hallo',)"]),
+        ("hi", ["('h',)", "('i',)"]),
+        (list(range(5)), [str((x, )) for x in range(5)]),
+        ([(1, 2), (3, 4)], ["(1, 2)", "(3, 4)"]),
+    ])
+    def test_async_map_self(self, input, expected, ssh_module):
+        class Run(object):
+            def __init__(self):
+                self.attr = 1
+
+            @async_map_completion
+            def run(self, *args):
+                assert self.attr == 1
+                return str(args)
+
+        c = Run().run(input)
+        self._wait(ssh_module, c)
+        assert c.result == expected
+
+    def test_then1(self, ssh_module):
+        @async_map_completion
+        def run(x):
+            return x+1
+
+        assert self._wait(ssh_module, run([1,2]).then(str)) == '[2, 3]'
+
+    def test_then2(self, ssh_module):
+        @async_map_completion
+        def run(x):
+            time.sleep(0.1)
+            return x+1
+
+        @async_completion
+        def async_str(results):
+            return str(results)
+
+        c = run([1,2]).then(async_str)
+
+        self._wait(ssh_module, c)
+        assert c.result == '[2, 3]'
+
+    def test_then3(self, ssh_module):
+        @async_map_completion
+        def run(x):
+            time.sleep(0.1)
+            return x+1
+
+        def async_str(results):
+            return async_completion(str)(results)
+
+        c = run([1,2]).then(async_str)
+
+        self._wait(ssh_module, c)
+        assert c.result == '[2, 3]'
+
+    def test_then4(self, ssh_module):
+        @async_map_completion
+        def run(x):
+            time.sleep(0.1)
+            return x+1
+
+        def async_str(results):
+            return async_completion(str)(results).then(lambda x: x + "hello")
+
+        c = run([1,2]).then(async_str)
+
+        self._wait(ssh_module, c)
+        assert c.result == '[2, 3]hello'
+
+    @pytest.mark.skip(reason="see limitation of async_map_completion")
+    def test_then5(self, ssh_module):
+        @async_map_completion
+        def run(x):
+            time.sleep(0.1)
+            return async_completion(str)(x+1)
+
+        c = run([1,2])
+
+        self._wait(ssh_module, c)
+        assert c.result == "['2', '3']"
+
+    @pytest.mark.skipif(sys.version_info < (3,0), reason="requires python3")
+    def test_raise(self, ssh_module):
+        @async_completion
+        def run(x):
+            raise ZeroDivisionError()
+
+        with pytest.raises(ZeroDivisionError):
+            self._wait(ssh_module, run(1))
+
index 639f08c92434dc8ccef41fd7428ae9e18c731724..d1e8803048c92ecda2f75b809ae2e33ff36bfd42 100644 (file)
-from orchestrator import ServiceDescription
+import json
+import time
+from contextlib import contextmanager
+
+from ceph.deployment.drive_group import DriveGroupSpec, DeviceSelection
+
+try:
+    from typing import Any
+except ImportError:
+    pass
+
+from orchestrator import ServiceDescription, raise_if_exception, Completion, InventoryNode, \
+    StatelessServiceSpec, PlacementSpec, RGWSpec, parse_host_specs
 from ..module import SSHOrchestrator
 from tests import mock
+from .fixtures import ssh_module
+
+
+"""
+TODOs:
+    There is really room for improvement here. I just quickly assembled theses tests.
+    I general, everything should be testes in Teuthology as well. Reasons for
+    also testing this here is the development roundtrip time.
+"""
+
+
+
+def _run_ceph_daemon(ret):
+    def foo(*args, **kwargs):
+        return ret, 0
+    return foo
+
+def mon_command(*args, **kwargs):
+    return 0, '', ''
+
+
+class TestSSH(object):
+    def _wait(self, m, c):
+        # type: (SSHOrchestrator, Completion) -> Any
+        m.process([c])
+        m.process([c])
+
+        for _ in range(30):
+            if c.is_finished:
+                raise_if_exception(c)
+                return c.result
+            time.sleep(0.1)
+        assert False, "timeout" + str(c._state)
+
+    @contextmanager
+    def _with_host(self, m, name):
+        self._wait(m, m.add_host(name))
+        yield
+        self._wait(m, m.remove_host(name))
+
+    def test_get_unique_name(self, ssh_module):
+        existing = [
+            ServiceDescription(service_instance='mon.a')
+        ]
+        new_mon = ssh_module.get_unique_name(existing, 'mon')
+        assert new_mon.startswith('mon.')
+        assert new_mon != 'mon.a'
+
+    def test_host(self, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            assert self._wait(ssh_module, ssh_module.get_hosts()) == [InventoryNode('test')]
+        c = ssh_module.get_hosts()
+        assert self._wait(ssh_module, c) == []
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+    def test_service_ls(self, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            c = ssh_module.describe_service()
+            assert self._wait(ssh_module, c) == []
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+    def test_device_ls(self, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            c = ssh_module.get_inventory()
+            assert self._wait(ssh_module, c) == [InventoryNode('test')]
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_mon_update(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            c = ssh_module.update_mons(1, [parse_host_specs('test:0.0.0.0')])
+            assert self._wait(ssh_module, c) == ["(Re)deployed mon.test on host 'test'"]
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_mgr_update(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            c = ssh_module.update_mgrs(1, [parse_host_specs('test:0.0.0.0')])
+            [out] = self._wait(ssh_module, c)
+            assert "(Re)deployed mgr." in out
+            assert " on host 'test'" in out
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_create_osds(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            dg = DriveGroupSpec('test', DeviceSelection(paths=['']))
+            c = ssh_module.create_osds(dg)
+            assert self._wait(ssh_module, c) == "Created osd(s) on host 'test'"
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_mds(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            ps = PlacementSpec(nodes=['test'])
+            c = ssh_module.add_mds(StatelessServiceSpec('name', ps))
+            [out] = self._wait(ssh_module, c)
+            assert "(Re)deployed mds.name." in out
+            assert " on host 'test'" in out
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_rgw(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            ps = PlacementSpec(nodes=['test'])
+            c = ssh_module.add_rgw(RGWSpec('name', ps))
+            [out] = self._wait(ssh_module, c)
+            assert "(Re)deployed rgw.name." in out
+            assert " on host 'test'" in out
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon(
+        json.dumps([
+            dict(
+                name='rgw.myrgw.foobar',
+                style='ceph-daemon',
+                fsid='fsid',
+                container_id='container_id',
+                version='version',
+                state='running',
+            )
+        ])
+    ))
+    def test_remove_rgw(self, ssh_module):
+        ssh_module._cluster_fsid = "fsid"
+        with self._with_host(ssh_module, 'test'):
+            c = ssh_module.remove_rgw('myrgw')
+            out = self._wait(ssh_module, c)
+            assert out == ["Removed rgw.myrgw.foobar from host 'test'"]
+
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_rbd_mirror(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            ps = PlacementSpec(nodes=['test'])
+            c = ssh_module.add_rbd_mirror(StatelessServiceSpec('name', ps))
+            [out] = self._wait(ssh_module, c)
+            assert "(Re)deployed rbd-mirror." in out
+            assert " on host 'test'" in out
 
+    @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+    @mock.patch("ssh.module.SSHOrchestrator.send_command")
+    @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+    @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+    def test_blink_device_light(self, _send_command, _get_connection, ssh_module):
+        with self._with_host(ssh_module, 'test'):
+            c = ssh_module.blink_device_light('ident', True, [('test', '')])
+            assert self._wait(ssh_module, c) == ['Set ident light for test: on']
 
-@mock.patch("ssh.module.SSHOrchestrator.get_ceph_option", lambda _,key: __file__)
-def test_get_unique_name():
-    o = SSHOrchestrator('module_name', 0, 0)
-    existing = [
-        ServiceDescription(service_instance='mon.a')
-    ]
-    new_mon = o.get_unique_name(existing, 'mon')
-    assert new_mon.startswith('mon.')
-    assert new_mon != 'mon.a'
 
index e742969119b6514f7a2bf11a4f514cdcb56fd004..6067276a2be0b7f01210bb0a9b141ed653fbb93e 100644 (file)
@@ -21,7 +21,7 @@ import orchestrator
 
 class TestCompletion(orchestrator.Completion):
     def evaluate(self):
-        self._first_promise.finalize(None)
+        self.finalize(None)
 
 
 def deferred_read(f):
index 666875d0aecebfa6fa6baa89d92c289ff5246d6a..af260d9853a7117db8a7dbe59381ad88943109a1 100644 (file)
@@ -22,7 +22,6 @@ if 'UNITTEST' in os.environ:
             self._ceph_get = mock.MagicMock()
             self._ceph_get_module_option = mock.MagicMock()
             self._ceph_log = mock.MagicMock()
-            self._ceph_get_option = mock.MagicMock()
             self._ceph_get_store = lambda _: ''
             self._ceph_get_store_prefix = lambda _: {}
 
index d24d1a3d96e89a171e2477726f5e658e4cc27b07..a0095db172402b7d6613a51e14173d845fb41937 100644 (file)
@@ -1,11 +1,7 @@
 from __future__ import absolute_import
 import json
 
-try:
-    from unittest.mock import MagicMock
-except ImportError:
-    # py2
-    from mock import MagicMock
+from tests import mock
 
 import pytest
 
@@ -123,14 +119,14 @@ def test_promise():
 
 def test_promise_then():
     p = Completion(value=3).then(lambda three: three + 1)
-    p._first_promise.finalize()
+    p.finalize()
     assert p.result == 4
 
 
 def test_promise_mondatic_then():
     p = Completion(value=3)
     p.then(lambda three: Completion(value=three + 1))
-    p._first_promise.finalize()
+    p.finalize()
     assert p.result == 4
 
 
@@ -138,11 +134,11 @@ def some_complex_completion():
     c = Completion(value=3).then(
         lambda three: Completion(value=three + 1).then(
             lambda four: four + 1))
-    return c._first_promise
+    return c
 
 def test_promise_mondatic_then_combined():
     p = some_complex_completion()
-    p._first_promise.finalize()
+    p.finalize()
     assert p.result == 5
 
 
@@ -161,13 +157,13 @@ def test_side_effect():
         foo['x'] = x
 
     foo['x'] = 1
-    Completion(value=3).then(run)._first_promise.finalize()
+    Completion(value=3).then(run).finalize()
     assert foo['x'] == 3
 
 
 def test_progress():
     c = some_complex_completion()
-    mgr = MagicMock()
+    mgr = mock.MagicMock()
     mgr.process = lambda cs: [c.finalize(None) for c in cs]
 
     progress_val = 0.75
@@ -193,7 +189,7 @@ def test_progress():
 
 
 def test_with_progress():
-    mgr = MagicMock()
+    mgr = mock.MagicMock()
     mgr.process = lambda cs: [c.finalize(None) for c in cs]
 
     def execute(y):
@@ -221,10 +217,10 @@ def test_exception():
     def run(x):
         raise KeyError(x)
 
-    c = Completion(value=3).then(run)._first_promise
+    c = Completion(value=3).then(run)
     c.finalize()
-
-    assert isinstance(c.exception, KeyError)
+    with pytest.raises(KeyError):
+        raise_if_exception(c)
 
 
 def test_fail():
index 7a516f8cae28022eb1b07a51131c3687b855339a..c97800f33b2d783288d4ed99a3b5f643cfdc0ded 100644 (file)
@@ -5,4 +5,4 @@ skipsdist = true
 [testenv]
 setenv = UNITTEST = true
 deps = -rrequirements.txt
-commands = pytest --cov --cov-append --cov-report=term --doctest-modules {posargs:mgr_util.py tests/ ssh/}
\ No newline at end of file
+commands = pytest -v --cov --cov-append --cov-report=term --doctest-modules {posargs:mgr_util.py tests/ ssh/}
\ No newline at end of file