]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
mgr/ssh: Adapt ssh orch to new Completions interface
authorSebastian Wagner <sebastian.wagner@suse.com>
Fri, 30 Aug 2019 16:08:46 +0000 (18:08 +0200)
committerSebastian Wagner <sebastian.wagner@suse.com>
Wed, 27 Nov 2019 12:38:20 +0000 (13:38 +0100)
Signed-off-by: Sebastian Wagner <sebastian.wagner@suse.com>
src/pybind/mgr/orchestrator.py
src/pybind/mgr/ssh/module.py

index 1ec3959c11cb0551cca0b70f969aa3092c745c82..a88bf7f09fe9d28e149fb53bbba33dc7b4991158 100644 (file)
@@ -894,7 +894,7 @@ class Orchestrator(object):
         raise NotImplementedError()
 
     def update_rgw(self, spec):
-        # type: (StatelessServiceSpec) -> Completion
+        # type: (RGWSpec) -> Completion
         """
         Update / redeploy existing RGW zone
         Like for example changing the number of service instances.
index dbd33e7d6333b01006da59ea715a8334ebf2d586..af42e778ad38de59b8275755f3b462c609802665 100644 (file)
@@ -4,6 +4,13 @@ import logging
 from functools import wraps
 
 import string
+try:
+    from typing import List, Dict, Optional, Callable, TypeVar, Type, Any
+except ImportError:
+    pass  # just for type checking
+
+T = TypeVar('T')
+
 import six
 import os
 import random
@@ -64,87 +71,86 @@ except ImportError:
 #    multiple bootstrapping / initialization
 
 
-class SSHCompletionmMixin(object):
-    def __init__(self, result):
-        if isinstance(result, multiprocessing.pool.AsyncResult):
-            self._result = [result]
-        else:
-            self._result = result
-        assert isinstance(self._result, list)
-
-    @property
-    def result(self):
-        return list(map(lambda r: r.get(), self._result))
+class AsyncCompletion(orchestrator.Completion[T]):
+    def __init__(self, *args, many=False, **kwargs):
+        self.__on_complete = None  # type: Callable[[T], Any]
+        self.many = many
+        super(AsyncCompletion, self).__init__(*args, **kwargs)
 
+    def propagate_to_next(self):
+        # We don't have a synchronous result.
+        pass
 
-class SSHReadCompletion(SSHCompletionmMixin, orchestrator.ReadCompletion):
     @property
-    def has_result(self):
-        return all(map(lambda r: r.ready(), self._result))
+    def _progress_reference(self):
+        if hasattr(self.__on_complete, 'progress_id'):
+            return self.__on_complete
+        return None
 
+    @property
+    def _on_complete(self):
+        # type: () -> Optional[Callable[[T], Any]]
+        if self.__on_complete is None:
+            return None
+
+        def callback(result):
+            if self._next_promise:
+                self._next_promise._value = result
+            else:
+                self._value = result
+            super(AsyncCompletion, self).propagate_to_next()
 
-class SSHWriteCompletion(SSHCompletionmMixin, orchestrator.WriteCompletion):
+        def run(value):
+            if self.many:
+                SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
+                                                                callback=callback)
+            else:
+                SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
+                                                                  callback=callback)
 
-    @property
-    def has_result(self):
-        return all(map(lambda r: r.ready(), self._result))
+        return run
 
-    @property
-    def is_effective(self):
-        return all(map(lambda r: r.ready(), self._result))
+    @_on_complete.setter
+    def _on_complete(self, inner):
+        # type: (Callable[[T], Any]) -> None
+        self.__on_complete = inner
 
-    @property
-    def is_errored(self):
-        for r in self._result:
-            if not r.ready():
-                return False
-            if not r.successful():
-                return True
-        return False
 
+def ssh_completion(cls=AsyncCompletion, **c_kwargs):
+    # type: (Type[orchestrator.Completion], Any) -> Callable
+    """
+    run the given function through `apply_async()` or `map_asyc()`
+    """
+    def decorator(f):
+        @wraps(f)
+        def wrapper(*args, **kwargs):
+            return cls(on_complete=lambda _: f(*args, **kwargs), **c_kwargs)
 
-class SSHWriteCompletionReady(SSHWriteCompletion):
-    def __init__(self, result):
-        orchestrator.WriteCompletion.__init__(self)
-        self._result = result
+        return wrapper
+    return decorator
 
-    @property
-    def result(self):
-        return self._result
 
-    @property
-    def has_result(self):
-        return True
+def async_completion(f):
+    # type: (Callable[..., T]) -> Callable[..., AsyncCompletion[T]]
+    return ssh_completion()(f)
 
-    @property
-    def is_effective(self):
-        return True
 
-    @property
-    def is_errored(self):
-        return False
+def async_map_completion(f):
+    # type: (Callable[..., T]) -> Callable[..., AsyncCompletion[T]]
+    return ssh_completion(many=True)(f)
 
 
-def log_exceptions(f):
-    if six.PY3:
-        return f
-    else:
-        # Python 2 does no exception chaining, thus the
-        # real exception is lost
-        @wraps(f)
-        def wrapper(*args, **kwargs):
-            try:
-                return f(*args, **kwargs)
-            except Exception:
-                logger.exception('something went wrong.')
-                raise
-        return wrapper
+def trivial_completion(f):
+    # type: (Callable[..., T]) -> Callable[..., orchestrator.Completion[T]]
+    return ssh_completion(cls=orchestrator.Completion)(f)
 
 
-class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
+class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
 
     _STORE_HOST_PREFIX = "host"
 
+
+    instance = None
     NATIVE_OPTIONS = []
     MODULE_OPTIONS = [
         {
@@ -192,6 +198,9 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
 
         self._reconfig_ssh()
 
+        SSHOrchestrator.instance = self
+        self.all_progress_references = list()  # type: List[orchestrator.ProgressReference]
+
         # load inventory
         i = self.get_store('inventory')
         if i:
@@ -521,54 +530,45 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
     def _get_hosts(self, wanted=None):
         return self.inventory_cache.items_filtered(wanted)
 
+    @async_completion
     def add_host(self, host):
         """
         Add a host to be managed by the orchestrator.
 
         :param host: host name
         """
-        @log_exceptions
-        def run(host):
-            self.inventory[host] = {}
-            self._save_inventory()
-            self.inventory_cache[host] = orchestrator.OutdatableData()
-            self.service_cache[host] = orchestrator.OutdatableData()
-            return "Added host '{}'".format(host)
-
-        return SSHWriteCompletion(
-            self._worker_pool.apply_async(run, (host,)))
+        self.inventory[host] = {}
+        self._save_inventory()
+        self.inventory_cache[host] = orchestrator.OutdatableData()
+        self.service_cache[host] = orchestrator.OutdatableData()
+        return "Added host '{}'".format(host)
 
+    @async_completion
     def remove_host(self, host):
         """
         Remove a host from orchestrator management.
 
         :param host: host name
         """
-        @log_exceptions
-        def run(host):
-            del self.inventory[host]
-            self._save_inventory()
-            del self.inventory_cache[host]
-            del self.service_cache[host]
-            return "Removed host '{}'".format(host)
-
-        return SSHWriteCompletion(
-            self._worker_pool.apply_async(run, (host,)))
+        del self.inventory[host]
+        self._save_inventory()
+        del self.inventory_cache[host]
+        del self.service_cache[host]
+        return "Removed host '{}'".format(host)
 
+    @trivial_completion
     def get_hosts(self):
         """
         Return a list of hosts managed by the orchestrator.
 
         Notes:
           - skip async: manager reads from cache.
-        """
-        nodes = [
-            orchestrator.InventoryNode(h,
-                                       inventory.Devices([]),
-                                       i.get('labels', []))
-            for h, i in self.inventory.items()]
-        return orchestrator.TrivialReadCompletion(nodes)
 
+        TODO:
+          - InventoryNode probably needs to be able to report labels
+        """
+        nodes = [orchestrator.InventoryNode(host_name, []) for host_name in self.inventory_cache]
+"""
     def add_host_label(self, host, label):
         if host not in self.inventory:
             raise OrchestratorError('host %s does not exist' % host)
@@ -600,6 +600,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
 
         return SSHWriteCompletion(
             self._worker_pool.apply_async(run, (host, label)))
+"""
 
     def _refresh_host_services(self, host):
         out, code = self._run_ceph_daemon(
@@ -616,61 +617,63 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
                       node_name=None,
                       refresh=False):
         hosts = []
-        wait_for = []
+        wait_for_args = []
+        in_cache = []
         for host, host_info in self.service_cache.items_filtered():
             hosts.append(host)
             if host_info.outdated(self.service_cache_timeout) or refresh:
                 self.log.info("refresing stale services for '{}'".format(host))
-                wait_for.append(
-                    SSHReadCompletion(self._worker_pool.apply_async(
-                        self._refresh_host_services, (host,))))
+                wait_for_args.append((host,))
             else:
                 self.log.debug('have recent services for %s: %s' % (
                     host, host_info.data))
-                wait_for.append(
-                    orchestrator.TrivialReadCompletion([host_info.data]))
-        self._orchestrator_wait(wait_for)
-
-        services = {}
-        for host, c in zip(hosts, wait_for):
-            services[host] = c.result[0]
-
-        result = []
-        for host, ls in services.items():
-            for d in ls:
-                if not d['style'].startswith('ceph-daemon'):
-                    self.log.debug('ignoring non-ceph-daemon on %s: %s' % (host, d))
-                    continue
-                if d['fsid'] != self._cluster_fsid:
-                    self.log.debug('ignoring foreign daemon on %s: %s' % (host, d))
-                    continue
-                self.log.debug('including %s' % d)
-                sd = orchestrator.ServiceDescription()
-                sd.service_type = d['name'].split('.')[0]
-                if service_type and service_type != sd.service_type:
-                    continue
-                if '.' in d['name']:
-                    sd.service_instance = '.'.join(d['name'].split('.')[1:])
-                else:
-                    sd.service_instance = host  # e.g., crash
-                if service_id and service_id != sd.service_instance:
-                    continue
-                if service_name and not sd.service_instance.startswith(service_name + '.'):
-                    continue
-                sd.nodename = host
-                sd.container_id = d.get('container_id')
-                sd.container_image_name = d.get('container_image_name')
-                sd.container_image_id = d.get('container_image_id')
-                sd.version = d.get('version')
-                sd.status_desc = d['state']
-                sd.status = {
-                    'running': 1,
-                    'stopped': 0,
-                    'error': -1,
-                    'unknown': -1,
-                }[d['state']]
-                result.append(sd)
-        return result
+                in_cache.append(host_info.data)
+
+        def _get_services_result(self, results):
+            services = {}
+            for host, c in zip(hosts, results + in_cache):
+                services[host] = c.result[0]
+
+            result = []
+            for host, ls in services.items():
+                for d in ls:
+                    if not d['style'].startswith('ceph-daemon'):
+                        self.log.debug('ignoring non-ceph-daemon on %s: %s' % (host, d))
+                        continue
+                    if d['fsid'] != self._cluster_fsid:
+                        self.log.debug('ignoring foreign daemon on %s: %s' % (host, d))
+                        continue
+                    self.log.debug('including %s' % d)
+                    sd = orchestrator.ServiceDescription()
+                    sd.service_type = d['name'].split('.')[0]
+                    if service_type and service_type != sd.service_type:
+                        continue
+                    if '.' in d['name']:
+                        sd.service_instance = '.'.join(d['name'].split('.')[1:])
+                    else:
+                        sd.service_instance = host  # e.g., crash
+                    if service_id and service_id != sd.service_instance:
+                        continue
+                    if service_name and not sd.service_instance.startswith(service_name + '.'):
+                        continue
+                    sd.nodename = host
+                    sd.container_id = d.get('container_id')
+                    sd.container_image_name = d.get('container_image_name')
+                    sd.container_image_id = d.get('container_image_id')
+                    sd.version = d.get('version')
+                    sd.status_desc = d['state']
+                    sd.status = {
+                        'running': 1,
+                        'stopped': 0,
+                        'error': -1,
+                        'unknown': -1,
+                    }[d['state']]
+                    result.append(sd)
+            return result
+
+        return async_map_completion(self._refresh_host_services)(wait_for_args).then(
+            _get_services_result)
+
 
     def describe_service(self, service_type=None, service_id=None,
                          node_name=None, refresh=False):
@@ -695,20 +698,17 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             service_type,
             service_name=service_name,
             service_id=service_id)
-        results = []
+        args = []
         for d in daemons:
-            results.append(self._worker_pool.apply_async(
-                self._service_action, (d.service_type, d.service_instance,
-                                       d.nodename, action)))
-        if not results:
-            if service_name:
-                n = service_name + '-*'
-            else:
-                n = service_id
-            raise OrchestratorError(
-                'Unable to find %s.%s daemon(s)' % (
-                    service_type, n))
-        return SSHWriteCompletion(results)
+            args.append((d.service_type, d.service_instance,
+                                       d.nodename, action))
+        if not args:
+            n = service_name
+            if n:
+                n += '-*'
+            raise orchestrator.OrchestratorError('Unable to find %s.%s daemon(s)' % (
+                service_type, n))
+        return async_map_completion(self._service_action)(args)
 
     def _service_action(self, service_type, service_id, host, action):
         if action == 'redeploy':
@@ -759,9 +759,9 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             # this implies the returned hosts are registered
             hosts = self._get_hosts()
 
-        @log_exceptions
-        def run(host, host_info):
-            # type: (str, orchestrator.OutdatableData) -> orchestrator.InventoryNode
+        def run(host_info):
+            # type: (orchestrator.OutdatableData) -> orchestrator.InventoryNode
+            host = host_info.data['name']
 
             if host_info.outdated(self.inventory_cache_timeout) or refresh:
                 self.log.info("refresh stale inventory for '{}'".format(host))
@@ -778,17 +778,10 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             devices = inventory.Devices.from_json(host_info.data)
             return orchestrator.InventoryNode(host, devices)
 
-        results = []
-        for key, host_info in hosts:
-            result = self._worker_pool.apply_async(run, (key, host_info))
-            results.append(result)
-
-        return SSHReadCompletion(results)
+        return async_map_completion(run)(hosts.values())
 
-    @log_exceptions
     def blink_device_light(self, ident_fault, on, locs):
         # type: (str, bool, List[orchestrator.DeviceLightLoc]) -> SSHWriteCompletion
-
         def blink(host, dev, ident_fault_, on_):
             # type: (str, str, str, bool) -> str
             cmd = [
@@ -807,16 +800,20 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             return "Set %s light for %s:%s %s" % (
                 ident_fault_, host, dev, 'on' if on_ else 'off')
 
-        results = []
-        for loc in locs:
-            results.append(
-                self._worker_pool.apply_async(
-                    blink,
-                    (loc.host, loc.dev, ident_fault, on)))
-        return SSHWriteCompletion(results)
+        return async_map_completion(blink)(locs)
+
+    @async_completion
+    def _create_osd(self, all_hosts_, drive_group):
+        all_hosts = orchestrator.InventoryNode.get_host_names(all_hosts_)
+        assert len(drive_group.hosts(all_hosts)) == 1
+        assert len(drive_group.data_devices.paths) > 0
+        assert all(map(lambda p: isinstance(p, six.string_types),
+            drive_group.data_devices.paths))
+
+        host = drive_group.hosts(all_hosts)[0]
+        self._require_hosts(host)
+
 
-    @log_exceptions
-    def _create_osd(self, host, drive_group):
         # get bootstrap key
         ret, keyring, err = self.mon_command({
             'prefix': 'auth get',
@@ -881,7 +878,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
 
         return "Created osd(s) on host '{}'".format(host)
 
-    def create_osds(self, drive_group, all_hosts=None):
+    def create_osds(self, drive_group):
         """
         Create a new osd.
 
@@ -894,29 +891,13 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
           - support full drive_group specification
           - support batch creation
         """
-        assert len(drive_group.hosts(all_hosts)) == 1
-        assert len(drive_group.data_devices.paths) > 0
-        assert all(map(lambda p: isinstance(p, six.string_types),
-            drive_group.data_devices.paths))
-
-        host = drive_group.hosts(all_hosts)[0]
-        self._require_hosts(host)
 
-        result = self._worker_pool.apply_async(self._create_osd, (host,
-                drive_group))
-
-        return SSHWriteCompletion(result)
+        return self.get_hosts().then(self._create_osd)
 
     def remove_osds(self, name):
         daemons = self._get_services('osd', service_id=name)
-        results = []
-        for d in daemons:
-            results.append(self._worker_pool.apply_async(
-                self._remove_daemon,
-                ('osd.%s' % d.service_instance, d.nodename)))
-        if not results:
-            raise OrchestratorError('Unable to find osd.%s' % name)
-        return SSHWriteCompletion(results)
+        args = [('osd.%s' % d.service_instance, d.nodename) for d in daemons]
+        return async_map_completion(self._remove_daemon)(args)
 
     def _create_daemon(self, daemon_type, daemon_id, host, keyring,
                        extra_args=[]):
@@ -1022,7 +1003,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
         mon_map = self.get("mon_map")
         num_mons = len(mon_map["mons"])
         if num == num_mons:
-            return SSHWriteCompletionReady("The requested number of monitors exist.")
+            return orchestrator.Completion(value="The requested number of monitors exist.")
         if num < num_mons:
             raise NotImplementedError("Removing monitors is not supported.")
 
@@ -1051,14 +1032,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
 
         # TODO: we may want to chain the creation of the monitors so they join
         # the quorum one at a time.
-        results = []
-        for host, network, name in host_specs:
-            result = self._worker_pool.apply_async(self._create_mon,
-                                                   (host, network, name))
-
-            results.append(result)
-
-        return SSHWriteCompletion(results)
+        return async_map_completion(self._create_mon)(host_specs)
 
     def _create_mgr(self, host, name):
         """
@@ -1084,7 +1058,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
         daemons = self._get_services('mgr')
         num_mgrs = len(daemons)
         if num == num_mgrs:
-            return SSHWriteCompletionReady("The requested number of managers exist.")
+            return orchestrator.Completion(value="The requested number of managers exist.")
 
         self.log.debug("Trying to update managers on: {}".format(host_specs))
         # check that all the hosts are registered
@@ -1102,13 +1076,12 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
                 connected.append(mgr_map.get('active_name', ''))
             for standby in mgr_map.get('standbys', []):
                 connected.append(standby.get('name', ''))
+            to_remove_damons = []
+            to_remove_mgr = []
             for d in daemons:
                 if d.service_instance not in connected:
-                    result = self._worker_pool.apply_async(
-                        self._remove_daemon,
-                        ('%s.%s' % (d.service_type, d.service_instance),
+                    to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance),
                          d.nodename))
-                    results.append(result)
                     num_to_remove -= 1
                     if num_to_remove == 0:
                         break
@@ -1116,14 +1089,15 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             # otherwise, remove *any* mgr
             if num_to_remove > 0:
                 for daemon in daemons:
-                    result = self._worker_pool.apply_async(
-                        self._remove_daemon,
-                        ('%s.%s' % (d.service_type, d.service_instance),
-                         d.nodename))
-                    results.append(result)
+                    to_remove_mgr.append((('%s.%s' % (d.service_type, d.service_instance), daemon.nodename))
                     num_to_remove -= 1
                     if num_to_remove == 0:
                         break
+            return async_map_completion(self._remove_daemon)(to_remove_damons).then(
+                lambda remove_daemon_result: async_map_completion(self._remove_daemon)(to_remove_mgr).then(
+                    lambda remove_mgr_result: remove_daemon_result + remove_mgr_result
+                )
+            )
 
         else:
             # we assume explicit placement by which there are the same number of
@@ -1150,6 +1124,14 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
                 result = self._worker_pool.apply_async(self._create_mgr,
                                                        (host_spec.hostname, name))
                 results.append(result)
+                               
+            args = []
+            for host_spec in host_specs:
+                           name = host_spec.name or self.get_unique_name(daemons)
+                               host = host_spec.hostname
+                args.append((host, name))
+        return async_map_completion(self._create_mgr)(args)
+                       
 
         return SSHWriteCompletion(results)
 
@@ -1157,16 +1139,14 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
         if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
             raise RuntimeError("must specify at least %d hosts" % spec.count)
         daemons = self._get_services('mds')
-        results = []
+        args = []
         num_added = 0
         for host, _, name in spec.placement.nodes:
             if num_added >= spec.count:
                 break
             mds_id = self.get_unique_name(daemons, spec.name, name)
             self.log.debug('placing mds.%s on host %s' % (mds_id, host))
-            results.append(
-                self._worker_pool.apply_async(self._create_mds, (mds_id, host))
-            )
+            args.append((mds_id, host))
             # add to daemon list so next name(s) will also be unique
             sd = orchestrator.ServiceDescription()
             sd.service_instance = mds_id
@@ -1174,7 +1154,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             sd.nodename = host
             daemons.append(sd)
             num_added += 1
-        return SSHWriteCompletion(results)
+        return async_map_completion(self._create_mds)(args)
 
     def update_mds(self, spec):
         return self._update_service('mds', self.add_mds, spec)
@@ -1192,17 +1172,12 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
 
     def remove_mds(self, name):
         daemons = self._get_services('mds')
-        results = []
         self.log.debug("Attempting to remove volume: {}".format(name))
-        for d in daemons:
-            if d.service_instance == name or d.service_instance.startswith(name + '.'):
-                results.append(self._worker_pool.apply_async(
-                    self._remove_daemon,
-                    ('%s.%s' % (d.service_type, d.service_instance),
-                     d.nodename)))
-        if not results:
-            raise OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
-        return SSHWriteCompletion(results)
+        if daemons:
+            args = [('%s.%s' % (d.service_type, d.service_instance),
+                     d.nodename) for d in daemons]
+            return async_map_completion(self._remove_daemon)(args)
+        raise orchestrator.OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
 
     def add_rgw(self, spec):
         if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
@@ -1215,16 +1190,14 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             'value': spec.name,
         })
         daemons = self._get_services('rgw')
-        results = []
+        args = []
         num_added = 0
         for host, _, name in spec.placement.nodes:
             if num_added >= spec.count:
                 break
             rgw_id = self.get_unique_name(daemons, spec.name, name)
             self.log.debug('placing rgw.%s on host %s' % (rgw_id, host))
-            results.append(
-                self._worker_pool.apply_async(self._create_rgw, (rgw_id, host))
-            )
+            args.append((rgw_id, host))
             # add to daemon list so next name(s) will also be unique
             sd = orchestrator.ServiceDescription()
             sd.service_instance = rgw_id
@@ -1232,7 +1205,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
             sd.nodename = host
             daemons.append(sd)
             num_added += 1
-        return SSHWriteCompletion(results)
+        return async_map_completion(self._create_rgw)(args)
 
     def _create_rgw(self, rgw_id, host):
         ret, keyring, err = self.mon_command({
@@ -1246,19 +1219,29 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
 
     def remove_rgw(self, name):
         daemons = self._get_services('rgw')
-        results = []
+
+        args = []
         for d in daemons:
             if d.service_instance == name or d.service_instance.startswith(name + '.'):
-                results.append(self._worker_pool.apply_async(
-                    self._remove_daemon,
-                    ('%s.%s' % (d.service_type, d.service_instance),
-                     d.nodename)))
-        if not results:
-            raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
-        return SSHWriteCompletion(results)
+                args.append(('%s.%s' % (d.service_type, d.service_instance),
+                     d.nodename))
+        if args:
+            return async_map_completion(self._remove_daemon)(args)
+        raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
 
     def update_rgw(self, spec):
-        return self._update_service('rgw', self.add_rgw, spec)
+        daemons = self._get_services('rgw', service_name=spec.name)
+        if len(daemons) > spec.count:
+            # remove some
+            to_remove = len(daemons) - spec.count
+            args = []
+            for d in daemons[0:to_remove]:
+                args.append((d.service_instance, d.nodename))
+            return async_map_completion(self._remove_rgw)(args)
+        elif len(daemons) < spec.count:
+            # add some
+            spec.count -= len(daemons)
+            return self.add_rgw(spec)
 
     def add_rbd_mirror(self, spec):
         if not spec.placement.nodes or len(spec.placement.nodes) < spec.count: