From f3ea1f18e23f63958d305b70c7864120c5a76132 Mon Sep 17 00:00:00 2001 From: Sebastian Wagner Date: Fri, 30 Aug 2019 18:08:46 +0200 Subject: [PATCH] mgr/ssh: Adapt ssh orch to new Completions interface Signed-off-by: Sebastian Wagner --- src/pybind/mgr/orchestrator.py | 2 +- src/pybind/mgr/ssh/module.py | 455 ++++++++++++++++----------------- 2 files changed, 220 insertions(+), 237 deletions(-) diff --git a/src/pybind/mgr/orchestrator.py b/src/pybind/mgr/orchestrator.py index 1ec3959c11c..a88bf7f09fe 100644 --- a/src/pybind/mgr/orchestrator.py +++ b/src/pybind/mgr/orchestrator.py @@ -894,7 +894,7 @@ class Orchestrator(object): raise NotImplementedError() def update_rgw(self, spec): - # type: (StatelessServiceSpec) -> Completion + # type: (RGWSpec) -> Completion """ Update / redeploy existing RGW zone Like for example changing the number of service instances. diff --git a/src/pybind/mgr/ssh/module.py b/src/pybind/mgr/ssh/module.py index dbd33e7d633..af42e778ad3 100644 --- a/src/pybind/mgr/ssh/module.py +++ b/src/pybind/mgr/ssh/module.py @@ -4,6 +4,13 @@ import logging from functools import wraps import string +try: + from typing import List, Dict, Optional, Callable, TypeVar, Type, Any +except ImportError: + pass # just for type checking + +T = TypeVar('T') + import six import os import random @@ -64,87 +71,86 @@ except ImportError: # multiple bootstrapping / initialization -class SSHCompletionmMixin(object): - def __init__(self, result): - if isinstance(result, multiprocessing.pool.AsyncResult): - self._result = [result] - else: - self._result = result - assert isinstance(self._result, list) - - @property - def result(self): - return list(map(lambda r: r.get(), self._result)) +class AsyncCompletion(orchestrator.Completion[T]): + def __init__(self, *args, many=False, **kwargs): + self.__on_complete = None # type: Callable[[T], Any] + self.many = many + super(AsyncCompletion, self).__init__(*args, **kwargs) + def propagate_to_next(self): + # We don't have a synchronous result. + pass -class SSHReadCompletion(SSHCompletionmMixin, orchestrator.ReadCompletion): @property - def has_result(self): - return all(map(lambda r: r.ready(), self._result)) + def _progress_reference(self): + if hasattr(self.__on_complete, 'progress_id'): + return self.__on_complete + return None + @property + def _on_complete(self): + # type: () -> Optional[Callable[[T], Any]] + if self.__on_complete is None: + return None + + def callback(result): + if self._next_promise: + self._next_promise._value = result + else: + self._value = result + super(AsyncCompletion, self).propagate_to_next() -class SSHWriteCompletion(SSHCompletionmMixin, orchestrator.WriteCompletion): + def run(value): + if self.many: + SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value, + callback=callback) + else: + SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,), + callback=callback) - @property - def has_result(self): - return all(map(lambda r: r.ready(), self._result)) + return run - @property - def is_effective(self): - return all(map(lambda r: r.ready(), self._result)) + @_on_complete.setter + def _on_complete(self, inner): + # type: (Callable[[T], Any]) -> None + self.__on_complete = inner - @property - def is_errored(self): - for r in self._result: - if not r.ready(): - return False - if not r.successful(): - return True - return False +def ssh_completion(cls=AsyncCompletion, **c_kwargs): + # type: (Type[orchestrator.Completion], Any) -> Callable + """ + run the given function through `apply_async()` or `map_asyc()` + """ + def decorator(f): + @wraps(f) + def wrapper(*args, **kwargs): + return cls(on_complete=lambda _: f(*args, **kwargs), **c_kwargs) -class SSHWriteCompletionReady(SSHWriteCompletion): - def __init__(self, result): - orchestrator.WriteCompletion.__init__(self) - self._result = result + return wrapper + return decorator - @property - def result(self): - return self._result - @property - def has_result(self): - return True +def async_completion(f): + # type: (Callable[..., T]) -> Callable[..., AsyncCompletion[T]] + return ssh_completion()(f) - @property - def is_effective(self): - return True - @property - def is_errored(self): - return False +def async_map_completion(f): + # type: (Callable[..., T]) -> Callable[..., AsyncCompletion[T]] + return ssh_completion(many=True)(f) -def log_exceptions(f): - if six.PY3: - return f - else: - # Python 2 does no exception chaining, thus the - # real exception is lost - @wraps(f) - def wrapper(*args, **kwargs): - try: - return f(*args, **kwargs) - except Exception: - logger.exception('something went wrong.') - raise - return wrapper +def trivial_completion(f): + # type: (Callable[..., T]) -> Callable[..., orchestrator.Completion[T]] + return ssh_completion(cls=orchestrator.Completion)(f) -class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): +class SSHOrchestrator(MgrModule, orchestrator.Orchestrator): _STORE_HOST_PREFIX = "host" + + instance = None NATIVE_OPTIONS = [] MODULE_OPTIONS = [ { @@ -192,6 +198,9 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): self._reconfig_ssh() + SSHOrchestrator.instance = self + self.all_progress_references = list() # type: List[orchestrator.ProgressReference] + # load inventory i = self.get_store('inventory') if i: @@ -521,54 +530,45 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): def _get_hosts(self, wanted=None): return self.inventory_cache.items_filtered(wanted) + @async_completion def add_host(self, host): """ Add a host to be managed by the orchestrator. :param host: host name """ - @log_exceptions - def run(host): - self.inventory[host] = {} - self._save_inventory() - self.inventory_cache[host] = orchestrator.OutdatableData() - self.service_cache[host] = orchestrator.OutdatableData() - return "Added host '{}'".format(host) - - return SSHWriteCompletion( - self._worker_pool.apply_async(run, (host,))) + self.inventory[host] = {} + self._save_inventory() + self.inventory_cache[host] = orchestrator.OutdatableData() + self.service_cache[host] = orchestrator.OutdatableData() + return "Added host '{}'".format(host) + @async_completion def remove_host(self, host): """ Remove a host from orchestrator management. :param host: host name """ - @log_exceptions - def run(host): - del self.inventory[host] - self._save_inventory() - del self.inventory_cache[host] - del self.service_cache[host] - return "Removed host '{}'".format(host) - - return SSHWriteCompletion( - self._worker_pool.apply_async(run, (host,))) + del self.inventory[host] + self._save_inventory() + del self.inventory_cache[host] + del self.service_cache[host] + return "Removed host '{}'".format(host) + @trivial_completion def get_hosts(self): """ Return a list of hosts managed by the orchestrator. Notes: - skip async: manager reads from cache. - """ - nodes = [ - orchestrator.InventoryNode(h, - inventory.Devices([]), - i.get('labels', [])) - for h, i in self.inventory.items()] - return orchestrator.TrivialReadCompletion(nodes) + TODO: + - InventoryNode probably needs to be able to report labels + """ + nodes = [orchestrator.InventoryNode(host_name, []) for host_name in self.inventory_cache] +""" def add_host_label(self, host, label): if host not in self.inventory: raise OrchestratorError('host %s does not exist' % host) @@ -600,6 +600,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): return SSHWriteCompletion( self._worker_pool.apply_async(run, (host, label))) +""" def _refresh_host_services(self, host): out, code = self._run_ceph_daemon( @@ -616,61 +617,63 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): node_name=None, refresh=False): hosts = [] - wait_for = [] + wait_for_args = [] + in_cache = [] for host, host_info in self.service_cache.items_filtered(): hosts.append(host) if host_info.outdated(self.service_cache_timeout) or refresh: self.log.info("refresing stale services for '{}'".format(host)) - wait_for.append( - SSHReadCompletion(self._worker_pool.apply_async( - self._refresh_host_services, (host,)))) + wait_for_args.append((host,)) else: self.log.debug('have recent services for %s: %s' % ( host, host_info.data)) - wait_for.append( - orchestrator.TrivialReadCompletion([host_info.data])) - self._orchestrator_wait(wait_for) - - services = {} - for host, c in zip(hosts, wait_for): - services[host] = c.result[0] - - result = [] - for host, ls in services.items(): - for d in ls: - if not d['style'].startswith('ceph-daemon'): - self.log.debug('ignoring non-ceph-daemon on %s: %s' % (host, d)) - continue - if d['fsid'] != self._cluster_fsid: - self.log.debug('ignoring foreign daemon on %s: %s' % (host, d)) - continue - self.log.debug('including %s' % d) - sd = orchestrator.ServiceDescription() - sd.service_type = d['name'].split('.')[0] - if service_type and service_type != sd.service_type: - continue - if '.' in d['name']: - sd.service_instance = '.'.join(d['name'].split('.')[1:]) - else: - sd.service_instance = host # e.g., crash - if service_id and service_id != sd.service_instance: - continue - if service_name and not sd.service_instance.startswith(service_name + '.'): - continue - sd.nodename = host - sd.container_id = d.get('container_id') - sd.container_image_name = d.get('container_image_name') - sd.container_image_id = d.get('container_image_id') - sd.version = d.get('version') - sd.status_desc = d['state'] - sd.status = { - 'running': 1, - 'stopped': 0, - 'error': -1, - 'unknown': -1, - }[d['state']] - result.append(sd) - return result + in_cache.append(host_info.data) + + def _get_services_result(self, results): + services = {} + for host, c in zip(hosts, results + in_cache): + services[host] = c.result[0] + + result = [] + for host, ls in services.items(): + for d in ls: + if not d['style'].startswith('ceph-daemon'): + self.log.debug('ignoring non-ceph-daemon on %s: %s' % (host, d)) + continue + if d['fsid'] != self._cluster_fsid: + self.log.debug('ignoring foreign daemon on %s: %s' % (host, d)) + continue + self.log.debug('including %s' % d) + sd = orchestrator.ServiceDescription() + sd.service_type = d['name'].split('.')[0] + if service_type and service_type != sd.service_type: + continue + if '.' in d['name']: + sd.service_instance = '.'.join(d['name'].split('.')[1:]) + else: + sd.service_instance = host # e.g., crash + if service_id and service_id != sd.service_instance: + continue + if service_name and not sd.service_instance.startswith(service_name + '.'): + continue + sd.nodename = host + sd.container_id = d.get('container_id') + sd.container_image_name = d.get('container_image_name') + sd.container_image_id = d.get('container_image_id') + sd.version = d.get('version') + sd.status_desc = d['state'] + sd.status = { + 'running': 1, + 'stopped': 0, + 'error': -1, + 'unknown': -1, + }[d['state']] + result.append(sd) + return result + + return async_map_completion(self._refresh_host_services)(wait_for_args).then( + _get_services_result) + def describe_service(self, service_type=None, service_id=None, node_name=None, refresh=False): @@ -695,20 +698,17 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): service_type, service_name=service_name, service_id=service_id) - results = [] + args = [] for d in daemons: - results.append(self._worker_pool.apply_async( - self._service_action, (d.service_type, d.service_instance, - d.nodename, action))) - if not results: - if service_name: - n = service_name + '-*' - else: - n = service_id - raise OrchestratorError( - 'Unable to find %s.%s daemon(s)' % ( - service_type, n)) - return SSHWriteCompletion(results) + args.append((d.service_type, d.service_instance, + d.nodename, action)) + if not args: + n = service_name + if n: + n += '-*' + raise orchestrator.OrchestratorError('Unable to find %s.%s daemon(s)' % ( + service_type, n)) + return async_map_completion(self._service_action)(args) def _service_action(self, service_type, service_id, host, action): if action == 'redeploy': @@ -759,9 +759,9 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): # this implies the returned hosts are registered hosts = self._get_hosts() - @log_exceptions - def run(host, host_info): - # type: (str, orchestrator.OutdatableData) -> orchestrator.InventoryNode + def run(host_info): + # type: (orchestrator.OutdatableData) -> orchestrator.InventoryNode + host = host_info.data['name'] if host_info.outdated(self.inventory_cache_timeout) or refresh: self.log.info("refresh stale inventory for '{}'".format(host)) @@ -778,17 +778,10 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): devices = inventory.Devices.from_json(host_info.data) return orchestrator.InventoryNode(host, devices) - results = [] - for key, host_info in hosts: - result = self._worker_pool.apply_async(run, (key, host_info)) - results.append(result) - - return SSHReadCompletion(results) + return async_map_completion(run)(hosts.values()) - @log_exceptions def blink_device_light(self, ident_fault, on, locs): # type: (str, bool, List[orchestrator.DeviceLightLoc]) -> SSHWriteCompletion - def blink(host, dev, ident_fault_, on_): # type: (str, str, str, bool) -> str cmd = [ @@ -807,16 +800,20 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): return "Set %s light for %s:%s %s" % ( ident_fault_, host, dev, 'on' if on_ else 'off') - results = [] - for loc in locs: - results.append( - self._worker_pool.apply_async( - blink, - (loc.host, loc.dev, ident_fault, on))) - return SSHWriteCompletion(results) + return async_map_completion(blink)(locs) + + @async_completion + def _create_osd(self, all_hosts_, drive_group): + all_hosts = orchestrator.InventoryNode.get_host_names(all_hosts_) + assert len(drive_group.hosts(all_hosts)) == 1 + assert len(drive_group.data_devices.paths) > 0 + assert all(map(lambda p: isinstance(p, six.string_types), + drive_group.data_devices.paths)) + + host = drive_group.hosts(all_hosts)[0] + self._require_hosts(host) + - @log_exceptions - def _create_osd(self, host, drive_group): # get bootstrap key ret, keyring, err = self.mon_command({ 'prefix': 'auth get', @@ -881,7 +878,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): return "Created osd(s) on host '{}'".format(host) - def create_osds(self, drive_group, all_hosts=None): + def create_osds(self, drive_group): """ Create a new osd. @@ -894,29 +891,13 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): - support full drive_group specification - support batch creation """ - assert len(drive_group.hosts(all_hosts)) == 1 - assert len(drive_group.data_devices.paths) > 0 - assert all(map(lambda p: isinstance(p, six.string_types), - drive_group.data_devices.paths)) - - host = drive_group.hosts(all_hosts)[0] - self._require_hosts(host) - result = self._worker_pool.apply_async(self._create_osd, (host, - drive_group)) - - return SSHWriteCompletion(result) + return self.get_hosts().then(self._create_osd) def remove_osds(self, name): daemons = self._get_services('osd', service_id=name) - results = [] - for d in daemons: - results.append(self._worker_pool.apply_async( - self._remove_daemon, - ('osd.%s' % d.service_instance, d.nodename))) - if not results: - raise OrchestratorError('Unable to find osd.%s' % name) - return SSHWriteCompletion(results) + args = [('osd.%s' % d.service_instance, d.nodename) for d in daemons] + return async_map_completion(self._remove_daemon)(args) def _create_daemon(self, daemon_type, daemon_id, host, keyring, extra_args=[]): @@ -1022,7 +1003,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): mon_map = self.get("mon_map") num_mons = len(mon_map["mons"]) if num == num_mons: - return SSHWriteCompletionReady("The requested number of monitors exist.") + return orchestrator.Completion(value="The requested number of monitors exist.") if num < num_mons: raise NotImplementedError("Removing monitors is not supported.") @@ -1051,14 +1032,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): # TODO: we may want to chain the creation of the monitors so they join # the quorum one at a time. - results = [] - for host, network, name in host_specs: - result = self._worker_pool.apply_async(self._create_mon, - (host, network, name)) - - results.append(result) - - return SSHWriteCompletion(results) + return async_map_completion(self._create_mon)(host_specs) def _create_mgr(self, host, name): """ @@ -1084,7 +1058,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): daemons = self._get_services('mgr') num_mgrs = len(daemons) if num == num_mgrs: - return SSHWriteCompletionReady("The requested number of managers exist.") + return orchestrator.Completion(value="The requested number of managers exist.") self.log.debug("Trying to update managers on: {}".format(host_specs)) # check that all the hosts are registered @@ -1102,13 +1076,12 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): connected.append(mgr_map.get('active_name', '')) for standby in mgr_map.get('standbys', []): connected.append(standby.get('name', '')) + to_remove_damons = [] + to_remove_mgr = [] for d in daemons: if d.service_instance not in connected: - result = self._worker_pool.apply_async( - self._remove_daemon, - ('%s.%s' % (d.service_type, d.service_instance), + to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance), d.nodename)) - results.append(result) num_to_remove -= 1 if num_to_remove == 0: break @@ -1116,14 +1089,15 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): # otherwise, remove *any* mgr if num_to_remove > 0: for daemon in daemons: - result = self._worker_pool.apply_async( - self._remove_daemon, - ('%s.%s' % (d.service_type, d.service_instance), - d.nodename)) - results.append(result) + to_remove_mgr.append((('%s.%s' % (d.service_type, d.service_instance), daemon.nodename)) num_to_remove -= 1 if num_to_remove == 0: break + return async_map_completion(self._remove_daemon)(to_remove_damons).then( + lambda remove_daemon_result: async_map_completion(self._remove_daemon)(to_remove_mgr).then( + lambda remove_mgr_result: remove_daemon_result + remove_mgr_result + ) + ) else: # we assume explicit placement by which there are the same number of @@ -1150,6 +1124,14 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): result = self._worker_pool.apply_async(self._create_mgr, (host_spec.hostname, name)) results.append(result) + + args = [] + for host_spec in host_specs: + name = host_spec.name or self.get_unique_name(daemons) + host = host_spec.hostname + args.append((host, name)) + return async_map_completion(self._create_mgr)(args) + return SSHWriteCompletion(results) @@ -1157,16 +1139,14 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): if not spec.placement.nodes or len(spec.placement.nodes) < spec.count: raise RuntimeError("must specify at least %d hosts" % spec.count) daemons = self._get_services('mds') - results = [] + args = [] num_added = 0 for host, _, name in spec.placement.nodes: if num_added >= spec.count: break mds_id = self.get_unique_name(daemons, spec.name, name) self.log.debug('placing mds.%s on host %s' % (mds_id, host)) - results.append( - self._worker_pool.apply_async(self._create_mds, (mds_id, host)) - ) + args.append((mds_id, host)) # add to daemon list so next name(s) will also be unique sd = orchestrator.ServiceDescription() sd.service_instance = mds_id @@ -1174,7 +1154,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): sd.nodename = host daemons.append(sd) num_added += 1 - return SSHWriteCompletion(results) + return async_map_completion(self._create_mds)(args) def update_mds(self, spec): return self._update_service('mds', self.add_mds, spec) @@ -1192,17 +1172,12 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): def remove_mds(self, name): daemons = self._get_services('mds') - results = [] self.log.debug("Attempting to remove volume: {}".format(name)) - for d in daemons: - if d.service_instance == name or d.service_instance.startswith(name + '.'): - results.append(self._worker_pool.apply_async( - self._remove_daemon, - ('%s.%s' % (d.service_type, d.service_instance), - d.nodename))) - if not results: - raise OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name) - return SSHWriteCompletion(results) + if daemons: + args = [('%s.%s' % (d.service_type, d.service_instance), + d.nodename) for d in daemons] + return async_map_completion(self._remove_daemon)(args) + raise orchestrator.OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name) def add_rgw(self, spec): if not spec.placement.nodes or len(spec.placement.nodes) < spec.count: @@ -1215,16 +1190,14 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): 'value': spec.name, }) daemons = self._get_services('rgw') - results = [] + args = [] num_added = 0 for host, _, name in spec.placement.nodes: if num_added >= spec.count: break rgw_id = self.get_unique_name(daemons, spec.name, name) self.log.debug('placing rgw.%s on host %s' % (rgw_id, host)) - results.append( - self._worker_pool.apply_async(self._create_rgw, (rgw_id, host)) - ) + args.append((rgw_id, host)) # add to daemon list so next name(s) will also be unique sd = orchestrator.ServiceDescription() sd.service_instance = rgw_id @@ -1232,7 +1205,7 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): sd.nodename = host daemons.append(sd) num_added += 1 - return SSHWriteCompletion(results) + return async_map_completion(self._create_rgw)(args) def _create_rgw(self, rgw_id, host): ret, keyring, err = self.mon_command({ @@ -1246,19 +1219,29 @@ class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin): def remove_rgw(self, name): daemons = self._get_services('rgw') - results = [] + + args = [] for d in daemons: if d.service_instance == name or d.service_instance.startswith(name + '.'): - results.append(self._worker_pool.apply_async( - self._remove_daemon, - ('%s.%s' % (d.service_type, d.service_instance), - d.nodename))) - if not results: - raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name) - return SSHWriteCompletion(results) + args.append(('%s.%s' % (d.service_type, d.service_instance), + d.nodename)) + if args: + return async_map_completion(self._remove_daemon)(args) + raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name) def update_rgw(self, spec): - return self._update_service('rgw', self.add_rgw, spec) + daemons = self._get_services('rgw', service_name=spec.name) + if len(daemons) > spec.count: + # remove some + to_remove = len(daemons) - spec.count + args = [] + for d in daemons[0:to_remove]: + args.append((d.service_instance, d.nodename)) + return async_map_completion(self._remove_rgw)(args) + elif len(daemons) < spec.count: + # add some + spec.count -= len(daemons) + return self.add_rgw(spec) def add_rbd_mirror(self, spec): if not spec.placement.nodes or len(spec.placement.nodes) < spec.count: -- 2.39.5