-from cStringIO import StringIO
+from StringIO import StringIO
+import paramiko
import socket
from mock import MagicMock, patch
ConnectionLostError)
+def set_buffer_contents(buf, contents):
+ buf.seek(0)
+ if isinstance(contents, basestring):
+ buf.write(contents)
+ elif isinstance(contents, (list, tuple)):
+ buf.writelines(contents)
+ else:
+ raise TypeError(
+ "% is is a %s; should be a string, list or tuple" % (
+ contents, type(contents)
+ )
+ )
+ buf.seek(0)
+
+
class TestRun(object):
def setup(self):
self.start_patchers()
'teuthology.orchestra.run.RemoteProcess',
self.m_remote_process,
)
- self.m_channelfile = MagicMock(spec=run.ChannelFile)
- self.m_ssh = MagicMock()
+ self.m_channel = MagicMock(spec=paramiko.Channel)()
+ """
+ self.m_channelfile = MagicMock(wraps=paramiko.ChannelFile)
+ self.m_stdin_buf = self.m_channelfile(self.m_channel())
+ self.m_stdout_buf = self.m_channelfile(self.m_channel())
+ self.m_stderr_buf = self.m_channelfile(self.m_channel())
+ """
+ class M_ChannelFile(StringIO):
+ channel = MagicMock(spec=paramiko.Channel)()
+
+ self.m_channelfile = M_ChannelFile
self.m_stdin_buf = self.m_channelfile()
self.m_stdout_buf = self.m_channelfile()
self.m_stderr_buf = self.m_channelfile()
+ self.m_ssh = MagicMock()
self.m_ssh.exec_command.return_value = (
self.m_stdin_buf,
self.m_stdout_buf,
def test_capture_stdout(self):
output = 'foo\nbar'
-
- def m_copyfileobj(src, dest):
- print output
- print dest
- dest.write(output)
-
+ set_buffer_contents(self.m_stdout_buf, output)
self.m_stdout_buf.channel.recv_exit_status.return_value = 0
- with patch(
- 'teuthology.orchestra.run.shutil.copyfileobj',
- m_copyfileobj,
- ):
- proc = run.run(
- client=self.m_ssh,
- args=['foo', 'bar baz'],
- stdout=StringIO(),
- )
+ stdout = StringIO()
+ proc = run.run(
+ client=self.m_ssh,
+ args=['foo', 'bar baz'],
+ stdout=stdout,
+ )
+ assert proc.stdout is stdout
+ assert proc.stdout.read() == output
assert proc.stdout.getvalue() == output
+ def test_capture_stderr_newline(self):
+ output = 'foo\nbar\n'
+ set_buffer_contents(self.m_stderr_buf, output)
+ self.m_stderr_buf.channel.recv_exit_status.return_value = 0
+ stderr = StringIO()
+ proc = run.run(
+ client=self.m_ssh,
+ args=['foo', 'bar baz'],
+ stderr=stderr,
+ )
+ assert proc.stderr is stderr
+ assert proc.stderr.read() == output
+ assert proc.stderr.getvalue() == output
+
def test_status_bad(self):
self.m_stdout_buf.channel.recv_exit_status.return_value = 42
with raises(CommandFailedError) as exc:
def test_stdout_pipe(self):
self.m_stdout_buf.channel.recv_exit_status.return_value = 0
- self.m_stdout_buf.read.side_effect = [
- 'one', 'two', '',
- ]
+ lines = ['one\n', 'two', '']
+ set_buffer_contents(self.m_stdout_buf, lines)
proc = run.run(
client=self.m_ssh,
args=['foo'],
wait=False
)
assert proc.poll() is None
- assert proc.stdout.read() == 'one'
- assert proc.stdout.read() == 'two'
- assert proc.stdout.read() == ''
+ assert proc.stdout.readline() == lines[0]
+ assert proc.stdout.readline() == lines[1]
+ assert proc.stdout.readline() == lines[2]
code = proc.wait()
assert code == 0
assert proc.exitstatus == 0
def test_stderr_pipe(self):
self.m_stdout_buf.channel.recv_exit_status.return_value = 0
- self.m_stderr_buf.read.side_effect = [
- 'one', 'two', '',
- ]
+ lines = ['one\n', 'two', '']
+ set_buffer_contents(self.m_stderr_buf, lines)
proc = run.run(
client=self.m_ssh,
args=['foo'],
wait=False
)
assert proc.poll() is None
- assert proc.stderr.read() == 'one'
- assert proc.stderr.read() == 'two'
- assert proc.stderr.read() == ''
+ assert proc.stderr.readline() == lines[0]
+ assert proc.stderr.readline() == lines[1]
+ assert proc.stderr.readline() == lines[2]
code = proc.wait()
assert code == 0
assert proc.exitstatus == 0