from functools import wraps
import string
+try:
+ from typing import List, Dict, Optional, Callable, TypeVar, Type, Any
+except ImportError:
+ pass # just for type checking
+
+T = TypeVar('T')
+
import six
import os
import random
# multiple bootstrapping / initialization
-class SSHCompletionmMixin(object):
- def __init__(self, result):
- if isinstance(result, multiprocessing.pool.AsyncResult):
- self._result = [result]
- else:
- self._result = result
- assert isinstance(self._result, list)
-
- @property
- def result(self):
- return list(map(lambda r: r.get(), self._result))
+class AsyncCompletion(orchestrator.Completion[T]):
+ def __init__(self, *args, many=False, **kwargs):
+ self.__on_complete = None # type: Callable[[T], Any]
+ self.many = many
+ super(AsyncCompletion, self).__init__(*args, **kwargs)
+ def propagate_to_next(self):
+ # We don't have a synchronous result.
+ pass
-class SSHReadCompletion(SSHCompletionmMixin, orchestrator.ReadCompletion):
@property
- def has_result(self):
- return all(map(lambda r: r.ready(), self._result))
+ def _progress_reference(self):
+ if hasattr(self.__on_complete, 'progress_id'):
+ return self.__on_complete
+ return None
+ @property
+ def _on_complete(self):
+ # type: () -> Optional[Callable[[T], Any]]
+ if self.__on_complete is None:
+ return None
+
+ def callback(result):
+ if self._next_promise:
+ self._next_promise._value = result
+ else:
+ self._value = result
+ super(AsyncCompletion, self).propagate_to_next()
-class SSHWriteCompletion(SSHCompletionmMixin, orchestrator.WriteCompletion):
+ def run(value):
+ if self.many:
+ SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
+ callback=callback)
+ else:
+ SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
+ callback=callback)
- @property
- def has_result(self):
- return all(map(lambda r: r.ready(), self._result))
+ return run
- @property
- def is_effective(self):
- return all(map(lambda r: r.ready(), self._result))
+ @_on_complete.setter
+ def _on_complete(self, inner):
+ # type: (Callable[[T], Any]) -> None
+ self.__on_complete = inner
- @property
- def is_errored(self):
- for r in self._result:
- if not r.ready():
- return False
- if not r.successful():
- return True
- return False
+def ssh_completion(cls=AsyncCompletion, **c_kwargs):
+ # type: (Type[orchestrator.Completion], Any) -> Callable
+ """
+ run the given function through `apply_async()` or `map_asyc()`
+ """
+ def decorator(f):
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ return cls(on_complete=lambda _: f(*args, **kwargs), **c_kwargs)
-class SSHWriteCompletionReady(SSHWriteCompletion):
- def __init__(self, result):
- orchestrator.WriteCompletion.__init__(self)
- self._result = result
+ return wrapper
+ return decorator
- @property
- def result(self):
- return self._result
- @property
- def has_result(self):
- return True
+def async_completion(f):
+ # type: (Callable[..., T]) -> Callable[..., AsyncCompletion[T]]
+ return ssh_completion()(f)
- @property
- def is_effective(self):
- return True
- @property
- def is_errored(self):
- return False
+def async_map_completion(f):
+ # type: (Callable[..., T]) -> Callable[..., AsyncCompletion[T]]
+ return ssh_completion(many=True)(f)
-def log_exceptions(f):
- if six.PY3:
- return f
- else:
- # Python 2 does no exception chaining, thus the
- # real exception is lost
- @wraps(f)
- def wrapper(*args, **kwargs):
- try:
- return f(*args, **kwargs)
- except Exception:
- logger.exception('something went wrong.')
- raise
- return wrapper
+def trivial_completion(f):
+ # type: (Callable[..., T]) -> Callable[..., orchestrator.Completion[T]]
+ return ssh_completion(cls=orchestrator.Completion)(f)
-class SSHOrchestrator(MgrModule, orchestrator.OrchestratorClientMixin):
+class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
_STORE_HOST_PREFIX = "host"
+
+ instance = None
NATIVE_OPTIONS = []
MODULE_OPTIONS = [
{
self._reconfig_ssh()
+ SSHOrchestrator.instance = self
+ self.all_progress_references = list() # type: List[orchestrator.ProgressReference]
+
# load inventory
i = self.get_store('inventory')
if i:
def _get_hosts(self, wanted=None):
return self.inventory_cache.items_filtered(wanted)
+ @async_completion
def add_host(self, host):
"""
Add a host to be managed by the orchestrator.
:param host: host name
"""
- @log_exceptions
- def run(host):
- self.inventory[host] = {}
- self._save_inventory()
- self.inventory_cache[host] = orchestrator.OutdatableData()
- self.service_cache[host] = orchestrator.OutdatableData()
- return "Added host '{}'".format(host)
-
- return SSHWriteCompletion(
- self._worker_pool.apply_async(run, (host,)))
+ self.inventory[host] = {}
+ self._save_inventory()
+ self.inventory_cache[host] = orchestrator.OutdatableData()
+ self.service_cache[host] = orchestrator.OutdatableData()
+ return "Added host '{}'".format(host)
+ @async_completion
def remove_host(self, host):
"""
Remove a host from orchestrator management.
:param host: host name
"""
- @log_exceptions
- def run(host):
- del self.inventory[host]
- self._save_inventory()
- del self.inventory_cache[host]
- del self.service_cache[host]
- return "Removed host '{}'".format(host)
-
- return SSHWriteCompletion(
- self._worker_pool.apply_async(run, (host,)))
+ del self.inventory[host]
+ self._save_inventory()
+ del self.inventory_cache[host]
+ del self.service_cache[host]
+ return "Removed host '{}'".format(host)
+ @trivial_completion
def get_hosts(self):
"""
Return a list of hosts managed by the orchestrator.
Notes:
- skip async: manager reads from cache.
- """
- nodes = [
- orchestrator.InventoryNode(h,
- inventory.Devices([]),
- i.get('labels', []))
- for h, i in self.inventory.items()]
- return orchestrator.TrivialReadCompletion(nodes)
+ TODO:
+ - InventoryNode probably needs to be able to report labels
+ """
+ nodes = [orchestrator.InventoryNode(host_name, []) for host_name in self.inventory_cache]
+"""
def add_host_label(self, host, label):
if host not in self.inventory:
raise OrchestratorError('host %s does not exist' % host)
return SSHWriteCompletion(
self._worker_pool.apply_async(run, (host, label)))
+"""
def _refresh_host_services(self, host):
out, code = self._run_ceph_daemon(
node_name=None,
refresh=False):
hosts = []
- wait_for = []
+ wait_for_args = []
+ in_cache = []
for host, host_info in self.service_cache.items_filtered():
hosts.append(host)
if host_info.outdated(self.service_cache_timeout) or refresh:
self.log.info("refresing stale services for '{}'".format(host))
- wait_for.append(
- SSHReadCompletion(self._worker_pool.apply_async(
- self._refresh_host_services, (host,))))
+ wait_for_args.append((host,))
else:
self.log.debug('have recent services for %s: %s' % (
host, host_info.data))
- wait_for.append(
- orchestrator.TrivialReadCompletion([host_info.data]))
- self._orchestrator_wait(wait_for)
-
- services = {}
- for host, c in zip(hosts, wait_for):
- services[host] = c.result[0]
-
- result = []
- for host, ls in services.items():
- for d in ls:
- if not d['style'].startswith('ceph-daemon'):
- self.log.debug('ignoring non-ceph-daemon on %s: %s' % (host, d))
- continue
- if d['fsid'] != self._cluster_fsid:
- self.log.debug('ignoring foreign daemon on %s: %s' % (host, d))
- continue
- self.log.debug('including %s' % d)
- sd = orchestrator.ServiceDescription()
- sd.service_type = d['name'].split('.')[0]
- if service_type and service_type != sd.service_type:
- continue
- if '.' in d['name']:
- sd.service_instance = '.'.join(d['name'].split('.')[1:])
- else:
- sd.service_instance = host # e.g., crash
- if service_id and service_id != sd.service_instance:
- continue
- if service_name and not sd.service_instance.startswith(service_name + '.'):
- continue
- sd.nodename = host
- sd.container_id = d.get('container_id')
- sd.container_image_name = d.get('container_image_name')
- sd.container_image_id = d.get('container_image_id')
- sd.version = d.get('version')
- sd.status_desc = d['state']
- sd.status = {
- 'running': 1,
- 'stopped': 0,
- 'error': -1,
- 'unknown': -1,
- }[d['state']]
- result.append(sd)
- return result
+ in_cache.append(host_info.data)
+
+ def _get_services_result(self, results):
+ services = {}
+ for host, c in zip(hosts, results + in_cache):
+ services[host] = c.result[0]
+
+ result = []
+ for host, ls in services.items():
+ for d in ls:
+ if not d['style'].startswith('ceph-daemon'):
+ self.log.debug('ignoring non-ceph-daemon on %s: %s' % (host, d))
+ continue
+ if d['fsid'] != self._cluster_fsid:
+ self.log.debug('ignoring foreign daemon on %s: %s' % (host, d))
+ continue
+ self.log.debug('including %s' % d)
+ sd = orchestrator.ServiceDescription()
+ sd.service_type = d['name'].split('.')[0]
+ if service_type and service_type != sd.service_type:
+ continue
+ if '.' in d['name']:
+ sd.service_instance = '.'.join(d['name'].split('.')[1:])
+ else:
+ sd.service_instance = host # e.g., crash
+ if service_id and service_id != sd.service_instance:
+ continue
+ if service_name and not sd.service_instance.startswith(service_name + '.'):
+ continue
+ sd.nodename = host
+ sd.container_id = d.get('container_id')
+ sd.container_image_name = d.get('container_image_name')
+ sd.container_image_id = d.get('container_image_id')
+ sd.version = d.get('version')
+ sd.status_desc = d['state']
+ sd.status = {
+ 'running': 1,
+ 'stopped': 0,
+ 'error': -1,
+ 'unknown': -1,
+ }[d['state']]
+ result.append(sd)
+ return result
+
+ return async_map_completion(self._refresh_host_services)(wait_for_args).then(
+ _get_services_result)
+
def describe_service(self, service_type=None, service_id=None,
node_name=None, refresh=False):
service_type,
service_name=service_name,
service_id=service_id)
- results = []
+ args = []
for d in daemons:
- results.append(self._worker_pool.apply_async(
- self._service_action, (d.service_type, d.service_instance,
- d.nodename, action)))
- if not results:
- if service_name:
- n = service_name + '-*'
- else:
- n = service_id
- raise OrchestratorError(
- 'Unable to find %s.%s daemon(s)' % (
- service_type, n))
- return SSHWriteCompletion(results)
+ args.append((d.service_type, d.service_instance,
+ d.nodename, action))
+ if not args:
+ n = service_name
+ if n:
+ n += '-*'
+ raise orchestrator.OrchestratorError('Unable to find %s.%s daemon(s)' % (
+ service_type, n))
+ return async_map_completion(self._service_action)(args)
def _service_action(self, service_type, service_id, host, action):
if action == 'redeploy':
# this implies the returned hosts are registered
hosts = self._get_hosts()
- @log_exceptions
- def run(host, host_info):
- # type: (str, orchestrator.OutdatableData) -> orchestrator.InventoryNode
+ def run(host_info):
+ # type: (orchestrator.OutdatableData) -> orchestrator.InventoryNode
+ host = host_info.data['name']
if host_info.outdated(self.inventory_cache_timeout) or refresh:
self.log.info("refresh stale inventory for '{}'".format(host))
devices = inventory.Devices.from_json(host_info.data)
return orchestrator.InventoryNode(host, devices)
- results = []
- for key, host_info in hosts:
- result = self._worker_pool.apply_async(run, (key, host_info))
- results.append(result)
-
- return SSHReadCompletion(results)
+ return async_map_completion(run)(hosts.values())
- @log_exceptions
def blink_device_light(self, ident_fault, on, locs):
# type: (str, bool, List[orchestrator.DeviceLightLoc]) -> SSHWriteCompletion
-
def blink(host, dev, ident_fault_, on_):
# type: (str, str, str, bool) -> str
cmd = [
return "Set %s light for %s:%s %s" % (
ident_fault_, host, dev, 'on' if on_ else 'off')
- results = []
- for loc in locs:
- results.append(
- self._worker_pool.apply_async(
- blink,
- (loc.host, loc.dev, ident_fault, on)))
- return SSHWriteCompletion(results)
+ return async_map_completion(blink)(locs)
+
+ @async_completion
+ def _create_osd(self, all_hosts_, drive_group):
+ all_hosts = orchestrator.InventoryNode.get_host_names(all_hosts_)
+ assert len(drive_group.hosts(all_hosts)) == 1
+ assert len(drive_group.data_devices.paths) > 0
+ assert all(map(lambda p: isinstance(p, six.string_types),
+ drive_group.data_devices.paths))
+
+ host = drive_group.hosts(all_hosts)[0]
+ self._require_hosts(host)
+
- @log_exceptions
- def _create_osd(self, host, drive_group):
# get bootstrap key
ret, keyring, err = self.mon_command({
'prefix': 'auth get',
return "Created osd(s) on host '{}'".format(host)
- def create_osds(self, drive_group, all_hosts=None):
+ def create_osds(self, drive_group):
"""
Create a new osd.
- support full drive_group specification
- support batch creation
"""
- assert len(drive_group.hosts(all_hosts)) == 1
- assert len(drive_group.data_devices.paths) > 0
- assert all(map(lambda p: isinstance(p, six.string_types),
- drive_group.data_devices.paths))
-
- host = drive_group.hosts(all_hosts)[0]
- self._require_hosts(host)
- result = self._worker_pool.apply_async(self._create_osd, (host,
- drive_group))
-
- return SSHWriteCompletion(result)
+ return self.get_hosts().then(self._create_osd)
def remove_osds(self, name):
daemons = self._get_services('osd', service_id=name)
- results = []
- for d in daemons:
- results.append(self._worker_pool.apply_async(
- self._remove_daemon,
- ('osd.%s' % d.service_instance, d.nodename)))
- if not results:
- raise OrchestratorError('Unable to find osd.%s' % name)
- return SSHWriteCompletion(results)
+ args = [('osd.%s' % d.service_instance, d.nodename) for d in daemons]
+ return async_map_completion(self._remove_daemon)(args)
def _create_daemon(self, daemon_type, daemon_id, host, keyring,
extra_args=[]):
mon_map = self.get("mon_map")
num_mons = len(mon_map["mons"])
if num == num_mons:
- return SSHWriteCompletionReady("The requested number of monitors exist.")
+ return orchestrator.Completion(value="The requested number of monitors exist.")
if num < num_mons:
raise NotImplementedError("Removing monitors is not supported.")
# TODO: we may want to chain the creation of the monitors so they join
# the quorum one at a time.
- results = []
- for host, network, name in host_specs:
- result = self._worker_pool.apply_async(self._create_mon,
- (host, network, name))
-
- results.append(result)
-
- return SSHWriteCompletion(results)
+ return async_map_completion(self._create_mon)(host_specs)
def _create_mgr(self, host, name):
"""
daemons = self._get_services('mgr')
num_mgrs = len(daemons)
if num == num_mgrs:
- return SSHWriteCompletionReady("The requested number of managers exist.")
+ return orchestrator.Completion(value="The requested number of managers exist.")
self.log.debug("Trying to update managers on: {}".format(host_specs))
# check that all the hosts are registered
connected.append(mgr_map.get('active_name', ''))
for standby in mgr_map.get('standbys', []):
connected.append(standby.get('name', ''))
+ to_remove_damons = []
+ to_remove_mgr = []
for d in daemons:
if d.service_instance not in connected:
- result = self._worker_pool.apply_async(
- self._remove_daemon,
- ('%s.%s' % (d.service_type, d.service_instance),
+ to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance),
d.nodename))
- results.append(result)
num_to_remove -= 1
if num_to_remove == 0:
break
# otherwise, remove *any* mgr
if num_to_remove > 0:
for daemon in daemons:
- result = self._worker_pool.apply_async(
- self._remove_daemon,
- ('%s.%s' % (d.service_type, d.service_instance),
- d.nodename))
- results.append(result)
+ to_remove_mgr.append((('%s.%s' % (d.service_type, d.service_instance), daemon.nodename))
num_to_remove -= 1
if num_to_remove == 0:
break
+ return async_map_completion(self._remove_daemon)(to_remove_damons).then(
+ lambda remove_daemon_result: async_map_completion(self._remove_daemon)(to_remove_mgr).then(
+ lambda remove_mgr_result: remove_daemon_result + remove_mgr_result
+ )
+ )
else:
# we assume explicit placement by which there are the same number of
result = self._worker_pool.apply_async(self._create_mgr,
(host_spec.hostname, name))
results.append(result)
+
+ args = []
+ for host_spec in host_specs:
+ name = host_spec.name or self.get_unique_name(daemons)
+ host = host_spec.hostname
+ args.append((host, name))
+ return async_map_completion(self._create_mgr)(args)
+
return SSHWriteCompletion(results)
if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
raise RuntimeError("must specify at least %d hosts" % spec.count)
daemons = self._get_services('mds')
- results = []
+ args = []
num_added = 0
for host, _, name in spec.placement.nodes:
if num_added >= spec.count:
break
mds_id = self.get_unique_name(daemons, spec.name, name)
self.log.debug('placing mds.%s on host %s' % (mds_id, host))
- results.append(
- self._worker_pool.apply_async(self._create_mds, (mds_id, host))
- )
+ args.append((mds_id, host))
# add to daemon list so next name(s) will also be unique
sd = orchestrator.ServiceDescription()
sd.service_instance = mds_id
sd.nodename = host
daemons.append(sd)
num_added += 1
- return SSHWriteCompletion(results)
+ return async_map_completion(self._create_mds)(args)
def update_mds(self, spec):
return self._update_service('mds', self.add_mds, spec)
def remove_mds(self, name):
daemons = self._get_services('mds')
- results = []
self.log.debug("Attempting to remove volume: {}".format(name))
- for d in daemons:
- if d.service_instance == name or d.service_instance.startswith(name + '.'):
- results.append(self._worker_pool.apply_async(
- self._remove_daemon,
- ('%s.%s' % (d.service_type, d.service_instance),
- d.nodename)))
- if not results:
- raise OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
- return SSHWriteCompletion(results)
+ if daemons:
+ args = [('%s.%s' % (d.service_type, d.service_instance),
+ d.nodename) for d in daemons]
+ return async_map_completion(self._remove_daemon)(args)
+ raise orchestrator.OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
def add_rgw(self, spec):
if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
'value': spec.name,
})
daemons = self._get_services('rgw')
- results = []
+ args = []
num_added = 0
for host, _, name in spec.placement.nodes:
if num_added >= spec.count:
break
rgw_id = self.get_unique_name(daemons, spec.name, name)
self.log.debug('placing rgw.%s on host %s' % (rgw_id, host))
- results.append(
- self._worker_pool.apply_async(self._create_rgw, (rgw_id, host))
- )
+ args.append((rgw_id, host))
# add to daemon list so next name(s) will also be unique
sd = orchestrator.ServiceDescription()
sd.service_instance = rgw_id
sd.nodename = host
daemons.append(sd)
num_added += 1
- return SSHWriteCompletion(results)
+ return async_map_completion(self._create_rgw)(args)
def _create_rgw(self, rgw_id, host):
ret, keyring, err = self.mon_command({
def remove_rgw(self, name):
daemons = self._get_services('rgw')
- results = []
+
+ args = []
for d in daemons:
if d.service_instance == name or d.service_instance.startswith(name + '.'):
- results.append(self._worker_pool.apply_async(
- self._remove_daemon,
- ('%s.%s' % (d.service_type, d.service_instance),
- d.nodename)))
- if not results:
- raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
- return SSHWriteCompletion(results)
+ args.append(('%s.%s' % (d.service_type, d.service_instance),
+ d.nodename))
+ if args:
+ return async_map_completion(self._remove_daemon)(args)
+ raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
def update_rgw(self, spec):
- return self._update_service('rgw', self.add_rgw, spec)
+ daemons = self._get_services('rgw', service_name=spec.name)
+ if len(daemons) > spec.count:
+ # remove some
+ to_remove = len(daemons) - spec.count
+ args = []
+ for d in daemons[0:to_remove]:
+ args.append((d.service_instance, d.nodename))
+ return async_map_completion(self._remove_rgw)(args)
+ elif len(daemons) < spec.count:
+ # add some
+ spec.count -= len(daemons)
+ return self.add_rgw(spec)
def add_rbd_mirror(self, spec):
if not spec.placement.nodes or len(spec.placement.nodes) < spec.count: