]> git.apps.os.sepia.ceph.com Git - teuthology.git/commitdiff
async
authorZack Cerza <zack@redhat.com>
Sat, 20 Jan 2024 00:38:59 +0000 (17:38 -0700)
committerZack Cerza <zack@redhat.com>
Mon, 22 Jan 2024 21:48:48 +0000 (14:48 -0700)
scripts/nuke.py
setup.cfg
teuthology/__init__.py
teuthology/nuke/__init__.py
teuthology/orchestra/remote.py
teuthology/orchestra/run.py
teuthology/parallel.py
teuthology/task/install/__init__.py
teuthology/task/install/deb.py
teuthology/task/pexec.py
teuthology/test/test_parallel.py

index 0b1644c3e720219fc41ad81541532c52236cf1d3..33021618540fae3b64f23d7e7428cb00a27814b9 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import docopt
 
 import teuthology.nuke
@@ -44,4 +45,4 @@ teuthology-nuke -t target.yaml --pid 1234 --unlock --owner user@host
 
 def main():
     args = docopt.docopt(doc)
-    teuthology.nuke.main(args)
+    asyncio.run(teuthology.nuke.main(args))
index be35d5ebddece015178f1a6b752e87b3ab872b69..ac2b072b5d8ea86462479a16be7e90dee5cd6fdf 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -102,6 +102,7 @@ test =
     mock
     nose
     pytest
+    pytest-asyncio
     pytest-cov
     toml
     tox
index d84f25a2eaef254759ace6cf4e542ee99ddb6fa1..6a9ca5d9e1aa9659377422817d59646391eae258 100644 (file)
@@ -7,11 +7,6 @@ except ImportError:
 
 __version__ = importlib_metadata.version("teuthology")
 
-# Tell gevent not to patch os.waitpid() since it is susceptible to race
-# conditions. See:
-# http://www.gevent.org/gevent.monkey.html#gevent.monkey.patch_os
-os.environ['GEVENT_NOWAITPID'] = 'true'
-
 # Use manhole to give us a way to debug hung processes
 # https://pypi.python.org/pypi/manhole
 try:
@@ -23,21 +18,11 @@ try:
     )
 except ImportError:
     pass
-from gevent import monkey
-monkey.patch_all(
-    dns=False,
-    # Don't patch subprocess to avoid http://tracker.ceph.com/issues/14990
-    subprocess=False,
-)
 import sys
-from gevent.hub import Hub
 
 # Don't write pyc files
 sys.dont_write_bytecode = True
 
-from teuthology.orchestra import monkey
-monkey.patch_all()
-
 import logging
 
 # If we are running inside a virtualenv, ensure we have its 'bin' directory in
@@ -56,6 +41,9 @@ logging.getLogger('urllib3.connectionpool').setLevel(
 # We also don't need the "Converted retries value" messages
 logging.getLogger('urllib3.util.retry').setLevel(
     logging.WARN)
+# TODO re-check: gevent-related debug statement from asyncio
+logging.getLogger('asyncio').setLevel(
+    logging.INFO)
 
 logging.basicConfig(
     level=logging.INFO,
@@ -94,19 +82,3 @@ def install_except_hook():
                                                          exc_traceback))
         sys.__excepthook__(exc_type, exc_value, exc_traceback)
     sys.excepthook = log_exception
-
-
-def patch_gevent_hub_error_handler():
-    Hub._origin_handle_error = Hub.handle_error
-
-    def custom_handle_error(self, context, type, value, tb):
-        if context is None or issubclass(type, Hub.SYSTEM_ERROR):
-            self.handle_system_error(type, value)
-        elif issubclass(type, Hub.NOT_ERROR):
-            pass
-        else:
-            log.error("Uncaught exception (Hub)", exc_info=(type, value, tb))
-
-    Hub.handle_error = custom_handle_error
-
-patch_gevent_hub_error_handler()
index 8a2985b9eff5aa29458fe4f923bf2e47e935cc40..cde6eb7eaa69c3d1cba3b6836ee0c3769baad220 100644 (file)
@@ -1,4 +1,5 @@
 import argparse
+import asyncio
 import datetime
 import json
 import logging
@@ -176,7 +177,7 @@ def openstack_remove_again():
         openstack_delete_volume(i)
 
 
-def main(args):
+async def main(args):
     ctx = FakeNamespace(args)
     if ctx.verbose:
         teuthology.log.setLevel(logging.DEBUG)
@@ -234,15 +235,22 @@ def main(args):
         else:
             subprocess.check_call(["kill", "-9", str(ctx.pid)])
 
-    nuke(ctx, ctx.unlock, ctx.synch_clocks, ctx.noipmi, ctx.keep_logs, not ctx.no_reboot)
+    await nuke(ctx, ctx.unlock, ctx.synch_clocks, ctx.noipmi, ctx.keep_logs, not ctx.no_reboot)
 
 
-def nuke(ctx, should_unlock, sync_clocks=True, noipmi=False, keep_logs=False, should_reboot=True):
+async def nuke(ctx, should_unlock, sync_clocks=True, noipmi=False, keep_logs=False, should_reboot=True):
     if 'targets' not in ctx.config:
         return
     total_unnuked = {}
+    tasks = set()
+    def callback(task):
+        result = task.result()
+        if result:
+            total_unnuked.update(result)
+        tasks.discard(task)
+
     log.info('Checking targets against current locks')
-    with parallel() as p:
+    async with parallel() as p:
         for target, hostkey in ctx.config['targets'].items():
             status = get_status(target)
             if ctx.name and ctx.name not in (status.get('description') or ""):
@@ -256,18 +264,22 @@ def nuke(ctx, should_unlock, sync_clocks=True, noipmi=False, keep_logs=False, sh
                 total_unnuked[target] = hostkey
                 log.info(f"Not nuking {target} because it is down")
                 continue
+            # task = asyncio.create_task(
             p.spawn(
                 nuke_one,
-                ctx,
-                {target: hostkey},
-                should_unlock,
-                sync_clocks,
-                ctx.config.get('check-locks', True),
-                noipmi,
-                keep_logs,
-                should_reboot,
+                    ctx,
+                    {target: hostkey},
+                    should_unlock,
+                    sync_clocks,
+                    ctx.config.get('check-locks', True),
+                    noipmi,
+                    keep_logs,
+                    should_reboot,
             )
-        for unnuked in p:
+            # tasks.add(task)
+            # task.add_done_callback(callback)
+        async for task in p:
+            unnuked = await task
             if unnuked:
                 total_unnuked.update(unnuked)
     if total_unnuked:
@@ -278,7 +290,7 @@ def nuke(ctx, should_unlock, sync_clocks=True, noipmi=False, keep_logs=False, sh
                                   default_flow_style=False).splitlines()))
 
 
-def nuke_one(ctx, target, should_unlock, synch_clocks,
+async def nuke_one(ctx, target, should_unlock, synch_clocks,
              check_locks, noipmi, keep_logs, should_reboot):
     ret = None
     ctx = argparse.Namespace(
@@ -291,7 +303,7 @@ def nuke_one(ctx, target, should_unlock, synch_clocks,
         noipmi=noipmi,
     )
     try:
-        nuke_helper(ctx, should_unlock, keep_logs, should_reboot)
+        await nuke_helper(ctx, should_unlock, keep_logs, should_reboot)
     except Exception:
         log.exception('Could not nuke %s' % target)
         # not re-raising the so that parallel calls aren't killed
@@ -302,7 +314,7 @@ def nuke_one(ctx, target, should_unlock, synch_clocks,
     return ret
 
 
-def nuke_helper(ctx, should_unlock, keep_logs, should_reboot):
+async def nuke_helper(ctx, should_unlock, keep_logs, should_reboot):
     # ensure node is up with ipmi
     (target,) = ctx.config['targets'].keys()
     host = target.split('@')[-1]
@@ -324,7 +336,7 @@ def nuke_helper(ctx, should_unlock, keep_logs, should_reboot):
         provision.pelagos.park_node(host)
         return
     elif remote_.is_container:
-        remote_.run(
+        await remote_.run(
             args=['sudo', '/testnode_stop.sh'],
             check_status=False,
         )
index ce77a519cf36189c16a0e15c151b8e3c19467a24..9904397604dfc04c02dc2386563424fc5cfdab5b 100644 (file)
@@ -474,20 +474,19 @@ class Remote(RemoteShell):
             self._shortname = host_shortname(self.hostname)
         return self._shortname
 
-    @property
-    def is_online(self):
+    async def is_online(self):
         if self.ssh is None:
             return False
         if self.ssh.get_transport() is None:
             return False
         try:
-            self.run(args="true")
+            await self.run(args="true")
         except Exception:
             return False
         return self.ssh.get_transport().is_active()
 
     def ensure_online(self):
-        if self.is_online:
+        if self.is_online():
             return
         self.connect()
         if not self.is_online:
@@ -509,7 +508,7 @@ class Remote(RemoteShell):
             name=self.name,
             )
 
-    def run(self, **kwargs):
+    async def run(self, **kwargs):
         """
         This calls `orchestra.run.run` with our SSH client.
 
@@ -520,7 +519,7 @@ class Remote(RemoteShell):
            not self.ssh.get_transport().is_active():
             if not self.reconnect():
                 raise ConnectionError(f'Failed to reconnect to {self.shortname}')
-        r = self._runner(client=self.ssh, name=self.shortname, **kwargs)
+        r = await self._runner(client=self.ssh, name=self.shortname, **kwargs)
         r.remote = self
         return r
 
index f31dfd0d7fc1db89feab1f67abdc95c809a9731d..de6a016a2bb1ce0bc02379a34f2fed37d1cb1e4f 100644 (file)
@@ -2,20 +2,21 @@
 Paramiko run support
 """
 
+import asyncio
 import io
 
 from paramiko import ChannelFile
 
-import gevent
-import gevent.event
 import socket
 import pipes
 import logging
 import shutil
 
-from teuthology.contextutil import safe_while
-from teuthology.exceptions import (CommandCrashedError, CommandFailedError,
-                                   ConnectionLostError)
+from teuthology.exceptions import (
+    CommandCrashedError,
+    CommandFailedError,
+    ConnectionLostError,
+)
 
 log = logging.getLogger(__name__)
 
@@ -24,22 +25,44 @@ class RemoteProcess(object):
     """
     An object to begin and monitor execution of a process on a remote host
     """
+
     __slots__ = [
-        'client', 'args', 'check_status', 'command', 'hostname',
-        'stdin', 'stdout', 'stderr',
-        '_stdin_buf', '_stdout_buf', '_stderr_buf',
-        'returncode', 'exitstatus', 'timeout',
-        'greenlets',
-        '_wait', 'logger',
+        "client",
+        "args",
+        "check_status",
+        "command",
+        "hostname",
+        "stdin",
+        "stdout",
+        "stderr",
+        "_stdin_buf",
+        "_stdout_buf",
+        "_stderr_buf",
+        "returncode",
+        "exitstatus",
+        "timeout",
+        "tasks",
+        "_wait",
+        "logger",
         # for orchestra.remote.Remote to place a backreference
-        'remote',
-        'label',
-        ]
+        "remote",
+        "label",
+    ]
 
     deadlock_warning = "Using PIPE for %s without wait=False would deadlock"
 
-    def __init__(self, client, args, check_status=True, hostname=None,
-                 label=None, timeout=None, wait=True, logger=None, cwd=None):
+    def __init__(
+        self,
+        client,
+        args,
+        check_status=True,
+        hostname=None,
+        label=None,
+        timeout=None,
+        wait=True,
+        logger=None,
+        cwd=None,
+    ):
         """
         Create the object. Does not initiate command execution.
 
@@ -67,8 +90,7 @@ class RemoteProcess(object):
             self.command = args
 
         if cwd:
-            self.command = '(cd {cwd} && exec {cmd})'.format(
-                           cwd=cwd, cmd=self.command)
+            self.command = "(cd {cwd} && exec {cmd})".format(cwd=cwd, cmd=self.command)
 
         self.check_status = check_status
         self.label = label
@@ -79,7 +101,7 @@ class RemoteProcess(object):
         else:
             (self.hostname, port) = client.get_transport().getpeername()[0:2]
 
-        self.greenlets = []
+        self.tasks = set()
         self.stdin, self.stdout, self.stderr = (None, None, None)
         self.returncode = self.exitstatus = None
         self._wait = wait
@@ -89,43 +111,49 @@ class RemoteProcess(object):
         """
         Execute remote command
         """
-        for line in self.command.split('\n'):
-            log.getChild(self.hostname).debug('%s> %s' % (self.label or '', line))
-
-        if hasattr(self, 'timeout'):
-            (self._stdin_buf, self._stdout_buf, self._stderr_buf) = \
-                self.client.exec_command(self.command, timeout=self.timeout)
+        for line in self.command.split("\n"):
+            log.getChild(self.hostname).debug("%s> %s" % (self.label or "", line))
+
+        if hasattr(self, "timeout"):
+            (
+                self._stdin_buf,
+                self._stdout_buf,
+                self._stderr_buf,
+            ) = self.client.exec_command(self.command, timeout=self.timeout)
         else:
-            (self._stdin_buf, self._stdout_buf, self._stderr_buf) = \
-                self.client.exec_command(self.command)
-        (self.stdin, self.stdout, self.stderr) = \
-            (self._stdin_buf, self._stdout_buf, self._stderr_buf)
-
-    def add_greenlet(self, greenlet):
-        self.greenlets.append(greenlet)
+            (
+                self._stdin_buf,
+                self._stdout_buf,
+                self._stderr_buf,
+            ) = self.client.exec_command(self.command)
+        (self.stdin, self.stdout, self.stderr) = (
+            self._stdin_buf,
+            self._stdout_buf,
+            self._stderr_buf,
+        )
 
     def setup_stdin(self, stream_obj):
         self.stdin = KludgeFile(wrapped=self.stdin)
         if stream_obj is not PIPE:
-            greenlet = gevent.spawn(copy_and_close, stream_obj, self.stdin)
-            self.add_greenlet(greenlet)
+            self.tasks.add(asyncio.create_task(copy_and_close(stream_obj, self.stdin)))
             self.stdin = None
         elif self._wait:
             # FIXME: Is this actually true?
-            raise RuntimeError(self.deadlock_warning % 'stdin')
+            raise RuntimeError(self.deadlock_warning % "stdin")
 
     def setup_output_stream(self, stream_obj, stream_name, quiet=False):
         if stream_obj is not PIPE:
             # Log the stream
             host_log = self.logger.getChild(self.hostname)
             stream_log = host_log.getChild(stream_name)
-            self.add_greenlet(
-                gevent.spawn(
-                    copy_file_to,
-                    getattr(self, stream_name),
-                    stream_log,
-                    stream_obj,
-                    quiet,
+            self.tasks.add(
+                asyncio.create_task(
+                    copy_file_to(
+                        getattr(self, stream_name),
+                        stream_log,
+                        stream_obj,
+                        quiet,
+                    )
                 )
             )
             setattr(self, stream_name, stream_obj)
@@ -133,7 +161,7 @@ class RemoteProcess(object):
             # FIXME: Is this actually true?
             raise RuntimeError(self.deadlock_warning % stream_name)
 
-    def wait(self):
+    async def wait(self):
         """
         Block until remote process finishes.
 
@@ -143,19 +171,20 @@ class RemoteProcess(object):
         status = self._get_exitstatus()
         if status != 0:
             log.debug("got remote process result: {}".format(status))
-        for greenlet in self.greenlets:
+        for task in self.tasks:
             try:
-                greenlet.get(block=True,timeout=60)
-            except gevent.Timeout:
-                log.debug("timed out waiting; will kill: {}".format(greenlet))
-                greenlet.kill(block=False)
-        for stream in ('stdout', 'stderr'):
+                await task
+            except asyncio.TimeoutError:
+                log.debug("timed out waiting; will kill: {}".format(task))
+                task.cancel()
+        for stream in ("stdout", "stderr"):
             if hasattr(self, stream):
                 stream_obj = getattr(self, stream)
                 # Despite ChannelFile having a seek() method, it raises
                 # "IOError: File does not support seeking."
-                if hasattr(stream_obj, 'seek') and \
-                        not isinstance(stream_obj, ChannelFile):
+                if hasattr(stream_obj, "seek") and not isinstance(
+                    stream_obj, ChannelFile
+                ):
                     stream_obj.seek(0)
 
         self._raise_for_status()
@@ -171,16 +200,17 @@ class RemoteProcess(object):
                 transport = self.client.get_transport()
                 if transport is None or not transport.is_active():
                     # look like we lost the connection
-                    raise ConnectionLostError(command=self.command,
-                                              node=self.hostname)
+                    raise ConnectionLostError(command=self.command, node=self.hostname)
 
                 # connection seems healthy still, assuming it was a
                 # signal; sadly SSH does not tell us which signal
                 raise CommandCrashedError(command=self.command)
             if self.returncode != 0:
                 raise CommandFailedError(
-                    command=self.command, exitstatus=self.returncode,
-                    node=self.hostname, label=self.label
+                    command=self.command,
+                    exitstatus=self.returncode,
+                    node=self.hostname,
+                    label=self.label,
                 )
 
     def _get_exitstatus(self):
@@ -197,7 +227,7 @@ class RemoteProcess(object):
 
     @property
     def finished(self):
-        gevent.wait(self.greenlets, timeout=0.1)
+        # return all([task.done() for task in self.tasks])
         ready = self._stdout_buf.channel.exit_status_ready()
         if ready:
             self._get_exitstatus()
@@ -213,13 +243,13 @@ class RemoteProcess(object):
         return None
 
     def __repr__(self):
-        return '{classname}(client={client!r}, args={args!r}, check_status={check}, hostname={name!r})'.format(  # noqa
+        return "{classname}(client={client!r}, args={args!r}, check_status={check}, hostname={name!r})".format(  # noqa
             classname=self.__class__.__name__,
             client=self.client,
             args=self.args,
             check=self.check_status,
             name=self.hostname,
-            )
+        )
 
 
 class Raw(object):
@@ -227,14 +257,15 @@ class Raw(object):
     """
     Raw objects are passed to remote objects and are not processed locally.
     """
+
     def __init__(self, value):
         self.value = value
 
     def __repr__(self):
-        return '{cls}({value!r})'.format(
+        return "{cls}({value!r})".format(
             cls=self.__class__.__name__,
             value=self.value,
-            )
+        )
 
     def __eq__(self, value):
         return self.value == value
@@ -244,6 +275,7 @@ def quote(args):
     """
     Internal quote wrapper.
     """
+
     def _quote(args):
         """
         Handle quoted string, testing for raw charaters.
@@ -253,13 +285,14 @@ def quote(args):
                 yield a.value
             else:
                 yield pipes.quote(a)
+
     if isinstance(args, list):
-        return ' '.join(_quote(args))
+        return " ".join(_quote(args))
     else:
         return args
 
 
-def copy_to_log(f, logger, loglevel=logging.INFO, capture=None, quiet=False):
+async def copy_to_log(f, logger, loglevel=logging.INFO, capture=None, quiet=False):
     """
     Copy line by line from file in f to the log from logger
 
@@ -279,7 +312,7 @@ def copy_to_log(f, logger, loglevel=logging.INFO, capture=None, quiet=False):
                 if isinstance(line, str):
                     capture.write(line)
                 else:
-                    capture.write(line.decode('utf-8', 'replace'))
+                    capture.write(line.decode("utf-8", "replace"))
             elif isinstance(capture, io.BytesIO):
                 if isinstance(line, str):
                     capture.write(line.encode())
@@ -291,13 +324,13 @@ def copy_to_log(f, logger, loglevel=logging.INFO, capture=None, quiet=False):
             continue
         try:
             if isinstance(line, bytes):
-                line = line.decode('utf-8', 'replace')
+                line = line.decode("utf-8", "replace")
             logger.log(loglevel, line)
         except (UnicodeDecodeError, UnicodeEncodeError):
             logger.exception("Encountered unprintable line in command output")
 
 
-def copy_and_close(src, fdst):
+async def copy_and_close(src, fdst):
     """
     copyfileobj call wrapper.
     """
@@ -310,7 +343,7 @@ def copy_and_close(src, fdst):
     fdst.close()
 
 
-def copy_file_to(src, logger, stream=None, quiet=False):
+async def copy_file_to(src, logger, stream=None, quiet=False):
     """
     Copy file
     :param src: file to be copied.
@@ -320,33 +353,7 @@ def copy_file_to(src, logger, stream=None, quiet=False):
     :param quiet: disable logger usage if True, useful in combination
                   with `stream` parameter, defaults False.
     """
-    copy_to_log(src, logger, capture=stream, quiet=quiet)
-
-def spawn_asyncresult(fn, *args, **kwargs):
-    """
-    Spawn a Greenlet and pass it's results to an AsyncResult.
-
-    This function is useful to shuffle data from a Greenlet to
-    AsyncResult, which then again is useful because any Greenlets that
-    raise exceptions will cause tracebacks to be shown on stderr by
-    gevent, even when ``.link_exception`` has been called. Using an
-    AsyncResult avoids this.
-    """
-    r = gevent.event.AsyncResult()
-
-    def wrapper():
-        """
-        Internal wrapper.
-        """
-        try:
-            value = fn(*args, **kwargs)
-        except Exception as e:
-            r.set_exception(e)
-        else:
-            r.set(value)
-    gevent.spawn(wrapper)
-
-    return r
+    await copy_to_log(src, logger, capture=stream, quiet=quiet)
 
 
 class Sentinel(object):
@@ -354,13 +361,15 @@ class Sentinel(object):
     """
     Sentinel -- used to define PIPE file-like object.
     """
+
     def __init__(self, name):
         self.name = name
 
     def __str__(self):
         return self.name
 
-PIPE = Sentinel('PIPE')
+
+PIPE = Sentinel("PIPE")
 
 
 class KludgeFile(object):
@@ -369,6 +378,7 @@ class KludgeFile(object):
     Wrap Paramiko's ChannelFile in a way that lets ``f.close()``
     actually cause an EOF for the remote command.
     """
+
     def __init__(self, wrapped):
         self._wrapped = wrapped
 
@@ -383,9 +393,12 @@ class KludgeFile(object):
         self._wrapped.channel.shutdown_write()
 
 
-def run(
-    client, args,
-    stdin=None, stdout=None, stderr=None,
+async def run(
+    client,
+    args,
+    stdin=None,
+    stdout=None,
+    stderr=None,
     logger=None,
     check_status=True,
     wait=True,
@@ -395,7 +408,7 @@ def run(
     timeout=None,
     cwd=None,
     # omit_sudo is used by vstart_runner.py
-    omit_sudo=False
+    omit_sudo=False,
 ):
     """
     Run a command remotely.  If any of 'args' contains shell metacharacters
@@ -419,9 +432,7 @@ def run(
     :param check_status: Whether to raise CommandFailedError on non-zero exit
                          status, and . Defaults to True. All signals and
                          connection loss are made to look like SIGHUP.
-    :param wait: Whether to wait for process to exit. If False, returned
-                 ``r.exitstatus`` s a `gevent.event.AsyncResult`, and the
-                 actual status is available via ``.get()``.
+    :param wait: Whether to wait for process to exit.
     :param name: Human readable name (probably hostname) of the destination
                  host
     :param label: Can be used to label or describe what the command is doing.
@@ -444,19 +455,27 @@ def run(
 
     if timeout:
         log.info("Running command with timeout %d", timeout)
-    r = RemoteProcess(client, args, check_status=check_status, hostname=name,
-                      label=label, timeout=timeout, wait=wait, logger=logger,
-                      cwd=cwd)
+    r = RemoteProcess(
+        client,
+        args,
+        check_status=check_status,
+        hostname=name,
+        label=label,
+        timeout=timeout,
+        wait=wait,
+        logger=logger,
+        cwd=cwd,
+    )
     r.execute()
     r.setup_stdin(stdin)
-    r.setup_output_stream(stderr, 'stderr', quiet)
-    r.setup_output_stream(stdout, 'stdout', quiet)
+    r.setup_output_stream(stderr, "stderr", quiet)
+    r.setup_output_stream(stdout, "stdout", quiet)
     if wait:
-        r.wait()
+        await r.wait()
     return r
 
 
-def wait(processes, timeout=None):
+async def wait(processes, timeout=None):
     """
     Wait for all given processes to exit.
 
@@ -466,14 +485,4 @@ def wait(processes, timeout=None):
     """
     if timeout:
         log.info("waiting for %d", timeout)
-    if timeout and timeout > 0:
-        with safe_while(tries=(timeout // 6)) as check_time:
-            not_ready = list(processes)
-            while len(not_ready) > 0:
-                check_time()
-                for proc in list(not_ready):
-                    if proc.finished:
-                        not_ready.remove(proc)
-
-    for proc in processes:
-        proc.wait()
+    await asyncio.wait_for(asyncio.gather(processes), timeout)
index 0a7d3ab35a008ab1339a89c03cb8ecd81200dd7e..88c3614e9004c5d5b59638ee34b6a3c40fc70c7f 100644 (file)
@@ -1,39 +1,12 @@
+import asyncio
 import logging
-import sys
 
-import gevent
-import gevent.pool
-import gevent.queue
+from typing import List
 
 
 log = logging.getLogger(__name__)
 
 
-class ExceptionHolder(object):
-    def __init__(self, exc_info):
-        self.exc_info = exc_info
-
-
-def capture_traceback(func, *args, **kwargs):
-    """
-    Utility function to capture tracebacks of any exception func
-    raises.
-    """
-    try:
-        return func(*args, **kwargs)
-    except Exception:
-        return ExceptionHolder(sys.exc_info())
-
-
-def resurrect_traceback(exc):
-    if isinstance(exc, ExceptionHolder):
-        raise exc.exc_info[1]
-    elif isinstance(exc, BaseException):
-        raise exc
-    else:
-        return
-
-
 class parallel(object):
     """
     This class is a context manager for running functions in parallel.
@@ -61,55 +34,41 @@ class parallel(object):
     """
 
     def __init__(self):
-        self.group = gevent.pool.Group()
-        self.results = gevent.queue.Queue()
+        # self.results = asyncio.Queue()
+        self.results = []
         self.count = 0
-        self.any_spawned = False
         self.iteration_stopped = False
+        self.tasks: List[asyncio.Task] = []
+        self.any_spawned = False
 
     def spawn(self, func, *args, **kwargs):
-        self.count += 1
         self.any_spawned = True
-        greenlet = self.group.spawn(capture_traceback, func, *args, **kwargs)
-        greenlet.link(self._finish)
-
-    def __enter__(self):
+        self.count += 1
+        async def wrapper():
+            # print(f"{func} {args} {kwargs}")
+            return func(*args, **kwargs)
+        self.tasks.append(asyncio.create_task(
+            wrapper()
+        ))
+
+    async def __aenter__(self):
         return self
 
-    def __exit__(self, type_, value, traceback):
-        if value is not None:
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        if exc_value is not None:
             return False
-
-        # raises if any greenlets exited with an exception
-        for result in self:
-            log.debug('result is %s', repr(result))
+        self.results = await asyncio.gather(*self.tasks)#, return_exceptions=True)
 
         return True
 
-    def __iter__(self):
+    def __aiter__(self):
         return self
 
-    def __next__(self):
-        if not self.any_spawned or self.iteration_stopped:
-            raise StopIteration()
-        result = self.results.get()
-
-        try:
-            resurrect_traceback(result)
-        except StopIteration:
-            self.iteration_stopped = True
-            raise
-
-        return result
-
-    next = __next__
-
-    def _finish(self, greenlet):
-        if greenlet.successful():
-            self.results.put(greenlet.value)
-        else:
-            self.results.put(greenlet.exception)
-
-        self.count -= 1
-        if self.count <= 0:
-            self.results.put(StopIteration())
+    async def __anext__(self):
+        print(f"tasks={self.tasks}")
+        if not self.tasks:
+            raise StopAsyncIteration
+        task = self.tasks.pop(0)
+        res = await task
+        print(f"res={res}")
+        return res
index 0f1bb63dacab49cc11db4a4833ccb8e43107055a..1cb7f70b7a4409a0a4e9ede8e3d7e7ee3bd8e19e 100644 (file)
@@ -1,3 +1,4 @@
+import asyncio
 import contextlib
 import copy
 import logging
@@ -7,7 +8,6 @@ import yaml
 
 from teuthology import misc as teuthology
 from teuthology import contextutil, packaging
-from teuthology.parallel import parallel
 from teuthology.task import ansible
 
 from distutils.version import LooseVersion
@@ -63,7 +63,7 @@ def verify_package_version(ctx, config, remote):
         )
 
 
-def install_packages(ctx, pkgs, config):
+async def install_packages(ctx, pkgs, config):
     """
     Installs packages on each remote in ctx.
 
@@ -75,19 +75,19 @@ def install_packages(ctx, pkgs, config):
         "deb": deb._update_package_list_and_install,
         "rpm": rpm._update_package_list_and_install,
     }
-    with parallel() as p:
-        for remote in ctx.cluster.remotes.keys():
-            system_type = teuthology.get_system_type(remote)
-            p.spawn(
-                install_pkgs[system_type],
-                ctx, remote, pkgs[system_type], config)
+    tasks = set()
+    for remote in ctx.cluster.remotes.keys():
+        system_type = teuthology.get_system_type(remote)
+        install_fn = install_pkgs[system_type]
+        tasks.add(
+            asyncio.create_task(install_fn(ctx, remote, pkgs[system_type], config)))
 
     for remote in ctx.cluster.remotes.keys():
         # verifies that the install worked as expected
         verify_package_version(ctx, config, remote)
 
 
-def remove_packages(ctx, config, pkgs):
+async def remove_packages(ctx, config, pkgs):
     """
     Removes packages from each remote in ctx.
 
@@ -100,15 +100,16 @@ def remove_packages(ctx, config, pkgs):
         "rpm": rpm._remove,
     }
     cleanup = config.get('cleanup', False)
-    with parallel() as p:
-        for remote in ctx.cluster.remotes.keys():
-            if not remote.is_reimageable or cleanup:
-                system_type = teuthology.get_system_type(remote)
-                p.spawn(remove_pkgs[
-                        system_type], ctx, config, remote, pkgs[system_type])
+    tasks = set()
+    for remote in ctx.cluster.remotes.keys():
+        if not remote.is_reimageable or cleanup:
+            system_type = teuthology.get_system_type(remote)
+            remove_fn = remove_pkgs[system_type]
+            tasks.add(
+                asyncio.create_task(remove_fn(ctx, config, remote, pkgs[system_type])))
 
 
-def remove_sources(ctx, config):
+async def remove_sources(ctx, config):
     """
     Removes repo source files from each remote in ctx.
 
@@ -121,13 +122,13 @@ def remove_sources(ctx, config):
     }
     cleanup = config.get('cleanup', False)
     project = config.get('project', 'ceph')
-    with parallel() as p:
-        for remote in ctx.cluster.remotes.keys():
-            if not remote.is_reimageable or cleanup:
-                log.info("Removing {p} sources lists on {r}"
-                         .format(p=project,r=remote))
-                remove_fn = remove_sources_pkgs[remote.os.package_type]
-                p.spawn(remove_fn, ctx, config, remote)
+    tasks = set()
+    for remote in ctx.cluster.remotes.keys():
+        if not remote.is_reimageable or cleanup:
+            log.info("Removing {p} sources lists on {r}"
+                     .format(p=project,r=remote))
+            remove_fn = remove_sources_pkgs[remote.os.package_type]
+            tasks.add(asyncio.create_task(remove_fn(ctx, config, remote)))
 
 
 def get_package_list(ctx, config):
@@ -179,8 +180,8 @@ def get_package_list(ctx, config):
     return package_list
 
 
-@contextlib.contextmanager
-def install(ctx, config):
+@contextlib.asynccontextmanager
+async def install(ctx, config):
     """
     The install task. Installs packages for a given project on all hosts in
     ctx. May work for projects besides ceph, but may not. Patches welcomed!
@@ -215,12 +216,12 @@ def install(ctx, config):
                 'python-ceph']
         rpms = ['ceph-fuse', 'librbd1', 'librados2', 'ceph-test', 'python-ceph']
     package_list = dict(deb=debs, rpm=rpms)
-    install_packages(ctx, package_list, config)
+    await install_packages(ctx, package_list, config)
     try:
         yield
     finally:
-        remove_packages(ctx, config, package_list)
-        remove_sources(ctx, config)
+        await remove_packages(ctx, config, package_list)
+        await remove_sources(ctx, config)
 
 
 def upgrade_old_style(ctx, node, remote, pkgs, system_type):
index e1a290f78af65351a78ec4eda3e1a4709608a9f3..00ae089cb487dee5b7902083d0681f07f7873300 100644 (file)
@@ -165,7 +165,7 @@ def _remove(ctx, config, remote, debs):
 def _remove_sources_list(ctx, config, remote):
     builder = _get_builder_project(ctx, remote, config)
     builder.remove_repo()
-    remote.run(
+    return remote.run(
         args=[
             'sudo', 'apt-get', 'update',
         ],
index 4d18d27193078c387ecdf64c1d05ad7002c225ee..5832044465e63180967f414675bf55506b76b20e 100644 (file)
@@ -1,38 +1,15 @@
 """
 Handle parallel execution on remote hosts
 """
+import asyncio
 import logging
 
 from teuthology import misc as teuthology
-from teuthology.parallel import parallel
-from teuthology.orchestra import run as tor
+from teuthology.orchestra.run import PIPE, wait
 
 log = logging.getLogger(__name__)
 
-from gevent import queue as queue
-from gevent import event as event
-
-def _init_barrier(barrier_queue, remote):
-    """current just queues a remote host""" 
-    barrier_queue.put(remote)
-
-def _do_barrier(barrier, barrier_queue, remote):
-    """special case for barrier"""
-    barrier_queue.get()
-    if barrier_queue.empty():
-        barrier.set()
-        barrier.clear()
-    else:
-        barrier.wait()
-
-    barrier_queue.put(remote)
-    if barrier_queue.full():
-        barrier.set()
-        barrier.clear()
-    else:
-        barrier.wait()
-
-def _exec_host(barrier, barrier_queue, remote, sudo, testdir, ls):
+def _exec_host(remote, sudo, testdir, ls):
     """Execute command remotely"""
     log.info('Running commands on host %s', remote.name)
     args = [
@@ -43,21 +20,17 @@ def _exec_host(barrier, barrier_queue, remote, sudo, testdir, ls):
     if sudo:
         args.insert(0, 'sudo')
     
-    r = remote.run( args=args, stdin=tor.PIPE, wait=False)
+    r = remote.run( args=args, stdin=PIPE, wait=False)
     r.stdin.writelines(['set -e\n'])
     r.stdin.flush()
     for l in ls:
         l.replace('$TESTDIR', testdir)
-        if l == "barrier":
-            _do_barrier(barrier, barrier_queue, remote)
-            continue
-
         r.stdin.writelines([l, '\n'])
         r.stdin.flush()
     r.stdin.writelines(['\n'])
     r.stdin.flush()
     r.stdin.close()
-    tor.wait([r])
+    return r.wait()
 
 def _generate_remotes(ctx, config):
     """Return remote roles and the type of role specified in config"""
@@ -109,23 +82,6 @@ def task(ctx, config):
         - pexec:
             clients:
               - dd if=/dev/zero of={testdir}/mnt.* count=1024 bs=1024
-
-    You can also ensure that parallel commands are synchronized with the
-    special 'barrier' statement:
-
-    tasks:
-    - pexec:
-        clients:
-          - cd {testdir}/mnt.*
-          - while true; do
-          -   barrier
-          -   dd if=/dev/zero of=./foo count=1024 bs=1024
-          - done
-
-    The above writes to the file foo on all clients over and over, but ensures that
-    all clients perform each write command in sync.  If one client takes longer to
-    write, all the other clients will wait.
-
     """
     log.info('Executing custom commands...')
     assert isinstance(config, dict), "task pexec got invalid config"
@@ -138,12 +94,11 @@ def task(ctx, config):
     testdir = teuthology.get_testdir(ctx)
 
     remotes = list(_generate_remotes(ctx, config))
-    count = len(remotes)
-    barrier_queue = queue.Queue(count)
-    barrier = event.Event()
-
+    tasks = set()
     for remote in remotes:
-        _init_barrier(barrier_queue, remote[0])
-    with parallel() as p:
-        for remote in remotes:
-            p.spawn(_exec_host, barrier, barrier_queue, remote[0], sudo, testdir, remote[1])
+        task = _exec_host(remote[0], sudo, testdir, remote[1])
+        # task = asyncio.create_task(
+        #     _exec_host(remote[0], sudo, testdir, remote[1])
+        # )
+        tasks.add(task)
+    asyncio.gather(list(tasks))
index bba1d57bf79af44b82ea2b88b5a46e095b71e7e1..26bec98ed968e2ebbfdf1cadf0eb77f877c0577f 100644 (file)
@@ -1,3 +1,5 @@
+import pytest
+
 from teuthology.parallel import parallel
 
 
@@ -10,19 +12,22 @@ def identity(item, input_set=None, remove=False):
 
 
 class TestParallel(object):
-    def test_basic(self):
+    @pytest.mark.asyncio
+    async def test_basic(self):
         in_set = set(range(10))
-        with parallel() as para:
+        async with parallel() as para:
             for i in in_set:
                 para.spawn(identity, i, in_set, remove=True)
                 assert para.any_spawned is True
             assert para.count == len(in_set)
 
-    def test_result(self):
+    @pytest.mark.asyncio
+    async def test_result(self):
         in_set = set(range(10))
-        with parallel() as para:
+        async with parallel() as para:
             for i in in_set:
                 para.spawn(identity, i, in_set)
-            for result in para:
+            async for result in para:
+                print(f"res in test = {result}")
                 in_set.remove(result)