From 30070be24860c2dfc0d27c79b0958d6c09316d87 Mon Sep 17 00:00:00 2001 From: Kefu Chai Date: Sun, 24 Jan 2021 14:58:51 +0800 Subject: [PATCH] cephadm: refactor call() using asyncio.asyncio.StreamReader simpler this way, also fix a couple issues: * create a child watcher explicitly, see https://bugs.python.org/issue35621 * use StringIO for collecting outputs for better performance, instead of appending the lines to an existing str * catch ValueError when reading from the stream reader, because StreamReader.readline() could raise ValueError when it reaches the buffer limit while looking for a separator. in this case, we should try again, in hope that the spawned process can feed the reader with more data which contains the separator (i.e., b'\n'). * backport ThreadedChildWatcher from Python 3.8 so we can run create_subprocess_exec() in non-main threads. Signed-off-by: Kefu Chai --- src/cephadm/cephadm | 200 ++++++++++++++++++++++++++++++-------------- 1 file changed, 135 insertions(+), 65 deletions(-) diff --git a/src/cephadm/cephadm b/src/cephadm/cephadm index 1f0e178caaca2..106944e6e2831 100755 --- a/src/cephadm/cephadm +++ b/src/cephadm/cephadm @@ -38,6 +38,7 @@ You can invoke cephadm in two ways: injected_stdin = '...' """ import asyncio +import asyncio.subprocess import argparse import datetime import fcntl @@ -64,17 +65,17 @@ from socketserver import ThreadingMixIn from http.server import BaseHTTPRequestHandler, HTTPServer import signal import io -from contextlib import closing, redirect_stdout +from contextlib import redirect_stdout import ssl from enum import Enum -from typing import cast, Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO +from typing import Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO import re import uuid -from functools import partial, wraps +from functools import wraps from glob import glob from threading import Thread, RLock @@ -1185,38 +1186,108 @@ class CallVerbosity(Enum): VERBOSE = 3 -class StreamReaderProto(asyncio.SubprocessProtocol): - def __init__(self, - exited: asyncio.Future, - desc: str, - verbosity: CallVerbosity) -> None: - self.exited = exited - self.desc = desc - self.verbosity = verbosity - self.stdout = '' - self.stderr = '' - - def pipe_data_received(self, fd: int, data: bytes) -> None: - prefix = '' - lines = data.decode('utf-8') - - if fd == sys.stdout.fileno(): - prefix = self.desc + 'stdout' - self.stdout += lines - elif fd == sys.stderr.fileno(): - prefix = self.desc + 'stderr' - self.stderr += lines - else: - assert False, f"unknown data received from fd: {fd}" +if sys.version_info < (3, 8): + import itertools + import threading + import warnings + from asyncio import events + + class ThreadedChildWatcher(asyncio.AbstractChildWatcher): + """Threaded child watcher implementation. + The watcher uses a thread per process + for waiting for the process finish. + It doesn't require subscription on POSIX signal + but a thread creation is not free. + The watcher has O(1) complexity, its performance doesn't depend + on amount of spawn processes. + """ + + def __init__(self): + self._pid_counter = itertools.count(0) + self._threads = {} + + def is_active(self): + return True + + def close(self): + self._join_threads() + + def _join_threads(self): + """Internal: Join all non-daemon threads""" + threads = [thread for thread in list(self._threads.values()) + if thread.is_alive() and not thread.daemon] + for thread in threads: + thread.join() - for line in lines.split('\n'): - if self.verbosity == CallVerbosity.VERBOSE: - logger.info(prefix + line) - elif self.verbosity != CallVerbosity.SILENT: - logger.debug(prefix + line) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __del__(self, _warn=warnings.warn): + threads = [thread for thread in list(self._threads.values()) + if thread.is_alive()] + if threads: + _warn(f"{self.__class__} has registered but not finished child processes", + ResourceWarning, + source=self) + + def add_child_handler(self, pid, callback, *args): + loop = events.get_event_loop() + thread = threading.Thread(target=self._do_waitpid, + name=f"waitpid-{next(self._pid_counter)}", + args=(loop, pid, callback, args), + daemon=True) + self._threads[pid] = thread + thread.start() + + def remove_child_handler(self, pid): + # asyncio never calls remove_child_handler() !!! + # The method is no-op but is implemented because + # abstract base classe requires it + return True + + def attach_loop(self, loop): + pass + + def _do_waitpid(self, loop, expected_pid, callback, args): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, 0) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + if os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + else: + raise ValueError(f'unknown wait status {status}') + if loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + + if loop.is_closed(): + logger.warning("Loop %r that handles pid %r is closed", loop, pid) + else: + loop.call_soon_threadsafe(callback, pid, returncode, *args) + + self._threads.pop(expected_pid) + + # unlike SafeChildWatcher which handles SIGCHLD in the main thread, + # ThreadedChildWatcher runs in a separated thread, hence allows us to + # run create_subprocess_exec() in non-main thread, see + # https://bugs.python.org/issue35621 + asyncio.set_child_watcher(ThreadedChildWatcher()) - def process_exited(self) -> None: - self.exited.set_result(True) try: from asyncio import run as async_run # type: ignore[attr-defined] @@ -1227,8 +1298,11 @@ except ImportError: asyncio.set_event_loop(loop) return loop.run_until_complete(coro) finally: - asyncio.set_event_loop(None) - loop.close() + try: + loop.run_until_complete(loop.shutdown_asyncgens()) + finally: + asyncio.set_event_loop(None) + loop.close() def call(ctx: CephadmContext, command: List[str], @@ -1253,47 +1327,43 @@ def call(ctx: CephadmContext, logger.debug("Running command: %s" % ' '.join(command)) - async def run_with_timeout(): - loop = asyncio.get_event_loop() - proc_exited = loop.create_future() - protocol_factory = partial(StreamReaderProto, - proc_exited, - prefix, verbosity) - transport, protocol = await loop.subprocess_exec( - protocol_factory, + async def tee(reader: asyncio.StreamReader) -> str: + collected = StringIO() + async for line in reader: + message = line.decode('utf-8') + collected.write(message) + if verbosity == CallVerbosity.VERBOSE: + logger.info(prefix + message.rstrip()) + elif verbosity != CallVerbosity.SILENT: + logger.debug(prefix + message.rstrip()) + return collected.getvalue() + + async def run_with_timeout() -> Tuple[str, str, int]: + process = await asyncio.create_subprocess_exec( *command, - close_fds=True, - **kwargs) - proc_transport = cast(asyncio.SubprocessTransport, transport) - proc_protocol = cast(StreamReaderProto, protocol) - returncode = 0 + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + assert process.stdout + assert process.stderr try: - if timeout: - await asyncio.wait_for(proc_exited, timeout) - else: - await proc_exited + stdout, stderr = await asyncio.gather(tee(process.stdout), + tee(process.stderr)) + returncode = await asyncio.wait_for(process.wait(), timeout) except asyncio.TimeoutError: logger.info(prefix + f'timeout after {timeout} seconds') - returncode = 124 + return '', '', 124 else: - returncode = cast(int, proc_transport.get_returncode()) - finally: - proc_transport.close() - return (returncode, - proc_protocol.stdout, - proc_protocol.stderr) + return stdout, stderr, returncode - returncode, out, err = async_run(run_with_timeout()) + stdout, stderr, returncode = async_run(run_with_timeout()) if returncode != 0 and verbosity == CallVerbosity.VERBOSE_ON_FAILURE: - # dump stdout + stderr logger.info('Non-zero exit code %d from %s', returncode, ' '.join(command)) - for line in out.splitlines(): + for line in stdout.splitlines(): logger.info(prefix + 'stdout ' + line) - for line in err.splitlines(): + for line in stderr.splitlines(): logger.info(prefix + 'stderr ' + line) - - return out, err, returncode + return stdout, stderr, returncode def call_throws( -- 2.39.5