]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
cephadm: refactor call() using asyncio.asyncio.StreamReader
authorKefu Chai <kchai@redhat.com>
Sun, 24 Jan 2021 06:58:51 +0000 (14:58 +0800)
committerSebastian Wagner <sebastian.wagner@suse.com>
Fri, 29 Jan 2021 12:42:38 +0000 (13:42 +0100)
simpler this way, also fix a couple issues:

* create a child watcher explicitly, see
  https://bugs.python.org/issue35621
* use StringIO for collecting outputs for better performance,
  instead of appending the lines to an existing str
* catch ValueError when reading from the stream reader,
  because StreamReader.readline() could raise ValueError when
  it reaches the buffer limit while looking for a separator.
  in this case, we should try again, in hope that the spawned
  process can feed the reader with more data which contains
  the separator (i.e., b'\n').
* backport ThreadedChildWatcher from Python 3.8 so we can
  run create_subprocess_exec() in non-main threads.

Signed-off-by: Kefu Chai <kchai@redhat.com>
(cherry picked from commit 30070be24860c2dfc0d27c79b0958d6c09316d87)

src/cephadm/cephadm

index 96d7415dc8cc31e29c0af5ee8137e80cd69383bc..7e8641eecfeb07b368ed54ac9e5822d5cc611db5 100755 (executable)
@@ -39,6 +39,7 @@ You can invoke cephadm in two ways:
        injected_stdin = '...'
 """
 import asyncio
+import asyncio.subprocess
 import argparse
 import datetime
 import fcntl
@@ -65,17 +66,17 @@ from socketserver import ThreadingMixIn
 from http.server import BaseHTTPRequestHandler, HTTPServer
 import signal
 import io
-from contextlib import closing, redirect_stdout
+from contextlib import redirect_stdout
 import ssl
 from enum import Enum
 
 
-from typing import cast, Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
+from typing import Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
 
 import re
 import uuid
 
-from functools import partial, wraps
+from functools import wraps
 from glob import glob
 from threading import Thread, RLock
 
@@ -1186,38 +1187,108 @@ class CallVerbosity(Enum):
     VERBOSE = 3
 
 
-class StreamReaderProto(asyncio.SubprocessProtocol):
-    def __init__(self,
-                 exited: asyncio.Future,
-                 desc: str,
-                 verbosity: CallVerbosity) -> None:
-        self.exited = exited
-        self.desc = desc
-        self.verbosity = verbosity
-        self.stdout = ''
-        self.stderr = ''
-
-    def pipe_data_received(self, fd: int, data: bytes) -> None:
-        prefix = ''
-        lines = data.decode('utf-8')
-
-        if fd == sys.stdout.fileno():
-            prefix = self.desc + 'stdout'
-            self.stdout += lines
-        elif fd == sys.stderr.fileno():
-            prefix = self.desc + 'stderr'
-            self.stderr += lines
-        else:
-            assert False, f"unknown data received from fd: {fd}"
+if sys.version_info < (3, 8):
+    import itertools
+    import threading
+    import warnings
+    from asyncio import events
+
+    class ThreadedChildWatcher(asyncio.AbstractChildWatcher):
+        """Threaded child watcher implementation.
+        The watcher uses a thread per process
+        for waiting for the process finish.
+        It doesn't require subscription on POSIX signal
+        but a thread creation is not free.
+        The watcher has O(1) complexity, its performance doesn't depend
+        on amount of spawn processes.
+        """
+
+        def __init__(self):
+            self._pid_counter = itertools.count(0)
+            self._threads = {}
+
+        def is_active(self):
+                return True
+
+        def close(self):
+            self._join_threads()
+
+        def _join_threads(self):
+            """Internal: Join all non-daemon threads"""
+            threads = [thread for thread in list(self._threads.values())
+                       if thread.is_alive() and not thread.daemon]
+            for thread in threads:
+                thread.join()
 
-        for line in lines.split('\n'):
-            if self.verbosity == CallVerbosity.VERBOSE:
-                logger.info(prefix + line)
-            elif self.verbosity != CallVerbosity.SILENT:
-                logger.debug(prefix + line)
+        def __enter__(self):
+            return self
+
+        def __exit__(self, exc_type, exc_val, exc_tb):
+            pass
+
+        def __del__(self, _warn=warnings.warn):
+            threads = [thread for thread in list(self._threads.values())
+                       if thread.is_alive()]
+            if threads:
+                _warn(f"{self.__class__} has registered but not finished child processes",
+                      ResourceWarning,
+                      source=self)
+
+        def add_child_handler(self, pid, callback, *args):
+            loop = events.get_event_loop()
+            thread = threading.Thread(target=self._do_waitpid,
+                                      name=f"waitpid-{next(self._pid_counter)}",
+                                      args=(loop, pid, callback, args),
+                                      daemon=True)
+            self._threads[pid] = thread
+            thread.start()
+
+        def remove_child_handler(self, pid):
+            # asyncio never calls remove_child_handler() !!!
+            # The method is no-op but is implemented because
+            # abstract base classe requires it
+            return True
+
+        def attach_loop(self, loop):
+            pass
+
+        def _do_waitpid(self, loop, expected_pid, callback, args):
+            assert expected_pid > 0
+
+            try:
+                pid, status = os.waitpid(expected_pid, 0)
+            except ChildProcessError:
+                # The child process is already reaped
+                # (may happen if waitpid() is called elsewhere).
+                pid = expected_pid
+                returncode = 255
+                logger.warning(
+                    "Unknown child process pid %d, will report returncode 255",
+                    pid)
+            else:
+                if os.WIFEXITED(status):
+                    returncode = os.WEXITSTATUS(status)
+                elif os.WIFSIGNALED(status):
+                    returncode = -os.WTERMSIG(status)
+                else:
+                    raise ValueError(f'unknown wait status {status}')
+                if loop.get_debug():
+                    logger.debug('process %s exited with returncode %s',
+                                 expected_pid, returncode)
+
+            if loop.is_closed():
+                logger.warning("Loop %r that handles pid %r is closed", loop, pid)
+            else:
+                loop.call_soon_threadsafe(callback, pid, returncode, *args)
+
+            self._threads.pop(expected_pid)
+
+    # unlike SafeChildWatcher which handles SIGCHLD in the main thread,
+    # ThreadedChildWatcher runs in a separated thread, hence allows us to
+    # run create_subprocess_exec() in non-main thread, see
+    # https://bugs.python.org/issue35621
+    asyncio.set_child_watcher(ThreadedChildWatcher())
 
-    def process_exited(self) -> None:
-        self.exited.set_result(True)
 
 try:
     from asyncio import run as async_run   # type: ignore[attr-defined]
@@ -1228,8 +1299,11 @@ except ImportError:
             asyncio.set_event_loop(loop)
             return loop.run_until_complete(coro)
         finally:
-            asyncio.set_event_loop(None)
-            loop.close()
+            try:
+                loop.run_until_complete(loop.shutdown_asyncgens())
+            finally:
+                asyncio.set_event_loop(None)
+                loop.close()
 
 def call(ctx: CephadmContext,
          command: List[str],
@@ -1254,47 +1328,43 @@ def call(ctx: CephadmContext,
 
     logger.debug("Running command: %s" % ' '.join(command))
 
-    async def run_with_timeout():
-        loop = asyncio.get_event_loop()
-        proc_exited = loop.create_future()
-        protocol_factory = partial(StreamReaderProto,
-                                   proc_exited,
-                                   prefix, verbosity)
-        transport, protocol = await loop.subprocess_exec(
-            protocol_factory,
+    async def tee(reader: asyncio.StreamReader) -> str:
+        collected = StringIO()
+        async for line in reader:
+            message = line.decode('utf-8')
+            collected.write(message)
+            if verbosity == CallVerbosity.VERBOSE:
+                logger.info(prefix + message.rstrip())
+            elif verbosity != CallVerbosity.SILENT:
+                logger.debug(prefix + message.rstrip())
+        return collected.getvalue()
+
+    async def run_with_timeout() -> Tuple[str, str, int]:
+        process = await asyncio.create_subprocess_exec(
             *command,
-            close_fds=True,
-            **kwargs)
-        proc_transport = cast(asyncio.SubprocessTransport, transport)
-        proc_protocol = cast(StreamReaderProto, protocol)
-        returncode = 0
+            stdout=asyncio.subprocess.PIPE,
+            stderr=asyncio.subprocess.PIPE)
+        assert process.stdout
+        assert process.stderr
         try:
-            if timeout:
-                await asyncio.wait_for(proc_exited, timeout)
-            else:
-                await proc_exited
+            stdout, stderr = await asyncio.gather(tee(process.stdout),
+                                                  tee(process.stderr))
+            returncode = await asyncio.wait_for(process.wait(), timeout)
         except asyncio.TimeoutError:
             logger.info(prefix + f'timeout after {timeout} seconds')
-            returncode = 124
+            return '', '', 124
         else:
-            returncode = cast(int, proc_transport.get_returncode())
-        finally:
-            proc_transport.close()
-        return (returncode,
-                proc_protocol.stdout,
-                proc_protocol.stderr)
+            return stdout, stderr, returncode
 
-    returncode, out, err = async_run(run_with_timeout())
+    stdout, stderr, returncode = async_run(run_with_timeout())
     if returncode != 0 and verbosity == CallVerbosity.VERBOSE_ON_FAILURE:
-        # dump stdout + stderr
         logger.info('Non-zero exit code %d from %s',
                     returncode, ' '.join(command))
-        for line in out.splitlines():
+        for line in stdout.splitlines():
             logger.info(prefix + 'stdout ' + line)
-        for line in err.splitlines():
+        for line in stderr.splitlines():
             logger.info(prefix + 'stderr ' + line)
-
-    return out, err, returncode
+    return stdout, stderr, returncode
 
 
 def call_throws(