injected_stdin = '...'
"""
import asyncio
+import asyncio.subprocess
import argparse
import datetime
import fcntl
from http.server import BaseHTTPRequestHandler, HTTPServer
import signal
import io
-from contextlib import closing, redirect_stdout
+from contextlib import redirect_stdout
import ssl
from enum import Enum
-from typing import cast, Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
+from typing import Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
import re
import uuid
-from functools import partial, wraps
+from functools import wraps
from glob import glob
from threading import Thread, RLock
VERBOSE = 3
-class StreamReaderProto(asyncio.SubprocessProtocol):
- def __init__(self,
- exited: asyncio.Future,
- desc: str,
- verbosity: CallVerbosity) -> None:
- self.exited = exited
- self.desc = desc
- self.verbosity = verbosity
- self.stdout = ''
- self.stderr = ''
-
- def pipe_data_received(self, fd: int, data: bytes) -> None:
- prefix = ''
- lines = data.decode('utf-8')
-
- if fd == sys.stdout.fileno():
- prefix = self.desc + 'stdout'
- self.stdout += lines
- elif fd == sys.stderr.fileno():
- prefix = self.desc + 'stderr'
- self.stderr += lines
- else:
- assert False, f"unknown data received from fd: {fd}"
+if sys.version_info < (3, 8):
+ import itertools
+ import threading
+ import warnings
+ from asyncio import events
+
+ class ThreadedChildWatcher(asyncio.AbstractChildWatcher):
+ """Threaded child watcher implementation.
+ The watcher uses a thread per process
+ for waiting for the process finish.
+ It doesn't require subscription on POSIX signal
+ but a thread creation is not free.
+ The watcher has O(1) complexity, its performance doesn't depend
+ on amount of spawn processes.
+ """
+
+ def __init__(self):
+ self._pid_counter = itertools.count(0)
+ self._threads = {}
+
+ def is_active(self):
+ return True
+
+ def close(self):
+ self._join_threads()
+
+ def _join_threads(self):
+ """Internal: Join all non-daemon threads"""
+ threads = [thread for thread in list(self._threads.values())
+ if thread.is_alive() and not thread.daemon]
+ for thread in threads:
+ thread.join()
- for line in lines.split('\n'):
- if self.verbosity == CallVerbosity.VERBOSE:
- logger.info(prefix + line)
- elif self.verbosity != CallVerbosity.SILENT:
- logger.debug(prefix + line)
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+ def __del__(self, _warn=warnings.warn):
+ threads = [thread for thread in list(self._threads.values())
+ if thread.is_alive()]
+ if threads:
+ _warn(f"{self.__class__} has registered but not finished child processes",
+ ResourceWarning,
+ source=self)
+
+ def add_child_handler(self, pid, callback, *args):
+ loop = events.get_event_loop()
+ thread = threading.Thread(target=self._do_waitpid,
+ name=f"waitpid-{next(self._pid_counter)}",
+ args=(loop, pid, callback, args),
+ daemon=True)
+ self._threads[pid] = thread
+ thread.start()
+
+ def remove_child_handler(self, pid):
+ # asyncio never calls remove_child_handler() !!!
+ # The method is no-op but is implemented because
+ # abstract base classe requires it
+ return True
+
+ def attach_loop(self, loop):
+ pass
+
+ def _do_waitpid(self, loop, expected_pid, callback, args):
+ assert expected_pid > 0
+
+ try:
+ pid, status = os.waitpid(expected_pid, 0)
+ except ChildProcessError:
+ # The child process is already reaped
+ # (may happen if waitpid() is called elsewhere).
+ pid = expected_pid
+ returncode = 255
+ logger.warning(
+ "Unknown child process pid %d, will report returncode 255",
+ pid)
+ else:
+ if os.WIFEXITED(status):
+ returncode = os.WEXITSTATUS(status)
+ elif os.WIFSIGNALED(status):
+ returncode = -os.WTERMSIG(status)
+ else:
+ raise ValueError(f'unknown wait status {status}')
+ if loop.get_debug():
+ logger.debug('process %s exited with returncode %s',
+ expected_pid, returncode)
+
+ if loop.is_closed():
+ logger.warning("Loop %r that handles pid %r is closed", loop, pid)
+ else:
+ loop.call_soon_threadsafe(callback, pid, returncode, *args)
+
+ self._threads.pop(expected_pid)
+
+ # unlike SafeChildWatcher which handles SIGCHLD in the main thread,
+ # ThreadedChildWatcher runs in a separated thread, hence allows us to
+ # run create_subprocess_exec() in non-main thread, see
+ # https://bugs.python.org/issue35621
+ asyncio.set_child_watcher(ThreadedChildWatcher())
- def process_exited(self) -> None:
- self.exited.set_result(True)
try:
from asyncio import run as async_run # type: ignore[attr-defined]
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
finally:
- asyncio.set_event_loop(None)
- loop.close()
+ try:
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ finally:
+ asyncio.set_event_loop(None)
+ loop.close()
def call(ctx: CephadmContext,
command: List[str],
logger.debug("Running command: %s" % ' '.join(command))
- async def run_with_timeout():
- loop = asyncio.get_event_loop()
- proc_exited = loop.create_future()
- protocol_factory = partial(StreamReaderProto,
- proc_exited,
- prefix, verbosity)
- transport, protocol = await loop.subprocess_exec(
- protocol_factory,
+ async def tee(reader: asyncio.StreamReader) -> str:
+ collected = StringIO()
+ async for line in reader:
+ message = line.decode('utf-8')
+ collected.write(message)
+ if verbosity == CallVerbosity.VERBOSE:
+ logger.info(prefix + message.rstrip())
+ elif verbosity != CallVerbosity.SILENT:
+ logger.debug(prefix + message.rstrip())
+ return collected.getvalue()
+
+ async def run_with_timeout() -> Tuple[str, str, int]:
+ process = await asyncio.create_subprocess_exec(
*command,
- close_fds=True,
- **kwargs)
- proc_transport = cast(asyncio.SubprocessTransport, transport)
- proc_protocol = cast(StreamReaderProto, protocol)
- returncode = 0
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE)
+ assert process.stdout
+ assert process.stderr
try:
- if timeout:
- await asyncio.wait_for(proc_exited, timeout)
- else:
- await proc_exited
+ stdout, stderr = await asyncio.gather(tee(process.stdout),
+ tee(process.stderr))
+ returncode = await asyncio.wait_for(process.wait(), timeout)
except asyncio.TimeoutError:
logger.info(prefix + f'timeout after {timeout} seconds')
- returncode = 124
+ return '', '', 124
else:
- returncode = cast(int, proc_transport.get_returncode())
- finally:
- proc_transport.close()
- return (returncode,
- proc_protocol.stdout,
- proc_protocol.stderr)
+ return stdout, stderr, returncode
- returncode, out, err = async_run(run_with_timeout())
+ stdout, stderr, returncode = async_run(run_with_timeout())
if returncode != 0 and verbosity == CallVerbosity.VERBOSE_ON_FAILURE:
- # dump stdout + stderr
logger.info('Non-zero exit code %d from %s',
returncode, ' '.join(command))
- for line in out.splitlines():
+ for line in stdout.splitlines():
logger.info(prefix + 'stdout ' + line)
- for line in err.splitlines():
+ for line in stderr.splitlines():
logger.info(prefix + 'stderr ' + line)
-
- return out, err, returncode
+ return stdout, stderr, returncode
def call_throws(