]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
Allow easy writing to stdin of remote processes.
authorTommi Virtanen <tommi.virtanen@dreamhost.com>
Tue, 24 May 2011 20:00:44 +0000 (13:00 -0700)
committerTommi Virtanen <tommi.virtanen@dreamhost.com>
Tue, 24 May 2011 20:00:44 +0000 (13:00 -0700)
orchestra/run.py
orchestra/test/test_run.py

index 15adf906b3162b208859bb5dfa7ecf0fb780b47d..f98daedd20a88c30aeccde811973a8fdd4c7c1a2 100644 (file)
@@ -9,29 +9,13 @@ import shutil
 log = logging.getLogger(__name__)
 
 class RemoteProcess(object):
-    __slots__ = ['command', 'stdin', 'stdout', 'stderr', '_get_exitstatus']
-    def __init__(self, command, stdin, stdout, stderr, get_exitstatus):
+    __slots__ = ['command', 'stdin', 'stdout', 'stderr', 'exitstatus']
+    def __init__(self, command, stdin, stdout, stderr, exitstatus):
         self.command = command
         self.stdin = stdin
         self.stdout = stdout
         self.stderr = stderr
-        self._get_exitstatus = get_exitstatus
-
-    @property
-    def exitstatus(self):
-        """
-        Wait for exit and return exit status.
-
-        Will return None on signals and connection loss.
-
-        This will likely block until you've closed stdin and consumed
-        stdout and stderr.
-        """
-        status = self._get_exitstatus()
-        # -1 on connection loss *and* signals; map to more pythonic None
-        if status == -1:
-            status = None
-        return status
+        self.exitstatus = exitstatus
 
 def execute(client, args):
     """
@@ -42,15 +26,28 @@ def execute(client, args):
     :param client: SSHConnection to run the command with
     :param args: command to run
     :type args: list of string
+
+    Returns a RemoteProcess, where exitstatus is a callable that will
+    block until the exit status is available.
     """
     cmd = ' '.join(pipes.quote(a) for a in args)
     (in_, out, err) = client.exec_command(cmd)
+
+    def get_exitstatus():
+        status = out.channel.recv_exit_status()
+        # -1 on connection loss *and* signals; map to more pythonic None
+        if status == -1:
+            status = None
+        return status
+
     r = RemoteProcess(
         command=cmd,
         stdin=in_,
         stdout=out,
         stderr=err,
-        get_exitstatus=out.channel.recv_exit_status,
+        # this is a callable that will block until the status is
+        # available
+        exitstatus=get_exitstatus,
         )
     return r
 
@@ -139,6 +136,15 @@ def spawn_asyncresult(fn, *args, **kwargs):
 
     return r
 
+class Sentinel(object):
+    def __init__(self, name):
+        self.name = name
+
+    def __str__(self):
+        return self.name
+
+PIPE = Sentinel('PIPE')
+
 def run(
     client, args,
     stdin=None, stdout=None, stderr=None,
@@ -152,7 +158,7 @@ def run(
     :param client: SSHConnection to run the command with
     :param args: command to run
     :type args: list of string
-    :param stdin: Standard input to send; either a string, a file-like object, or None.
+    :param stdin: Standard input to send; either a string, a file-like object, None, or `PIPE`. `PIPE` means caller is responsible for closing stdin, or command may never exit.
     :param stdout: What to do with standard output. Either a file-like object, a `logging.Logger`, or `None` for copying to default log.
     :param stderr: What to do with standard error. See `stdout`.
     :param logger: If logging, write stdout/stderr to "out" and "err" children of this logger. Defaults to logger named after this module.
@@ -161,7 +167,12 @@ def run(
     """
     r = execute(client, args)
 
-    g_in = gevent.spawn(copy_and_close, stdin, r.stdin)
+    g_in = None
+    if stdin is not PIPE:
+        g_in = gevent.spawn(copy_and_close, stdin, r.stdin)
+        r.stdin = None
+    else:
+        assert not wait, "Using PIPE for stdin without wait=False would deadlock."
 
     if logger is None:
         logger = log
@@ -169,16 +180,19 @@ def run(
     if stderr is None:
         stderr = logger.getChild('err')
     g_err = gevent.spawn(copy_file_to, r.stderr, stderr)
+    r.stderr = stderr
 
     if stdout is None:
         stdout = logger.getChild('out')
     copy_file_to(r.stdout, stdout)
+    r.stdout = stdout
 
     g_err.get()
-    g_in.get()
+    if g_in is not None:
+        g_in.get()
 
-    def get_status():
-        status = r.exitstatus
+    def _check_status(status):
+        status = status()
         if check_status:
             if status is None:
                 # command either died due to a signal, or the connection
@@ -196,13 +210,8 @@ def run(
         return status
 
     if wait:
-        status = get_status()
+        r.exitstatus = _check_status(r.exitstatus)
     else:
-        status = spawn_asyncresult(get_status)
+        r.exitstatus = spawn_asyncresult(_check_status, r.exitstatus)
 
-    return CommandResult(
-        command=r.command,
-        stdout=stdout,
-        stderr=stderr,
-        exitstatus=status,
-        )
+    return r
index ae1edf2e9ee1d73a8cf107217c08341f09cdf424..79bfd7508b6a2e651a2f0e7e86aa5b6ac9e25080 100644 (file)
@@ -270,3 +270,34 @@ def test_run_nowait():
         )
     eq(e.exitstatus, 42)
     eq(str(e), "Command failed with status 42: 'foo'")
+
+
+@nose.with_setup(fudge.clear_expectations)
+@fudge.with_fakes
+def test_run_stdin_pipe():
+    ssh = fudge.Fake('SSHConnection')
+    cmd = ssh.expects('exec_command')
+    cmd.with_args("foo")
+    in_ = fudge.Fake('ChannelFile').is_a_stub()
+    out = fudge.Fake('ChannelFile').is_a_stub()
+    err = fudge.Fake('ChannelFile').is_a_stub()
+    cmd.returns((in_, out, err))
+    out.expects('xreadlines').with_args().returns([])
+    err.expects('xreadlines').with_args().returns([])
+    logger = fudge.Fake('logger').is_a_stub()
+    channel = fudge.Fake('channel')
+    out.has_attr(channel=channel)
+    channel.expects('recv_exit_status').with_args().returns(0)
+    r = run.run(
+        client=ssh,
+        logger=logger,
+        args=['foo'],
+        stdin=run.PIPE,
+        wait=False,
+        )
+    r.stdin.write('bar')
+    eq(r.command, 'foo')
+    assert isinstance(r.exitstatus, gevent.event.AsyncResult)
+    eq(r.exitstatus.ready(), False)
+    got = r.exitstatus.get()
+    eq(got, 0)