]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
cephadm: rewrite call() with asyncio 39035/head
authorKefu Chai <kchai@redhat.com>
Sat, 23 Jan 2021 05:18:56 +0000 (13:18 +0800)
committerKefu Chai <kchai@redhat.com>
Sat, 23 Jan 2021 17:43:33 +0000 (01:43 +0800)
for better readability, also return 124 when subprocess times out

Signed-off-by: Kefu Chai <kchai@redhat.com>
src/cephadm/cephadm

index b2084e2424028b818e5af773873f8fc24dcad826..1f0e178caaca22e177abb34d67b8933912cc54e8 100755 (executable)
@@ -37,6 +37,7 @@ You can invoke cephadm in two ways:
 
        injected_stdin = '...'
 """
+import asyncio
 import argparse
 import datetime
 import fcntl
@@ -63,17 +64,17 @@ from socketserver import ThreadingMixIn
 from http.server import BaseHTTPRequestHandler, HTTPServer
 import signal
 import io
-from contextlib import redirect_stdout
+from contextlib import closing, redirect_stdout
 import ssl
 from enum import Enum
 
 
-from typing import Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
+from typing import cast, Dict, List, Tuple, Optional, Union, Any, NoReturn, Callable, IO
 
 import re
 import uuid
 
-from functools import wraps
+from functools import partial, wraps
 from glob import glob
 from threading import Thread, RLock
 
@@ -1184,6 +1185,51 @@ class CallVerbosity(Enum):
     VERBOSE = 3
 
 
+class StreamReaderProto(asyncio.SubprocessProtocol):
+    def __init__(self,
+                 exited: asyncio.Future,
+                 desc: str,
+                 verbosity: CallVerbosity) -> None:
+        self.exited = exited
+        self.desc = desc
+        self.verbosity = verbosity
+        self.stdout = ''
+        self.stderr = ''
+
+    def pipe_data_received(self, fd: int, data: bytes) -> None:
+        prefix = ''
+        lines = data.decode('utf-8')
+
+        if fd == sys.stdout.fileno():
+            prefix = self.desc + 'stdout'
+            self.stdout += lines
+        elif fd == sys.stderr.fileno():
+            prefix = self.desc + 'stderr'
+            self.stderr += lines
+        else:
+            assert False, f"unknown data received from fd: {fd}"
+
+        for line in lines.split('\n'):
+            if self.verbosity == CallVerbosity.VERBOSE:
+                logger.info(prefix + line)
+            elif self.verbosity != CallVerbosity.SILENT:
+                logger.debug(prefix + line)
+
+    def process_exited(self) -> None:
+        self.exited.set_result(True)
+
+try:
+    from asyncio import run as async_run   # type: ignore[attr-defined]
+except ImportError:
+    def async_run(coro):  # type: ignore
+        loop = asyncio.new_event_loop()
+        try:
+            asyncio.set_event_loop(loop)
+            return loop.run_until_complete(coro)
+        finally:
+            asyncio.set_event_loop(None)
+            loop.close()
+
 def call(ctx: CephadmContext,
          command: List[str],
          desc: Optional[str] = None,
@@ -1200,117 +1246,52 @@ def call(ctx: CephadmContext,
     :param timeout: timeout in seconds
     """
 
-    if desc is None:
-        desc = command[0]
-    if desc:
-        desc += ': '
+    prefix = command[0] if desc is None else desc
+    if prefix:
+        prefix += ': '
     timeout = timeout or ctx.timeout
 
     logger.debug("Running command: %s" % ' '.join(command))
-    process = subprocess.Popen(
-        command,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.PIPE,
-        close_fds=True,
-        **kwargs
-    )
-    # get current p.stdout flags, add O_NONBLOCK
-    assert process.stdout is not None
-    assert process.stderr is not None
-    stdout_flags = fcntl.fcntl(process.stdout, fcntl.F_GETFL)
-    stderr_flags = fcntl.fcntl(process.stderr, fcntl.F_GETFL)
-    fcntl.fcntl(process.stdout, fcntl.F_SETFL, stdout_flags | os.O_NONBLOCK)
-    fcntl.fcntl(process.stderr, fcntl.F_SETFL, stderr_flags | os.O_NONBLOCK)
-
-    out = ''
-    err = ''
-    reads = None
-    stop = False
-    out_buffer = ''   # partial line (no newline yet)
-    err_buffer = ''   # partial line (no newline yet)
-    start_time = time.time()
-    end_time = None
-    if timeout:
-        end_time = start_time + timeout
-    while not stop:
-        if end_time and (time.time() >= end_time):
-            stop = True
-            if process.poll() is None:
-                logger.info(desc + 'timeout after %s seconds' % timeout)
-                process.kill()
-        if reads and process.poll() is not None:
-            # we want to stop, but first read off anything remaining
-            # on stdout/stderr
-            stop = True
-        else:
-            reads, _, _ = select.select(
-                [process.stdout.fileno(), process.stderr.fileno()],
-                [], [], timeout
-            )
-        for fd in reads:
-            try:
-                message = str()
-                message_b = os.read(fd, 1024)
-                if isinstance(message_b, bytes):
-                    message = message_b.decode('utf-8')
-                elif isinstance(message_b, str):
-                    message = message_b
-                else:
-                    assert False
-                if stop and message:
-                    # process has terminated, but have more to read still, so not stopping yet
-                    # (os.read returns '' when it encounters EOF)
-                    stop = False
-                if not message:
-                    continue
-                if fd == process.stdout.fileno():
-                    out += message
-                    message = out_buffer + message
-                    lines = message.split('\n')
-                    out_buffer = lines.pop()
-                    for line in lines:
-                        if verbosity == CallVerbosity.VERBOSE:
-                            logger.info(desc + 'stdout ' + line)
-                        elif verbosity != CallVerbosity.SILENT:
-                            logger.debug(desc + 'stdout ' + line)
-                elif fd == process.stderr.fileno():
-                    err += message
-                    message = err_buffer + message
-                    lines = message.split('\n')
-                    err_buffer = lines.pop()
-                    for line in lines:
-                        if verbosity == CallVerbosity.VERBOSE:
-                            logger.info(desc + 'stderr ' + line)
-                        elif verbosity != CallVerbosity.SILENT:
-                            logger.debug(desc + 'stderr ' + line)
-                else:
-                    assert False
-            except (IOError, OSError):
-                pass
-        if verbosity == CallVerbosity.VERBOSE:
-            logger.debug(desc + 'profile rt=%s, stop=%s, exit=%s, reads=%s'
-                % (time.time()-start_time, stop, process.poll(), reads))
-
-    returncode = process.wait()
-
-    if out_buffer != '':
-        if verbosity == CallVerbosity.VERBOSE:
-            logger.info(desc + 'stdout ' + out_buffer)
-        elif verbosity != CallVerbosity.SILENT:
-            logger.debug(desc + 'stdout ' + out_buffer)
-    if err_buffer != '':
-        if verbosity == CallVerbosity.VERBOSE:
-            logger.info(desc + 'stderr ' + err_buffer)
-        elif verbosity != CallVerbosity.SILENT:
-            logger.debug(desc + 'stderr ' + err_buffer)
 
+    async def run_with_timeout():
+        loop = asyncio.get_event_loop()
+        proc_exited = loop.create_future()
+        protocol_factory = partial(StreamReaderProto,
+                                   proc_exited,
+                                   prefix, verbosity)
+        transport, protocol = await loop.subprocess_exec(
+            protocol_factory,
+            *command,
+            close_fds=True,
+            **kwargs)
+        proc_transport = cast(asyncio.SubprocessTransport, transport)
+        proc_protocol = cast(StreamReaderProto, protocol)
+        returncode = 0
+        try:
+            if timeout:
+                await asyncio.wait_for(proc_exited, timeout)
+            else:
+                await proc_exited
+        except asyncio.TimeoutError:
+            logger.info(prefix + f'timeout after {timeout} seconds')
+            returncode = 124
+        else:
+            returncode = cast(int, proc_transport.get_returncode())
+        finally:
+            proc_transport.close()
+        return (returncode,
+                proc_protocol.stdout,
+                proc_protocol.stderr)
+
+    returncode, out, err = async_run(run_with_timeout())
     if returncode != 0 and verbosity == CallVerbosity.VERBOSE_ON_FAILURE:
         # dump stdout + stderr
-        logger.info('Non-zero exit code %d from %s' % (returncode, ' '.join(command)))
+        logger.info('Non-zero exit code %d from %s',
+                    returncode, ' '.join(command))
         for line in out.splitlines():
-            logger.info(desc + 'stdout ' + line)
+            logger.info(prefix + 'stdout ' + line)
         for line in err.splitlines():
-            logger.info(desc + 'stderr ' + line)
+            logger.info(prefix + 'stderr ' + line)
 
     return out, err, returncode