helper to build orchestrator modules.
"""
INITIALIZED = 1 # We have a parent completion and a next completion
- FINISHED = 2 # we have a final result
+ RUNNING = 2
+ FINISHED = 3 # we have a final result
NO_RESULT = _no_result() # type: None
+ ASYNC_RESULT = object()
def __init__(self,
_first_promise=None, # type: Optional["_Promise"]
value=NO_RESULT, # type: Optional
- on_complete=None # type: Optional[Callable]
+ on_complete=None, # type: Optional[Callable]
+ name=None, # type: Optional[str]
):
self._on_complete = on_complete
+ self._name = name
self._next_promise = None # type: Optional[_Promise]
self._state = self.INITIALIZED
self._first_promise = _first_promise or self # type: 'Completion'
def __repr__(self):
- name = getattr(self._on_complete, '__name__', '??') if self._on_complete else 'None'
+ name = self._name or getattr(self._on_complete, '__name__', '??') if self._on_complete else 'None'
val = repr(self._value) if self._value is not self.NO_RESULT else 'NA'
- return '{}(_s={}, val={}, id={}, name={}, pr={}, _next={})'.format(
- self.__class__, self._state, val, id(self), name, getattr(next, '_progress_reference', 'NA'), repr(self._next_promise)
+ return '{}(_s={}, val={}, _on_c={}, id={}, name={}, pr={}, _next={})'.format(
+ self.__class__, self._state, val, self._on_complete, id(self), name, getattr(next, '_progress_reference', 'NA'), repr(self._next_promise)
)
def then(self, on_complete):
"""
Call ``on_complete`` as soon as this promise is finalized.
"""
- assert self._state is self.INITIALIZED
+ assert self._state in (self.INITIALIZED, self.RUNNING)
if self._on_complete is not None:
assert self._next_promise is None
self._set_next_promise(self.__class__(
def _set_next_promise(self, next):
# type: (_Promise) -> None
assert self is not next
- assert self._state is self.INITIALIZED
+ assert self._state in (self.INITIALIZED, self.RUNNING)
self._next_promise = next
assert self._next_promise is not None
for p in iter(self._next_promise):
p._first_promise = self._first_promise
- def finalize(self, value=NO_RESULT):
+ def _finalize(self, value=NO_RESULT):
"""
Sets this promise to complete.
:param value: new value.
"""
- assert self._state is self.INITIALIZED
+ assert self._state in (self.INITIALIZED, self.RUNNING)
+
+ self._state = self.RUNNING
if value is not self.NO_RESULT:
self._value = value
- assert self._value is not self.NO_RESULT
+ assert self._value is not self.NO_RESULT, repr(self)
if self._on_complete:
try:
self._set_next_promise(next_result)
if self._next_promise._value is self.NO_RESULT:
self._next_promise._value = self._value
- else:
+ self.propagate_to_next()
+ elif next_result is not self.ASYNC_RESULT:
# simple map. simply forward
if self._next_promise:
self._next_promise._value = next_result
else:
# Hack: next_result is of type U, _value is of type T
self._value = next_result # type: ignore
- self._state = self.FINISHED
- logger.debug('finalized {}'.format(repr(self)))
- self.propagate_to_next()
+ self.propagate_to_next()
+ else:
+ # asynchronous promise
+ pass
+
def propagate_to_next(self):
- assert self._state is self.FINISHED
+ self._state = self.FINISHED
+ logger.debug('finalized {}'.format(repr(self)))
if self._next_promise:
- self._next_promise.finalize()
+ self._next_promise._finalize()
def fail(self, e):
# type: (Exception) -> None
Sets the whole completion to be faild with this exception and end the
evaluation.
"""
- assert self._state is self.INITIALIZED
+ assert self._state in (self.INITIALIZED, self.RUNNING)
logger.exception('_Promise failed')
self._exception = e
self._value = 'exception'
return self.progress == 1 and self._completion_has_result
def update(self):
- def run(progress):
+ def progress_run(progress):
self.progress = progress
if self.completion:
- c = self.completion().then(run)
+ c = self.completion().then(progress_run)
self.mgr.process([c._first_promise])
else:
self.progress = 1
def __init__(self,
_first_promise=None, # type: Optional["Completion"]
value=_Promise.NO_RESULT, # type: Any
- on_complete=None # type: Optional[Callable]
+ on_complete=None, # type: Optional[Callable],
+ name=None, # type: Optional[str]
):
- super(Completion, self).__init__(_first_promise, value, on_complete)
+ super(Completion, self).__init__(_first_promise, value, on_complete, name)
@property
def _progress_reference(self):
if self._progress_reference:
self._progress_reference.fail()
+ def finalize(self, result=_Promise.NO_RESULT):
+ if self._first_promise._state == self.INITIALIZED:
+ self._first_promise._finalize(result)
+
@property
def result(self):
"""
raise NotImplementedError()
def update_mons(self, num, hosts):
- # type: (int, List[Tuple[str,str]]) -> Completion
+ # type: (int, List[Tuple[str,str,str]]) -> Completion
"""
Update the number of cluster monitors.
raise NotImplementedError()
def add_rbd_mirror(self, spec):
- # type: (StatelessServiceSpec) -> WriteCompletion
+ # type: (StatelessServiceSpec) -> Completion
"""Create rbd-mirror cluster"""
raise NotImplementedError()
- def remove_rbd_mirror(self):
- # type: (str) -> WriteCompletion
+ def remove_rbd_mirror(self, name):
+ # type: (str) -> Completion
"""Remove rbd-mirror cluster"""
raise NotImplementedError()
def update_rbd_mirror(self, spec):
- # type: (StatelessServiceSpec) -> WriteCompletion
+ # type: (StatelessServiceSpec) -> Completion
"""
Update / redeploy rbd-mirror cluster
Like for example changing the number of service instances.
# type: (List[InventoryNode]) -> List[str]
return [node.name for node in nodes]
+ def __eq__(self, other):
+ return self.name == other.name and self.devices == other.devices
+
class DeviceLightLoc(namedtuple('DeviceLightLoc', ['host', 'dev'])):
"""
class RookCompletion(orchestrator.Completion):
def evaluate(self):
- self._first_promise.finalize(None)
+ self.finalize(None)
def deferred_read(f):
import orchestrator
if what == 'OrchestratorError':
c = orchestrator.TrivialReadCompletion(result=None)
- c.exception = orchestrator.OrchestratorError('hello', 'world')
+ c.fail(orchestrator.OrchestratorError('hello', 'world'))
return c
elif what == "ZeroDivisionError":
c = orchestrator.TrivialReadCompletion(result=None)
- c.exception = ZeroDivisionError('hello', 'world')
+ c.fail(ZeroDivisionError('hello', 'world'))
return c
assert False, repr(what)
class AsyncCompletion(orchestrator.Completion):
- def __init__(self, *args, **kwargs):
- self.__on_complete = None # type: Callable
- self.many = kwargs.pop('many', False)
- super(AsyncCompletion, self).__init__(*args, **kwargs)
-
- def propagate_to_next(self):
- # We don't have a synchronous result.
- pass
+ def __init__(self,
+ _first_promise=None, # type: Optional["Completion"]
+ value=orchestrator._Promise.NO_RESULT, # type: Any
+ on_complete=None, # type: Optional[Callable],
+ name=None, # type: Optional[str],
+ many=False, # type: bool
+ ):
+
+ assert SSHOrchestrator.instance is not None
+ self.many = many
+ if name is None and on_complete is not None:
+ name = on_complete.__name__
+ super(AsyncCompletion, self).__init__(_first_promise, value, on_complete, name)
@property
def _progress_reference(self):
return None
def callback(result):
- if self._next_promise:
- self._next_promise._value = result
- else:
- self._value = result
- super(AsyncCompletion, self).propagate_to_next()
+ try:
+ self.__on_complete = None
+ self._finalize(result)
+ except Exception as e:
+ self.fail(e)
+
+ def error_callback(e):
+ self.fail(e)
def run(value):
if self.many:
- SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
- callback=callback)
+ if not value:
+ logger.info('calling map_async without values')
+ callback([])
+ if six.PY3:
+ SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
+ callback=callback,
+ error_callback=error_callback)
+ else:
+ SSHOrchestrator.instance._worker_pool.map_async(self.__on_complete, value,
+ callback=callback)
else:
- SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
- callback=callback)
+ if six.PY3:
+ SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
+ callback=callback, error_callback=error_callback)
+ else:
+ SSHOrchestrator.instance._worker_pool.apply_async(self.__on_complete, (value,),
+ callback=callback)
+ return self.ASYNC_RESULT
return run
"""
def decorator(f):
@wraps(f)
- def wrapper(*args, **kwargs):
- return cls(on_complete=lambda _: f(*args, **kwargs), **c_kwargs)
+ def wrapper(*args):
+
+ name = f.__name__
+ many = c_kwargs.get('many', False)
+
+ # Some weired logic to make calling functions with multiple arguments work.
+ if len(args) == 1:
+ [value] = args
+ if many and value and isinstance(value[0], tuple):
+ return cls(on_complete=lambda x: f(*x), value=value, name=name, **c_kwargs)
+ else:
+ return cls(on_complete=f, value=value, name=name, **c_kwargs)
+ else:
+ if many:
+ self, value = args
+
+ def call_self(inner_args):
+ if not isinstance(inner_args, tuple):
+ inner_args = (inner_args, )
+ return f(self, *inner_args)
+
+ return cls(on_complete=call_self, value=value, name=name, **c_kwargs)
+ else:
+ return cls(on_complete=lambda x: f(*x), value=args, name=name, **c_kwargs)
+
return wrapper
return decorator
def async_map_completion(f):
# type: (Callable) -> Callable[..., AsyncCompletion]
+ """
+ kind of similar to
+
+ >>> def sync_map(f):
+ ... return lambda x: map(f, x)
+
+ Limitation: This does not work, as you cannot return completions form `f`
+
+ >>> @async_map_completion
+ ... def run(x):
+ ... return async_completion(str)(x)
+ """
return ssh_completion(many=True)(f)
return ssh_completion(cls=orchestrator.Completion)(f)
+def trivial_result(val):
+ return AsyncCompletion(value=val, name='trivial_result')
+
+
class SSHOrchestrator(MgrModule, orchestrator.Orchestrator):
_STORE_HOST_PREFIX = "host"
try:
with open(path, 'r') as f:
self._ceph_daemon = f.read()
- except IOError as e:
+ except (IOError, TypeError) as e:
raise RuntimeError("unable to read ceph-daemon at '%s': %s" % (
path, str(e)))
if h not in self.inventory:
del self.service_cache[h]
+ def shutdown(self):
+ self.log.error('ssh: shutdown')
+ self._worker_pool.close()
+ self._worker_pool.join()
+ self._worker_pool = None
+
def config_notify(self):
"""
This method is called whenever one of our config options is changed.
"""
Does nothing, as completions are processed in another thread.
"""
+ if completions:
+ self.log.info("wait: promises={0}".format(completions))
+
+ for p in completions:
+ p.finalize()
def _require_hosts(self, hosts):
"""
TODO:
- InventoryNode probably needs to be able to report labels
"""
- nodes = [orchestrator.InventoryNode(host_name, []) for host_name in self.inventory_cache]
+ return [orchestrator.InventoryNode(host_name) for host_name in self.inventory_cache]
+
"""
def add_host_label(self, host, label):
if host not in self.inventory:
self._worker_pool.apply_async(run, (host, label)))
"""
+ @async_map_completion
def _refresh_host_services(self, host):
out, code = self._run_ceph_daemon(
host, 'mon', 'ls', [], no_fsid=True)
data = json.loads(''.join(out))
- self.log.debug('refreshed host %s services: %s' % (host, data))
+ self.log.error('refreshed host %s services: %s' % (host, data))
self.service_cache[host] = orchestrator.OutdatableData(data)
return data
host, host_info.data))
in_cache.append(host_info.data)
- def _get_services_result(self, results):
+ def _get_services_result(results):
services = {}
- for host, c in zip(hosts, results + in_cache):
- services[host] = c.result[0]
+ for host, data in zip(hosts, results + in_cache):
+ services[host] = data
result = []
for host, ls in services.items():
result.append(sd)
return result
- return async_map_completion(self._refresh_host_services)(wait_for_args).then(
+ return self._refresh_host_services(wait_for_args).then(
_get_services_result)
service_id=service_id,
node_name=node_name,
refresh=refresh)
- return orchestrator.TrivialReadCompletion(result)
+ return result
def service_action(self, action, service_type,
service_name=None,
self.log.debug('service_action action %s type %s name %s id %s' % (
action, service_type, service_name, service_id))
if action == 'reload':
- return orchestrator.TrivialReadCompletion(
- ["Reload is a no-op"])
+ return trivial_result(["Reload is a no-op"])
daemons = self._get_services(
service_type,
service_name=service_name,
args.append((d.service_type, d.service_instance,
d.nodename, action))
if not args:
- n = service_name
- if n:
- n += '-*'
- raise orchestrator.OrchestratorError('Unable to find %s.%s daemon(s)' % (
- service_type, n))
- return async_map_completion(self._service_action)(args)
+ if service_name:
+ n = service_name + '-*'
+ else:
+ n = service_id
+ raise orchestrator.OrchestratorError(
+ 'Unable to find %s.%s daemon(s)' % (
+ service_type, n))
+ return self._service_action(args)
+ @async_map_completion
def _service_action(self, service_type, service_id, host, action):
if action == 'redeploy':
# recreate the systemd unit and then restart
# this implies the returned hosts are registered
hosts = self._get_hosts()
- def run(host_info):
- # type: (orchestrator.OutdatableData) -> orchestrator.InventoryNode
- host = host_info.data['name']
+ @async_map_completion
+ def _get_inventory(host, host_info):
+ # type: (str, orchestrator.OutdatableData) -> orchestrator.InventoryNode
if host_info.outdated(self.inventory_cache_timeout) or refresh:
self.log.info("refresh stale inventory for '{}'".format(host))
devices = inventory.Devices.from_json(host_info.data)
return orchestrator.InventoryNode(host, devices)
- return async_map_completion(run)(hosts.values())
+ return _get_inventory(hosts)
def blink_device_light(self, ident_fault, on, locs):
- # type: (str, bool, List[orchestrator.DeviceLightLoc]) -> SSHWriteCompletion
- def blink(host, dev, ident_fault_, on_):
- # type: (str, str, str, bool) -> str
+ @async_map_completion
+ def blink(host, dev):
cmd = [
'lsmcli',
'local-disk-%s-led-%s' % (
- ident_fault_,
- 'on' if on_ else 'off'),
+ ident_fault,
+ 'on' if on else 'off'),
'--path', '/dev/' + dev,
]
out, code = self._run_ceph_daemon(host, 'osd', 'shell', ['--'] + cmd,
if code:
raise RuntimeError(
'Unable to affect %s light for %s:%s. Command: %s' % (
- ident_fault_, host, dev, ' '.join(cmd)))
+ ident_fault, host, dev, ' '.join(cmd)))
return "Set %s light for %s:%s %s" % (
- ident_fault_, host, dev, 'on' if on_ else 'off')
+ ident_fault, host, dev, 'on' if on else 'off')
- return async_map_completion(blink)(locs)
+ return blink(locs)
@async_completion
def _create_osd(self, all_hosts_, drive_group):
- support batch creation
"""
- return self.get_hosts().then(self._create_osd)
+ return self.get_hosts().then(lambda hosts: self._create_osd(hosts, drive_group))
def remove_osds(self, name):
daemons = self._get_services('osd', service_id=name)
args = [('osd.%s' % d.service_instance, d.nodename) for d in daemons]
- return async_map_completion(self._remove_daemon)(args)
+ if not args:
+ raise OrchestratorError('Unable to find osd.%s' % name)
+ return self._remove_daemon(args)
def _create_daemon(self, daemon_type, daemon_id, host, keyring,
extra_args=[]):
self.log.info("create_daemon({}): finished".format(host))
conn.exit()
+ @async_map_completion
def _remove_daemon(self, name, host):
"""
Remove a daemon
return "Removed {} from host '{}'".format(name, host)
def _update_service(self, daemon_type, add_func, spec):
- daemons = self._get_services(daemon_type, service_name=spec.name)
- results = []
- if len(daemons) > spec.count:
- # remove some
- to_remove = len(daemons) - spec.count
- for d in daemons[0:to_remove]:
- results.append(self._worker_pool.apply_async(
- self._remove_daemon,
- ('%s.%s' % (d.service_type, d.service_instance),
- d.nodename)))
- elif len(daemons) < spec.count:
- # add some
- spec.count -= len(daemons)
- return add_func(spec)
- return SSHWriteCompletion(results)
-
+ def ___update_service(daemons):
+ if len(daemons) > spec.count:
+ # remove some
+ to_remove = len(daemons) - spec.count
+ args = []
+ for d in daemons[0:to_remove]:
+ args.append(
+ ('%s.%s' % (d.service_type, d.service_instance), d.nodename)
+ )
+ return self._remove_daemon(args)
+ elif len(daemons) < spec.count:
+ # add some
+ spec.count -= len(daemons)
+ return add_func(spec)
+ return []
+ return self._get_services(daemon_type, service_name=spec.name).then(___update_service)
+
+ @async_map_completion
def _create_mon(self, host, network, name):
"""
Create a new monitor on the given host.
# TODO: we may want to chain the creation of the monitors so they join
# the quorum one at a time.
- return async_map_completion(self._create_mon)(host_specs)
+ return self._create_mon(host_specs)
+ @async_map_completion
def _create_mgr(self, host, name):
"""
Create a new manager instance on a host.
"""
Adjust the number of cluster managers.
"""
- daemons = self._get_services('mgr')
+ return self._get_services('mgr').then(lambda daemons: self._update_mgrs(num, host_specs, daemons))
+
+ def _update_mgrs(self, num, host_specs, daemons):
num_mgrs = len(daemons)
if num == num_mgrs:
return orchestrator.Completion(value="The requested number of managers exist.")
for standby in mgr_map.get('standbys', []):
connected.append(standby.get('name', ''))
to_remove_damons = []
- to_remove_mgr = []
for d in daemons:
if d.service_instance not in connected:
to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance),
# otherwise, remove *any* mgr
if num_to_remove > 0:
- for daemon in daemons:
- to_remove_mgr.append((('%s.%s' % (d.service_type, d.service_instance), daemon.nodename))
+ for d in daemons:
+ to_remove_damons.append(('%s.%s' % (d.service_type, d.service_instance), d.nodename))
num_to_remove -= 1
if num_to_remove == 0:
break
- return async_map_completion(self._remove_daemon)(to_remove_damons).then(
- lambda remove_daemon_result: async_map_completion(self._remove_daemon)(to_remove_mgr).then(
- lambda remove_mgr_result: remove_daemon_result + remove_mgr_result
- )
- )
+ return self._remove_daemon(to_remove_damons)
else:
# we assume explicit placement by which there are the same number of
self.log.info("creating {} managers on hosts: '{}'".format(
num_new_mgrs, ",".join([spec.hostname for spec in host_specs])))
- for host_spec in host_specs:
- name = host_spec.name or self.get_unique_name(daemons)
- result = self._worker_pool.apply_async(self._create_mgr,
- (host_spec.hostname, name))
- results.append(result)
-
args = []
for host_spec in host_specs:
- name = host_spec.name or self.get_unique_name(daemons)
- host = host_spec.hostname
+ name = host_spec.name or self.get_unique_name(daemons)
+ host = host_spec.hostname
args.append((host, name))
- return async_map_completion(self._create_mgr)(args)
-
-
- return SSHWriteCompletion(results)
+ return self._create_mgr(args)
def add_mds(self, spec):
if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
raise RuntimeError("must specify at least %d hosts" % spec.count)
- daemons = self._get_services('mds')
+ return self._get_services('mds').then(lambda ds: self._add_mds(ds, spec))
+
+ def _add_mds(self, daemons, spec):
args = []
num_added = 0
for host, _, name in spec.placement.nodes:
sd.nodename = host
daemons.append(sd)
num_added += 1
- return async_map_completion(self._create_mds)(args)
+ return self._create_mds(args)
def update_mds(self, spec):
return self._update_service('mds', self.add_mds, spec)
+ @async_map_completion
def _create_mds(self, mds_id, host):
# get mgr. key
ret, keyring, err = self.mon_command({
return self._create_daemon('mds', mds_id, host, keyring)
def remove_mds(self, name):
- daemons = self._get_services('mds')
self.log.debug("Attempting to remove volume: {}".format(name))
- if daemons:
- args = [('%s.%s' % (d.service_type, d.service_instance),
- d.nodename) for d in daemons]
- return async_map_completion(self._remove_daemon)(args)
- raise orchestrator.OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
+ def _remove_mds(daemons):
+ args = []
+ for d in daemons:
+ if d.service_instance == name or d.service_instance.startswith(name + '.'):
+ args.append(
+ ('%s.%s' % (d.service_type, d.service_instance), d.nodename)
+ )
+ if not args:
+ raise OrchestratorError('Unable to find mds.%s[-*] daemon(s)' % name)
+ return self._remove_daemon(args)
+ return self._get_services('mds').then(_remove_mds)
def add_rgw(self, spec):
if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
'name': 'rgw_zone',
'value': spec.name,
})
- daemons = self._get_services('rgw')
- args = []
- num_added = 0
- for host, _, name in spec.placement.nodes:
- if num_added >= spec.count:
- break
- rgw_id = self.get_unique_name(daemons, spec.name, name)
- self.log.debug('placing rgw.%s on host %s' % (rgw_id, host))
- args.append((rgw_id, host))
- # add to daemon list so next name(s) will also be unique
- sd = orchestrator.ServiceDescription()
- sd.service_instance = rgw_id
- sd.service_type = 'rgw'
- sd.nodename = host
- daemons.append(sd)
- num_added += 1
- return async_map_completion(self._create_rgw)(args)
+ def _add_rgw(daemons):
+ args = []
+ num_added = 0
+ for host, _, name in spec.placement.nodes:
+ if num_added >= spec.count:
+ break
+ rgw_id = self.get_unique_name(daemons, spec.name, name)
+ self.log.debug('placing rgw.%s on host %s' % (rgw_id, host))
+ args.append((rgw_id, host))
+ # add to daemon list so next name(s) will also be unique
+ sd = orchestrator.ServiceDescription()
+ sd.service_instance = rgw_id
+ sd.service_type = 'rgw'
+ sd.nodename = host
+ daemons.append(sd)
+ num_added += 1
+ return self._create_rgw(args)
+
+ return self._get_services('rgw').then(_add_rgw)
+
+ @async_map_completion
def _create_rgw(self, rgw_id, host):
ret, keyring, err = self.mon_command({
'prefix': 'auth get-or-create',
return self._create_daemon('rgw', rgw_id, host, keyring)
def remove_rgw(self, name):
- daemons = self._get_services('rgw')
- args = []
- for d in daemons:
- if d.service_instance == name or d.service_instance.startswith(name + '.'):
- args.append(('%s.%s' % (d.service_type, d.service_instance),
- d.nodename))
- if args:
- return async_map_completion(self._remove_daemon)(args)
- raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
+ def _remove_rgw(daemons):
+ args = []
+ for d in daemons:
+ if d.service_instance == name or d.service_instance.startswith(name + '.'):
+ args.append(('%s.%s' % (d.service_type, d.service_instance),
+ d.nodename))
+ if args:
+ return self._remove_daemon(args)
+ raise RuntimeError('Unable to find rgw.%s[-*] daemon(s)' % name)
+
+ return self._get_services('rgw').then(_remove_rgw)
def update_rgw(self, spec):
- daemons = self._get_services('rgw', service_name=spec.name)
- if len(daemons) > spec.count:
- # remove some
- to_remove = len(daemons) - spec.count
- args = []
- for d in daemons[0:to_remove]:
- args.append((d.service_instance, d.nodename))
- return async_map_completion(self._remove_rgw)(args)
- elif len(daemons) < spec.count:
- # add some
- spec.count -= len(daemons)
- return self.add_rgw(spec)
+ return self._update_service('rgw', self.add_rgw, spec)
def add_rbd_mirror(self, spec):
if not spec.placement.nodes or len(spec.placement.nodes) < spec.count:
raise RuntimeError("must specify at least %d hosts" % spec.count)
self.log.debug('nodes %s' % spec.placement.nodes)
- daemons = self._get_services('rbd-mirror')
- results = []
- num_added = 0
- for host, _, name in spec.placement.nodes:
- if num_added >= spec.count:
- break
- daemon_id = self.get_unique_name(daemons, None, name)
- self.log.debug('placing rbd-mirror.%s on host %s' % (daemon_id,
- host))
- results.append(
- self._worker_pool.apply_async(self._create_rbd_mirror,
- (daemon_id, host))
- )
- # add to daemon list so next name(s) will also be unique
- sd = orchestrator.ServiceDescription()
- sd.service_instance = daemon_id
- sd.service_type = 'rbd-mirror'
- sd.nodename = host
- daemons.append(sd)
- num_added += 1
- return SSHWriteCompletion(results)
+ def _add_rbd_mirror(daemons):
+ args = []
+ num_added = 0
+ for host, _, name in spec.placement.nodes:
+ if num_added >= spec.count:
+ break
+ daemon_id = self.get_unique_name(daemons, None, name)
+ self.log.debug('placing rbd-mirror.%s on host %s' % (daemon_id,
+ host))
+ args.append((daemon_id, host))
+
+ # add to daemon list so next name(s) will also be unique
+ sd = orchestrator.ServiceDescription()
+ sd.service_instance = daemon_id
+ sd.service_type = 'rbd-mirror'
+ sd.nodename = host
+ daemons.append(sd)
+ num_added += 1
+ return self._create_rbd_mirror(args)
+
+ return self._get_services('rbd-mirror').then(_add_rbd_mirror)
+
+ @async_map_completion
def _create_rbd_mirror(self, daemon_id, host):
ret, keyring, err = self.mon_command({
'prefix': 'auth get-or-create',
return self._create_daemon('rbd-mirror', daemon_id, host, keyring)
def remove_rbd_mirror(self, name):
- daemons = self._get_services('rbd-mirror')
- results = []
- for d in daemons:
- if not name or d.service_instance == name:
- results.append(self._worker_pool.apply_async(
- self._remove_daemon,
- ('%s.%s' % (d.service_type, d.service_instance),
- d.nodename)))
- if not results and name:
- raise RuntimeError('Unable to find rbd-mirror.%s daemon' % name)
- return SSHWriteCompletion(results)
+ def _remove_rbd_mirror(daemons):
+ args = []
+ for d in daemons:
+ if not name or d.service_instance == name:
+ args.append(
+ ('%s.%s' % (d.service_type, d.service_instance),
+ d.nodename)
+ )
+ if not args and name:
+ raise RuntimeError('Unable to find rbd-mirror.%s daemon' % name)
+ return self._remove_daemon(args)
+
+ return self._get_services('rbd-mirror').then(_remove_rbd_mirror)
def update_rbd_mirror(self, spec):
return self._update_service('rbd-mirror', self.add_rbd_mirror, spec)
--- /dev/null
+from contextlib import contextmanager
+
+import pytest
+
+from ssh import SSHOrchestrator
+from tests import mock
+
+
+def set_store(self, k, v):
+ if v is None:
+ del self._store[k]
+ else:
+ self._store[k] = v
+
+
+def get_store(self, k):
+ return self._store[k]
+
+
+def get_store_prefix(self, prefix):
+ return {
+ k: v for k, v in self._store.items()
+ if k.startswith(prefix)
+ }
+
+
+@pytest.yield_fixture()
+def ssh_module():
+ with mock.patch("ssh.module.SSHOrchestrator.get_ceph_option", lambda _, key: __file__),\
+ mock.patch("ssh.module.SSHOrchestrator.set_store", set_store),\
+ mock.patch("ssh.module.SSHOrchestrator.get_store", get_store),\
+ mock.patch("ssh.module.SSHOrchestrator.get_store_prefix", get_store_prefix):
+ m = SSHOrchestrator.__new__ (SSHOrchestrator)
+ m._store = {
+ 'ssh_config': '',
+ 'ssh_identity_key': '',
+ 'ssh_identity_pub': '',
+ }
+ m.__init__('ssh', 0, 0)
+ yield m
--- /dev/null
+import sys
+import time
+
+
+try:
+ from typing import Any
+except ImportError:
+ pass
+
+import pytest
+
+
+from orchestrator import raise_if_exception, Completion
+from .fixtures import ssh_module
+from ..module import trivial_completion, async_completion, async_map_completion, SSHOrchestrator
+
+
+class TestCompletion(object):
+ def _wait(self, m, c):
+ # type: (SSHOrchestrator, Completion) -> Any
+ m.process([c])
+ m.process([c])
+
+ for _ in range(30):
+ if c.is_finished:
+ raise_if_exception(c)
+ return c.result
+ time.sleep(0.1)
+ assert False, "timeout" + str(c._state)
+
+ def test_trivial(self, ssh_module):
+ @trivial_completion
+ def run(x):
+ return x+1
+ assert self._wait(ssh_module, run(1)) == 2
+
+ @pytest.mark.parametrize("input", [
+ ((1, ), ),
+ ((1, 2), ),
+ (("hallo", ), ),
+ (("hallo", "foo"), ),
+ ])
+ def test_async(self, input, ssh_module):
+ @async_completion
+ def run(*args):
+ return str(args)
+
+ assert self._wait(ssh_module, run(*input)) == str(input)
+
+ @pytest.mark.parametrize("input,expected", [
+ ([], []),
+ ([1], ["(1,)"]),
+ (["hallo"], ["('hallo',)"]),
+ ("hi", ["('h',)", "('i',)"]),
+ (list(range(5)), [str((x, )) for x in range(5)]),
+ ([(1, 2), (3, 4)], ["(1, 2)", "(3, 4)"]),
+ ])
+ def test_async_map(self, input, expected, ssh_module):
+ @async_map_completion
+ def run(*args):
+ return str(args)
+
+ c = run(input)
+ self._wait(ssh_module, c)
+ assert c.result == expected
+
+ def test_async_self(self, ssh_module):
+ class Run(object):
+ def __init__(self):
+ self.attr = 1
+
+ @async_completion
+ def run(self, x):
+ assert self.attr == 1
+ return x + 1
+
+ assert self._wait(ssh_module, Run().run(1)) == 2
+
+ @pytest.mark.parametrize("input,expected", [
+ ([], []),
+ ([1], ["(1,)"]),
+ (["hallo"], ["('hallo',)"]),
+ ("hi", ["('h',)", "('i',)"]),
+ (list(range(5)), [str((x, )) for x in range(5)]),
+ ([(1, 2), (3, 4)], ["(1, 2)", "(3, 4)"]),
+ ])
+ def test_async_map_self(self, input, expected, ssh_module):
+ class Run(object):
+ def __init__(self):
+ self.attr = 1
+
+ @async_map_completion
+ def run(self, *args):
+ assert self.attr == 1
+ return str(args)
+
+ c = Run().run(input)
+ self._wait(ssh_module, c)
+ assert c.result == expected
+
+ def test_then1(self, ssh_module):
+ @async_map_completion
+ def run(x):
+ return x+1
+
+ assert self._wait(ssh_module, run([1,2]).then(str)) == '[2, 3]'
+
+ def test_then2(self, ssh_module):
+ @async_map_completion
+ def run(x):
+ time.sleep(0.1)
+ return x+1
+
+ @async_completion
+ def async_str(results):
+ return str(results)
+
+ c = run([1,2]).then(async_str)
+
+ self._wait(ssh_module, c)
+ assert c.result == '[2, 3]'
+
+ def test_then3(self, ssh_module):
+ @async_map_completion
+ def run(x):
+ time.sleep(0.1)
+ return x+1
+
+ def async_str(results):
+ return async_completion(str)(results)
+
+ c = run([1,2]).then(async_str)
+
+ self._wait(ssh_module, c)
+ assert c.result == '[2, 3]'
+
+ def test_then4(self, ssh_module):
+ @async_map_completion
+ def run(x):
+ time.sleep(0.1)
+ return x+1
+
+ def async_str(results):
+ return async_completion(str)(results).then(lambda x: x + "hello")
+
+ c = run([1,2]).then(async_str)
+
+ self._wait(ssh_module, c)
+ assert c.result == '[2, 3]hello'
+
+ @pytest.mark.skip(reason="see limitation of async_map_completion")
+ def test_then5(self, ssh_module):
+ @async_map_completion
+ def run(x):
+ time.sleep(0.1)
+ return async_completion(str)(x+1)
+
+ c = run([1,2])
+
+ self._wait(ssh_module, c)
+ assert c.result == "['2', '3']"
+
+ @pytest.mark.skipif(sys.version_info < (3,0), reason="requires python3")
+ def test_raise(self, ssh_module):
+ @async_completion
+ def run(x):
+ raise ZeroDivisionError()
+
+ with pytest.raises(ZeroDivisionError):
+ self._wait(ssh_module, run(1))
+
-from orchestrator import ServiceDescription
+import json
+import time
+from contextlib import contextmanager
+
+from ceph.deployment.drive_group import DriveGroupSpec, DeviceSelection
+
+try:
+ from typing import Any
+except ImportError:
+ pass
+
+from orchestrator import ServiceDescription, raise_if_exception, Completion, InventoryNode, \
+ StatelessServiceSpec, PlacementSpec, RGWSpec, parse_host_specs
from ..module import SSHOrchestrator
from tests import mock
+from .fixtures import ssh_module
+
+
+"""
+TODOs:
+ There is really room for improvement here. I just quickly assembled theses tests.
+ I general, everything should be testes in Teuthology as well. Reasons for
+ also testing this here is the development roundtrip time.
+"""
+
+
+
+def _run_ceph_daemon(ret):
+ def foo(*args, **kwargs):
+ return ret, 0
+ return foo
+
+def mon_command(*args, **kwargs):
+ return 0, '', ''
+
+
+class TestSSH(object):
+ def _wait(self, m, c):
+ # type: (SSHOrchestrator, Completion) -> Any
+ m.process([c])
+ m.process([c])
+
+ for _ in range(30):
+ if c.is_finished:
+ raise_if_exception(c)
+ return c.result
+ time.sleep(0.1)
+ assert False, "timeout" + str(c._state)
+
+ @contextmanager
+ def _with_host(self, m, name):
+ self._wait(m, m.add_host(name))
+ yield
+ self._wait(m, m.remove_host(name))
+
+ def test_get_unique_name(self, ssh_module):
+ existing = [
+ ServiceDescription(service_instance='mon.a')
+ ]
+ new_mon = ssh_module.get_unique_name(existing, 'mon')
+ assert new_mon.startswith('mon.')
+ assert new_mon != 'mon.a'
+
+ def test_host(self, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ assert self._wait(ssh_module, ssh_module.get_hosts()) == [InventoryNode('test')]
+ c = ssh_module.get_hosts()
+ assert self._wait(ssh_module, c) == []
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+ def test_service_ls(self, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ c = ssh_module.describe_service()
+ assert self._wait(ssh_module, c) == []
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+ def test_device_ls(self, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ c = ssh_module.get_inventory()
+ assert self._wait(ssh_module, c) == [InventoryNode('test')]
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_mon_update(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ c = ssh_module.update_mons(1, [parse_host_specs('test:0.0.0.0')])
+ assert self._wait(ssh_module, c) == ["(Re)deployed mon.test on host 'test'"]
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('[]'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_mgr_update(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ c = ssh_module.update_mgrs(1, [parse_host_specs('test:0.0.0.0')])
+ [out] = self._wait(ssh_module, c)
+ assert "(Re)deployed mgr." in out
+ assert " on host 'test'" in out
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_create_osds(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ dg = DriveGroupSpec('test', DeviceSelection(paths=['']))
+ c = ssh_module.create_osds(dg)
+ assert self._wait(ssh_module, c) == "Created osd(s) on host 'test'"
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_mds(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ ps = PlacementSpec(nodes=['test'])
+ c = ssh_module.add_mds(StatelessServiceSpec('name', ps))
+ [out] = self._wait(ssh_module, c)
+ assert "(Re)deployed mds.name." in out
+ assert " on host 'test'" in out
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_rgw(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ ps = PlacementSpec(nodes=['test'])
+ c = ssh_module.add_rgw(RGWSpec('name', ps))
+ [out] = self._wait(ssh_module, c)
+ assert "(Re)deployed rgw.name." in out
+ assert " on host 'test'" in out
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon(
+ json.dumps([
+ dict(
+ name='rgw.myrgw.foobar',
+ style='ceph-daemon',
+ fsid='fsid',
+ container_id='container_id',
+ version='version',
+ state='running',
+ )
+ ])
+ ))
+ def test_remove_rgw(self, ssh_module):
+ ssh_module._cluster_fsid = "fsid"
+ with self._with_host(ssh_module, 'test'):
+ c = ssh_module.remove_rgw('myrgw')
+ out = self._wait(ssh_module, c)
+ assert out == ["Removed rgw.myrgw.foobar from host 'test'"]
+
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_rbd_mirror(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ ps = PlacementSpec(nodes=['test'])
+ c = ssh_module.add_rbd_mirror(StatelessServiceSpec('name', ps))
+ [out] = self._wait(ssh_module, c)
+ assert "(Re)deployed rbd-mirror." in out
+ assert " on host 'test'" in out
+ @mock.patch("ssh.module.SSHOrchestrator._run_ceph_daemon", _run_ceph_daemon('{}'))
+ @mock.patch("ssh.module.SSHOrchestrator.send_command")
+ @mock.patch("ssh.module.SSHOrchestrator.mon_command", mon_command)
+ @mock.patch("ssh.module.SSHOrchestrator._get_connection")
+ def test_blink_device_light(self, _send_command, _get_connection, ssh_module):
+ with self._with_host(ssh_module, 'test'):
+ c = ssh_module.blink_device_light('ident', True, [('test', '')])
+ assert self._wait(ssh_module, c) == ['Set ident light for test: on']
-@mock.patch("ssh.module.SSHOrchestrator.get_ceph_option", lambda _,key: __file__)
-def test_get_unique_name():
- o = SSHOrchestrator('module_name', 0, 0)
- existing = [
- ServiceDescription(service_instance='mon.a')
- ]
- new_mon = o.get_unique_name(existing, 'mon')
- assert new_mon.startswith('mon.')
- assert new_mon != 'mon.a'
class TestCompletion(orchestrator.Completion):
def evaluate(self):
- self._first_promise.finalize(None)
+ self.finalize(None)
def deferred_read(f):
self._ceph_get = mock.MagicMock()
self._ceph_get_module_option = mock.MagicMock()
self._ceph_log = mock.MagicMock()
- self._ceph_get_option = mock.MagicMock()
self._ceph_get_store = lambda _: ''
self._ceph_get_store_prefix = lambda _: {}
from __future__ import absolute_import
import json
-try:
- from unittest.mock import MagicMock
-except ImportError:
- # py2
- from mock import MagicMock
+from tests import mock
import pytest
def test_promise_then():
p = Completion(value=3).then(lambda three: three + 1)
- p._first_promise.finalize()
+ p.finalize()
assert p.result == 4
def test_promise_mondatic_then():
p = Completion(value=3)
p.then(lambda three: Completion(value=three + 1))
- p._first_promise.finalize()
+ p.finalize()
assert p.result == 4
c = Completion(value=3).then(
lambda three: Completion(value=three + 1).then(
lambda four: four + 1))
- return c._first_promise
+ return c
def test_promise_mondatic_then_combined():
p = some_complex_completion()
- p._first_promise.finalize()
+ p.finalize()
assert p.result == 5
foo['x'] = x
foo['x'] = 1
- Completion(value=3).then(run)._first_promise.finalize()
+ Completion(value=3).then(run).finalize()
assert foo['x'] == 3
def test_progress():
c = some_complex_completion()
- mgr = MagicMock()
+ mgr = mock.MagicMock()
mgr.process = lambda cs: [c.finalize(None) for c in cs]
progress_val = 0.75
def test_with_progress():
- mgr = MagicMock()
+ mgr = mock.MagicMock()
mgr.process = lambda cs: [c.finalize(None) for c in cs]
def execute(y):
def run(x):
raise KeyError(x)
- c = Completion(value=3).then(run)._first_promise
+ c = Completion(value=3).then(run)
c.finalize()
-
- assert isinstance(c.exception, KeyError)
+ with pytest.raises(KeyError):
+ raise_if_exception(c)
def test_fail():
[testenv]
setenv = UNITTEST = true
deps = -rrequirements.txt
-commands = pytest --cov --cov-append --cov-report=term --doctest-modules {posargs:mgr_util.py tests/ ssh/}
\ No newline at end of file
+commands = pytest -v --cov --cov-append --cov-report=term --doctest-modules {posargs:mgr_util.py tests/ ssh/}
\ No newline at end of file