From: Alfredo Deza Date: Wed, 9 Oct 2013 14:31:55 +0000 (-0400) Subject: process will not pass the actual sudo command X-Git-Tag: 0.0.6~9 X-Git-Url: http://git-server-git.apps.pok.os.sepia.ceph.com/?a=commitdiff_plain;h=83fb51573a7c55265fb172cd8a47e42bb80192d1;p=remoto.git process will not pass the actual sudo command --- 83fb51573a7c55265fb172cd8a47e42bb80192d1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b327d00 --- /dev/null +++ b/.gitignore @@ -0,0 +1,33 @@ +*.py[cod] + +# C extensions +*.so + +# Packages +*.egg +*.egg-info +dist +build +eggs +parts +var +sdist +develop-eggs +.installed.cfg +lib64 + +# Installer logs +pip-log.txt + +# Unit test / coverage reports +.coverage +.tox +nosetests.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject diff --git a/CHANGELOG.rst b/CHANGELOG.rst new file mode 100644 index 0000000..f0ac6fc --- /dev/null +++ b/CHANGELOG.rst @@ -0,0 +1,21 @@ +0.0.5 +----- +* Allow more than one thread to be started in the connection +* log at debug level the name of the function to be remotely + executed + +0.0.4 +----- +* Create a way to execute functions remotely + +0.0.3 +----- +* If the hostname passed in to the connection matches the local + hostname, then do a local connection (not an ssh one) + +0.0.2 +----- +* Allow a context manager for running one-off commands with the connection + object. +* ``process.run`` can now take in a timeout value so that it does not hang in + remote processes diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3589596 --- /dev/null +++ b/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) +Copyright (c) 2013 Alfredo Deza + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..3c01b12 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include setup.py +include README.rst diff --git a/README.rst b/README.rst new file mode 100644 index 0000000..2bba175 --- /dev/null +++ b/README.rst @@ -0,0 +1,119 @@ +remoto +====== +A very simplistic remote-command-executor using ``ssh`` and Python in the +remote end. + +All the heavy lifting is done by execnet, while this minimal API provides the +bare minimum to handle easy logging and connections from the remote end. + +``remoto`` is a bit opinionated as it was conceived to replace helpers and +remote utilities for ``ceph-deploy`` a tool to run remote commands to configure +and setup the distributed file system Ceph. + + +Example Usage +------------- +The usage aims to be extremely straightforward, with a very minimal set of +helpers and utilities for remote processes and logging output. + +The most basic example will use the ``run`` helper to execute a command on the +remote end. It does require a logging object, which needs to be one that, at +the very least, has both ``error`` and ``debug``. Those are called for +``stderr`` and ``stdout`` respectively. + +This is how it would look with a basic logger passed in:: + + >>> logger = logging.getLogger('hostname') + >>> conn = remoto.Connection('hostname', logger=logger) + >>> run(conn, ['ls', '-a']) + 2013-09-07 15:32:06,662 [hostname][DEBUG] . + 2013-09-07 15:32:06,662 [hostname][DEBUG] .. + 2013-09-07 15:32:06,662 [hostname][DEBUG] .bash_history + 2013-09-07 15:32:06,662 [hostname][DEBUG] .bash_logout + 2013-09-07 15:32:06,662 [hostname][DEBUG] .bashrc + 2013-09-07 15:32:06,662 [hostname][DEBUG] .cache + 2013-09-07 15:32:06,664 [hostname][DEBUG] .profile + 2013-09-07 15:32:06,664 [hostname][DEBUG] .ssh + +The ``run`` helper will display the ``stderr`` and ``stdout`` as ``ERROR`` and +``DEBUG`` respectively. + +For other types of usage (like checking exit status codes, or raising upon +them) ``remoto`` does provide them too. + + +Remote Commands +=============== + +``process.run`` +--------------- +Calling remote commands can be done in a few different ways. The most simple +one is with ``process.run``:: + + >>> from remoto.process import run + >>> from remoto import Connection + >>> logger = my_logging_setup('hostname') + >>> conn = Connection('hostname') + >>> run(conn, ['whoami']) + 2013-09-07 15:32:06,664 [hostname][DEBUG] root + +Note however, that you are not capturing results or information from the remote +end. The intention here is only to be able to run a command and log its output. +It is a *fire and forget* call. + + +``process.check`` +----------------- +This callable, allows the caller to deal with the ``stderr``, ``stdout`` and +exit code. It returns it in a 3 item tuple:: + + >>> from remoto.process import check + >>> check(conn, ['ls', '/nonexistent/path']) + ([], ['ls: cannot access /nonexistent/path: No such file or directory'], 2) + +Note that the ``stdout`` and ``stderr`` items are returned as lists with the ``\n`` +characters removed. + +This is useful if you need to process the information back locally, as opposed +to just firing and forgetting (while logging, like ``process.run``). + + +Remote Functions +================ + +To execute remote functions (ideally) you would need to define them in a module +and add the following to the end of that module:: + + if __name__ == '__channelexec__': + for item in channel: + channel.send(eval(item)) + + +If you had a function in a module named ``foo`` that looks like this:: + + import os + + def listdir(path): + return os.listdir(path) + +To be able to execute that ``listdir`` function remotely you would need to pass +the module to the connection object and then call that function:: + + >>> import foo + >>> conn = Connection('hostname') + >>> remote_foo = conn.import_module(foo) + >>> remote_foo.listdir('.') + ['.bash_logout', + '.profile', + '.veewee_version', + '.lesshst', + 'python', + '.vbox_version', + 'ceph', + '.cache', + '.ssh'] + +Note that functions to be executed remotely **cannot** accept objects as +arguments, just normal Python data structures, like tuples, lists and +dictionaries. Also safe to use are ints and strings. + diff --git a/remoto/__init__.py b/remoto/__init__.py new file mode 100644 index 0000000..a418339 --- /dev/null +++ b/remoto/__init__.py @@ -0,0 +1,4 @@ +from .connection import Connection + + +__version__ = '0.0.5' diff --git a/remoto/connection.py b/remoto/connection.py new file mode 100644 index 0000000..e24a15d --- /dev/null +++ b/remoto/connection.py @@ -0,0 +1,108 @@ +import socket +from .lib import execnet + + +# +# Connection Object +# + +class Connection(object): + + def __init__(self, hostname, logger=None, sudo=False, threads=1): + self.hostname = hostname + self.gateway = self._make_gateway(hostname) + self.logger = logger or FakeRemoteLogger() + self.sudo = sudo + self.channel = None + self.gateway.remote_init_threads(threads) + + def _make_gateway(self, hostname): + if needs_ssh(hostname): + return execnet.makegateway('ssh=%s' % hostname) + return execnet.makegateway() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.exit() + return False + + def execute(self, function, **kw): + return self.gateway.remote_exec(function, **kw) + + def exit(self): + self.gateway.exit() + + def import_module(self, module): + return ModuleExecute(self.gateway, module, self.logger) + + +class ModuleExecute(object): + + def __init__(self, gateway, module, logger=None): + self.channel = gateway.remote_exec(module) + self.module = module + self.logger = logger + + def __getattr__(self, name): + if not hasattr(self.module, name): + msg = "module %s does not have attribute %s" % (str(self.module), name) + raise AttributeError(msg) + docstring = self._get_func_doc(getattr(self.module, name)) + + def wrapper(*args): + arguments = self._convert_args(args) + if docstring: + self.logger.debug(docstring) + self.channel.send("%s(%s)" % (name, arguments)) + return self.channel.receive() + return wrapper + + def _get_func_doc(self, func): + try: + return getattr(func, 'func_doc').strip() + except AttributeError: + return '' + + def _convert_args(self, args): + if args: + if len(args) > 1: + arguments = str(args).rstrip(')').lstrip('(') + else: + arguments = str(args).rstrip(',)').lstrip('(') + else: + arguments = '' + return arguments + + +# +# FIXME this is getting ridiculous +# + +class FakeRemoteLogger: + + def error(self, *a, **kw): + pass + + def debug(self, *a, **kw): + pass + + def info(self, *a, **kw): + pass + + def warning(self, *a, **kw): + pass + + +def needs_ssh(hostname, _socket=None): + """ + Obtains remote hostname of the socket and cuts off the domain part + of its FQDN. + """ + _socket = _socket or socket + local_hostname = _socket.gethostname() + local_short_hostname = local_hostname.split('.')[0] + if local_hostname == hostname or local_short_hostname == hostname: + return False + return True diff --git a/remoto/exc.py b/remoto/exc.py new file mode 100644 index 0000000..9b3acf9 --- /dev/null +++ b/remoto/exc.py @@ -0,0 +1,6 @@ +from .lib import execnet + +HostNotFound = execnet.HostNotFound +RemoteError = execnet.RemoteError +TimeoutError = execnet.TimeoutError +DataFormatError = execnet.DataFormatError diff --git a/remoto/lib/__init__.py b/remoto/lib/__init__.py new file mode 100644 index 0000000..3d64ef0 --- /dev/null +++ b/remoto/lib/__init__.py @@ -0,0 +1,7 @@ +import sys +import os +this_dir = os.path.abspath(os.path.dirname(__file__)) + +if this_dir not in sys.path: + sys.path.insert(0, this_dir) +import execnet diff --git a/remoto/lib/execnet/__init__.py b/remoto/lib/execnet/__init__.py new file mode 100644 index 0000000..45309e7 --- /dev/null +++ b/remoto/lib/execnet/__init__.py @@ -0,0 +1,26 @@ +""" +execnet: pure python lib for connecting to local and remote Python Interpreters. + +(c) 2012, Holger Krekel and others +""" +__version__ = '1.1' + +from . import apipkg + +apipkg.initpkg(__name__, { + 'PopenGateway': '.deprecated:PopenGateway', + 'SocketGateway': '.deprecated:SocketGateway', + 'SshGateway': '.deprecated:SshGateway', + 'makegateway': '.multi:makegateway', + 'HostNotFound': '.gateway_bootstrap:HostNotFound', + 'RemoteError': '.gateway_base:RemoteError', + 'TimeoutError': '.gateway_base:TimeoutError', + 'XSpec': '.xspec:XSpec', + 'Group': '.multi:Group', + 'MultiChannel': '.multi:MultiChannel', + 'RSync': '.rsync:RSync', + 'default_group': '.multi:default_group', + 'dumps': '.gateway_base:dumps', + 'loads': '.gateway_base:loads', + 'DataFormatError': '.gateway_base:DataFormatError', +}) diff --git a/remoto/lib/execnet/apipkg.py b/remoto/lib/execnet/apipkg.py new file mode 100644 index 0000000..a4576c0 --- /dev/null +++ b/remoto/lib/execnet/apipkg.py @@ -0,0 +1,167 @@ +""" +apipkg: control the exported namespace of a python package. + +see http://pypi.python.org/pypi/apipkg + +(c) holger krekel, 2009 - MIT license +""" +import os +import sys +from types import ModuleType + +__version__ = '1.2' + +def initpkg(pkgname, exportdefs, attr=dict()): + """ initialize given package from the export definitions. """ + oldmod = sys.modules.get(pkgname) + d = {} + f = getattr(oldmod, '__file__', None) + if f: + f = os.path.abspath(f) + d['__file__'] = f + if hasattr(oldmod, '__version__'): + d['__version__'] = oldmod.__version__ + if hasattr(oldmod, '__loader__'): + d['__loader__'] = oldmod.__loader__ + if hasattr(oldmod, '__path__'): + d['__path__'] = [os.path.abspath(p) for p in oldmod.__path__] + if '__doc__' not in exportdefs and getattr(oldmod, '__doc__', None): + d['__doc__'] = oldmod.__doc__ + d.update(attr) + if hasattr(oldmod, "__dict__"): + oldmod.__dict__.update(d) + mod = ApiModule(pkgname, exportdefs, implprefix=pkgname, attr=d) + sys.modules[pkgname] = mod + +def importobj(modpath, attrname): + module = __import__(modpath, None, None, ['__doc__']) + if not attrname: + return module + + retval = module + names = attrname.split(".") + for x in names: + retval = getattr(retval, x) + return retval + +class ApiModule(ModuleType): + def __docget(self): + try: + return self.__doc + except AttributeError: + if '__doc__' in self.__map__: + return self.__makeattr('__doc__') + def __docset(self, value): + self.__doc = value + __doc__ = property(__docget, __docset) + + def __init__(self, name, importspec, implprefix=None, attr=None): + self.__name__ = name + self.__all__ = [x for x in importspec if x != '__onfirstaccess__'] + self.__map__ = {} + self.__implprefix__ = implprefix or name + if attr: + for name, val in attr.items(): + #print "setting", self.__name__, name, val + setattr(self, name, val) + for name, importspec in importspec.items(): + if isinstance(importspec, dict): + subname = '%s.%s'%(self.__name__, name) + apimod = ApiModule(subname, importspec, implprefix) + sys.modules[subname] = apimod + setattr(self, name, apimod) + else: + parts = importspec.split(':') + modpath = parts.pop(0) + attrname = parts and parts[0] or "" + if modpath[0] == '.': + modpath = implprefix + modpath + + if not attrname: + subname = '%s.%s'%(self.__name__, name) + apimod = AliasModule(subname, modpath) + sys.modules[subname] = apimod + if '.' not in name: + setattr(self, name, apimod) + else: + self.__map__[name] = (modpath, attrname) + + def __repr__(self): + l = [] + if hasattr(self, '__version__'): + l.append("version=" + repr(self.__version__)) + if hasattr(self, '__file__'): + l.append('from ' + repr(self.__file__)) + if l: + return '' % (self.__name__, " ".join(l)) + return '' % (self.__name__,) + + def __makeattr(self, name): + """lazily compute value for name or raise AttributeError if unknown.""" + #print "makeattr", self.__name__, name + target = None + if '__onfirstaccess__' in self.__map__: + target = self.__map__.pop('__onfirstaccess__') + importobj(*target)() + try: + modpath, attrname = self.__map__[name] + except KeyError: + if target is not None and name != '__onfirstaccess__': + # retry, onfirstaccess might have set attrs + return getattr(self, name) + raise AttributeError(name) + else: + result = importobj(modpath, attrname) + setattr(self, name, result) + try: + del self.__map__[name] + except KeyError: + pass # in a recursive-import situation a double-del can happen + return result + + __getattr__ = __makeattr + + def __dict__(self): + # force all the content of the module to be loaded when __dict__ is read + dictdescr = ModuleType.__dict__['__dict__'] + dict = dictdescr.__get__(self) + if dict is not None: + hasattr(self, 'some') + for name in self.__all__: + try: + self.__makeattr(name) + except AttributeError: + pass + return dict + __dict__ = property(__dict__) + + +def AliasModule(modname, modpath, attrname=None): + mod = [] + + def getmod(): + if not mod: + x = importobj(modpath, None) + if attrname is not None: + x = getattr(x, attrname) + mod.append(x) + return mod[0] + + class AliasModule(ModuleType): + + def __repr__(self): + x = modpath + if attrname: + x += "." + attrname + return '' % (modname, x) + + def __getattribute__(self, name): + return getattr(getmod(), name) + + def __setattr__(self, name, value): + setattr(getmod(), name, value) + + def __delattr__(self, name): + delattr(getmod(), name) + + return AliasModule(modname) diff --git a/remoto/lib/execnet/deprecated.py b/remoto/lib/execnet/deprecated.py new file mode 100644 index 0000000..aef4626 --- /dev/null +++ b/remoto/lib/execnet/deprecated.py @@ -0,0 +1,43 @@ +""" +some deprecated calls + +(c) 2008-2009, Holger Krekel and others +""" +import execnet + +def PopenGateway(python=None): + """ instantiate a gateway to a subprocess + started with the given 'python' executable. + """ + APIWARN("1.0.0b4", "use makegateway('popen')") + spec = execnet.XSpec("popen") + spec.python = python + return execnet.default_group.makegateway(spec) + +def SocketGateway(host, port): + """ This Gateway provides interaction with a remote process + by connecting to a specified socket. On the remote + side you need to manually start a small script + (py/execnet/script/socketserver.py) that accepts + SocketGateway connections or use the experimental + new_remote() method on existing gateways. + """ + APIWARN("1.0.0b4", "use makegateway('socket=host:port')") + spec = execnet.XSpec("socket=%s:%s" %(host, port)) + return execnet.default_group.makegateway(spec) + +def SshGateway(sshaddress, remotepython=None, ssh_config=None): + """ instantiate a remote ssh process with the + given 'sshaddress' and remotepython version. + you may specify an ssh_config file. + """ + APIWARN("1.0.0b4", "use makegateway('ssh=host')") + spec = execnet.XSpec("ssh=%s" % sshaddress) + spec.python = remotepython + spec.ssh_config = ssh_config + return execnet.default_group.makegateway(spec) + +def APIWARN(version, msg, stacklevel=3): + import warnings + Warn = DeprecationWarning("(since version %s) %s" %(version, msg)) + warnings.warn(Warn, stacklevel=stacklevel) diff --git a/remoto/lib/execnet/gateway.py b/remoto/lib/execnet/gateway.py new file mode 100644 index 0000000..b2889d7 --- /dev/null +++ b/remoto/lib/execnet/gateway.py @@ -0,0 +1,211 @@ +""" +gateway code for initiating popen, socket and ssh connections. +(c) 2004-2009, Holger Krekel and others +""" + +import sys, os, inspect, types, linecache +import textwrap +import execnet +from execnet.gateway_base import Message +from execnet.gateway_io import Popen2IOMaster +from execnet import gateway_base +importdir = os.path.dirname(os.path.dirname(execnet.__file__)) + +class Gateway(gateway_base.BaseGateway): + """ Gateway to a local or remote Python Intepreter. """ + + def __init__(self, io, id): + super(Gateway, self).__init__(io=io, id=id, _startcount=1) + self._initreceive() + + @property + def remoteaddress(self): + return self._io.remoteaddress + + def __repr__(self): + """ return string representing gateway type and status. """ + try: + r = (self.hasreceiver() and 'receive-live' or 'not-receiving') + i = len(self._channelfactory.channels()) + except AttributeError: + r = "uninitialized" + i = "no" + return "<%s id=%r %s, %s active channels>" %( + self.__class__.__name__, self.id, r, i) + + def exit(self): + """ trigger gateway exit. Defer waiting for finishing + of receiver-thread and subprocess activity to when + group.terminate() is called. + """ + self._trace("gateway.exit() called") + if self not in self._group: + self._trace("gateway already unregistered with group") + return + self._group._unregister(self) + self._trace("--> sending GATEWAY_TERMINATE") + try: + self._send(Message.GATEWAY_TERMINATE) + self._io.close_write() + except IOError: + v = sys.exc_info()[1] + self._trace("io-error: could not send termination sequence") + self._trace(" exception: %r" % v) + + def reconfigure(self, py2str_as_py3str=True, py3str_as_py2str=False): + """ + set the string coercion for this gateway + the default is to try to convert py2 str as py3 str, + but not to try and convert py3 str to py2 str + """ + self._strconfig = (py2str_as_py3str, py3str_as_py2str) + data = gateway_base.dumps_internal(self._strconfig) + self._send(Message.RECONFIGURE, data=data) + + + def _rinfo(self, update=False): + """ return some sys/env information from remote. """ + if update or not hasattr(self, '_cache_rinfo'): + ch = self.remote_exec(rinfo_source) + self._cache_rinfo = RInfo(ch.receive()) + return self._cache_rinfo + + def hasreceiver(self): + """ return True if gateway is able to receive data. """ + return self._receiverthread.isAlive() # approxmimation + + def remote_status(self): + """ return information object about remote execution status. """ + channel = self.newchannel() + self._send(Message.STATUS, channel.id) + statusdict = channel.receive() + # the other side didn't actually instantiate a channel + # so we just delete the internal id/channel mapping + self._channelfactory._local_close(channel.id) + return RemoteStatus(statusdict) + + def remote_exec(self, source, **kwargs): + """ return channel object and connect it to a remote + execution thread where the given ``source`` executes. + + * ``source`` is a string: execute source string remotely + with a ``channel`` put into the global namespace. + * ``source`` is a pure function: serialize source and + call function with ``**kwargs``, adding a + ``channel`` object to the keyword arguments. + * ``source`` is a pure module: execute source of module + with a ``channel`` in its global namespace + + In all cases the binding ``__name__='__channelexec__'`` + will be available in the global namespace of the remotely + executing code. + """ + call_name = None + if isinstance(source, types.ModuleType): + linecache.updatecache(inspect.getsourcefile(source)) + source = inspect.getsource(source) + elif isinstance(source, types.FunctionType): + call_name = source.__name__ + source = _source_of_function(source) + else: + source = textwrap.dedent(str(source)) + + if call_name is None and kwargs: + raise TypeError("can't pass kwargs to non-function remote_exec") + + channel = self.newchannel() + self._send(Message.CHANNEL_EXEC, + channel.id, + gateway_base.dumps_internal((source, call_name, kwargs))) + return channel + + def remote_init_threads(self, num=None): + """ start up to 'num' threads for subsequent + remote_exec() invocations to allow concurrent + execution. + """ + if hasattr(self, '_remotechannelthread'): + raise IOError("remote threads already running") + from execnet import threadpool + source = inspect.getsource(threadpool) + self._remotechannelthread = self.remote_exec(source) + self._remotechannelthread.send(num) + status = self._remotechannelthread.receive() + assert status == "ok", status + +class RInfo: + def __init__(self, kwargs): + self.__dict__.update(kwargs) + + def __repr__(self): + info = ", ".join(["%s=%s" % item + for item in self.__dict__.items()]) + return "" % info + +RemoteStatus = RInfo + +def rinfo_source(channel): + import sys, os + channel.send(dict( + executable = sys.executable, + version_info = sys.version_info[:5], + platform = sys.platform, + cwd = os.getcwd(), + pid = os.getpid(), + )) + + +def _find_non_builtin_globals(source, codeobj): + try: + import ast + except ImportError: + return None + try: + import __builtin__ + except ImportError: + import builtins as __builtin__ + + vars = dict.fromkeys(codeobj.co_varnames) + all = [] + for node in ast.walk(ast.parse(source)): + if (isinstance(node, ast.Name) and node.id not in vars and + node.id not in __builtin__.__dict__): + all.append(node.id) + return all + + +def _source_of_function(function): + if function.__name__ == '': + raise ValueError("can't evaluate lambda functions'") + #XXX: we dont check before remote instanciation + # if arguments are used propperly + args, varargs, keywords, defaults = inspect.getargspec(function) + if args[0] != 'channel': + raise ValueError('expected first function argument to be `channel`') + + if sys.version_info < (3,0): + closure = function.func_closure + codeobj = function.func_code + else: + closure = function.__closure__ + codeobj = function.__code__ + + if closure is not None: + raise ValueError("functions with closures can't be passed") + + try: + source = inspect.getsource(function) + except IOError: + raise ValueError("can't find source file for %s" % function) + + source = textwrap.dedent(source) # just for inner functions + + used_globals = _find_non_builtin_globals(source, codeobj) + if used_globals: + raise ValueError( + "the use of non-builtin globals isn't supported", + used_globals, + ) + + return source + diff --git a/remoto/lib/execnet/gateway_base.py b/remoto/lib/execnet/gateway_base.py new file mode 100644 index 0000000..f60b65b --- /dev/null +++ b/remoto/lib/execnet/gateway_base.py @@ -0,0 +1,1215 @@ +""" +base execnet gateway code send to the other side for bootstrapping. + +NOTE: aims to be compatible to Python 2.3-3.1, Jython and IronPython + +(C) 2004-2009 Holger Krekel, Armin Rigo, Benjamin Peterson, and others +""" +import sys, os, weakref +import threading, traceback, struct +try: + import queue +except ImportError: + import Queue as queue + +try: + from io import BytesIO +except: + from StringIO import StringIO as BytesIO + +ISPY3 = sys.version_info >= (3, 0) +if ISPY3: + exec("def do_exec(co, loc): exec(co, loc)\n" + "def reraise(cls, val, tb): raise val\n") + unicode = str + _long_type = int + from _thread import interrupt_main +else: + exec("def do_exec(co, loc): exec co in loc\n" + "def reraise(cls, val, tb): raise cls, val, tb\n") + bytes = str + _long_type = long + try: + from thread import interrupt_main + except ImportError: + interrupt_main = None + +sysex = (KeyboardInterrupt, SystemExit) + + +DEBUG = os.environ.get('EXECNET_DEBUG') +pid = os.getpid() +if DEBUG == '2': + def trace(*msg): + try: + line = " ".join(map(str, msg)) + sys.stderr.write("[%s] %s\n" % (pid, line)) + sys.stderr.flush() + except Exception: + pass # nothing we can do, likely interpreter-shutdown +elif DEBUG: + import tempfile, os.path + fn = os.path.join(tempfile.gettempdir(), 'execnet-debug-%d' % pid) + #sys.stderr.write("execnet-debug at %r" %(fn,)) + debugfile = open(fn, 'w') + def trace(*msg): + try: + line = " ".join(map(str, msg)) + debugfile.write(line + "\n") + debugfile.flush() + except Exception: + try: + v = sys.exc_info()[1] + sys.stderr.write( + "[%s] exception during tracing: %r\n" % (pid, v)) + except Exception: + pass # nothing we can do, likely interpreter-shutdown +else: + notrace = trace = lambda *msg: None + +class Popen2IO: + error = (IOError, OSError, EOFError) + + def __init__(self, outfile, infile): + # we need raw byte streams + self.outfile, self.infile = outfile, infile + if sys.platform == "win32": + import msvcrt + try: + msvcrt.setmode(infile.fileno(), os.O_BINARY) + msvcrt.setmode(outfile.fileno(), os.O_BINARY) + except (AttributeError, IOError): + pass + self._read = getattr(infile, "buffer", infile).read + self._write = getattr(outfile, "buffer", outfile).write + + def read(self, numbytes): + """Read exactly 'numbytes' bytes from the pipe. """ + # a file in non-blocking mode may return less bytes, so we loop + buf = bytes() + while numbytes > len(buf): + data = self._read(numbytes-len(buf)) + if not data: + raise EOFError("expected %d bytes, got %d" %(numbytes, len(buf))) + buf += data + return buf + + def write(self, data): + """write out all data bytes. """ + assert isinstance(data, bytes) + self._write(data) + self.outfile.flush() + + def close_read(self): + self.infile.close() + + def close_write(self): + self.outfile.close() + +class Message: + """ encapsulates Messages and their wire protocol. """ + _types = [] + + def __init__(self, msgcode, channelid=0, data=''): + self.msgcode = msgcode + self.channelid = channelid + self.data = data + + @staticmethod + def from_io(io): + try: + header = io.read(9) # type 1, channel 4, payload 4 + except EOFError: + e = sys.exc_info()[1] + raise EOFError('couldnt load message header, ' + e.args[0]) + msgtype, channel, payload = struct.unpack('!bii', header) + return Message(msgtype, channel, io.read(payload)) + + def to_io(self, io): + header = struct.pack('!bii', self.msgcode, self.channelid, len(self.data)) + io.write(header+self.data) + + def received(self, gateway): + self._types[self.msgcode](self, gateway) + + def __repr__(self): + class FakeChannel(object): + _strconfig = False, False # never transform, never fail + def __init__(self, id): + self.id = id + def __repr__(self): + return '' % self.id + FakeChannel.new = FakeChannel + FakeChannel.gateway = FakeChannel + name = self._types[self.msgcode].__name__.upper() + try: + data = loads_internal(self.data, FakeChannel) + except LoadError: + data = self.data + r = repr(data) + if len(r) > 90: + return "" %(name, + self.channelid, len(r)) + else: + return "" %(name, + self.channelid, r) + +def _setupmessages(): + def status(message, gateway): + # we use the channelid to send back information + # but don't instantiate a channel object + active_channels = gateway._channelfactory.channels() + numexec = 0 + for ch in active_channels: + if getattr(ch, '_executing', False): + numexec += 1 + d = {'execqsize': gateway._execqueue.qsize(), + 'numchannels': len(active_channels), + 'numexecuting': numexec + } + gateway._send(Message.CHANNEL_DATA, message.channelid, dumps_internal(d)) + + def channel_exec(message, gateway): + channel = gateway._channelfactory.new(message.channelid) + gateway._local_schedulexec(channel=channel,sourcetask=message.data) + + def channel_data(message, gateway): + gateway._channelfactory._local_receive(message.channelid, message.data) + + def channel_close(message, gateway): + gateway._channelfactory._local_close(message.channelid) + + def channel_close_error(message, gateway): + remote_error = RemoteError(loads_internal(message.data)) + gateway._channelfactory._local_close(message.channelid, remote_error) + + def channel_last_message(message, gateway): + gateway._channelfactory._local_close(message.channelid, sendonly=True) + + def gateway_terminate(message, gateway): + gateway._terminate_execution() + raise SystemExit(0) + + def reconfigure(message, gateway): + if message.channelid == 0: + target = gateway + else: + target = gateway._channelfactory.new(message.channelid) + target._strconfig = loads_internal(message.data, gateway) + + types = [ + status, reconfigure, gateway_terminate, + channel_exec, channel_data, channel_close, + channel_close_error, channel_last_message, + ] + for i, handler in enumerate(types): + Message._types.append(handler) + setattr(Message, handler.__name__.upper(), i) + +_setupmessages() + +def geterrortext(excinfo, + format_exception=traceback.format_exception, sysex=sysex): + try: + l = format_exception(*excinfo) + errortext = "".join(l) + except sysex: + raise + except: + errortext = '%s: %s' % (excinfo[0].__name__, + excinfo[1]) + return errortext + +class RemoteError(Exception): + """ Exception containing a stringified error from the other side. """ + def __init__(self, formatted): + self.formatted = formatted + Exception.__init__(self) + + def __str__(self): + return self.formatted + + def __repr__(self): + return "%s: %s" %(self.__class__.__name__, self.formatted) + + def warn(self): + if self.formatted != INTERRUPT_TEXT: + # XXX do this better + sys.stderr.write("Warning: unhandled %r\n" % (self,)) + +class TimeoutError(IOError): + """ Exception indicating that a timeout was reached. """ + + +NO_ENDMARKER_WANTED = object() + +class Channel(object): + """Communication channel between two Python Interpreter execution points.""" + RemoteError = RemoteError + TimeoutError = TimeoutError + _INTERNALWAKEUP = 1000 + _executing = False + + def __init__(self, gateway, id): + assert isinstance(id, int) + self.gateway = gateway + #XXX: defaults copied from Unserializer + self._strconfig = getattr(gateway, '_strconfig', (True, False)) + self.id = id + self._items = queue.Queue() + self._closed = False + self._receiveclosed = threading.Event() + self._remoteerrors = [] + + def _trace(self, *msg): + self.gateway._trace(self.id, *msg) + + def setcallback(self, callback, endmarker=NO_ENDMARKER_WANTED): + """ set a callback function for receiving items. + + All already queued items will immediately trigger the callback. + Afterwards the callback will execute in the receiver thread + for each received data item and calls to ``receive()`` will + raise an error. + If an endmarker is specified the callback will eventually + be called with the endmarker when the channel closes. + """ + _callbacks = self.gateway._channelfactory._callbacks + _receivelock = self.gateway._receivelock + _receivelock.acquire() + try: + if self._items is None: + raise IOError("%r has callback already registered" %(self,)) + items = self._items + self._items = None + while 1: + try: + olditem = items.get(block=False) + except queue.Empty: + if not (self._closed or self._receiveclosed.isSet()): + _callbacks[self.id] = ( + callback, + endmarker, + self._strconfig, + ) + break + else: + if olditem is ENDMARKER: + items.put(olditem) # for other receivers + if endmarker is not NO_ENDMARKER_WANTED: + callback(endmarker) + break + else: + callback(olditem) + finally: + _receivelock.release() + + def __repr__(self): + flag = self.isclosed() and "closed" or "open" + return "" % (self.id, flag) + + def __del__(self): + if self.gateway is None: # can be None in tests + return + self._trace("channel.__del__") + # no multithreading issues here, because we have the last ref to 'self' + if self._closed: + # state transition "closed" --> "deleted" + for error in self._remoteerrors: + error.warn() + elif self._receiveclosed.isSet(): + # state transition "sendonly" --> "deleted" + # the remote channel is already in "deleted" state, nothing to do + pass + else: + # state transition "opened" --> "deleted" + if self._items is None: # has_callback + msgcode = Message.CHANNEL_LAST_MESSAGE + else: + msgcode = Message.CHANNEL_CLOSE + try: + self.gateway._send(msgcode, self.id) + except (IOError, ValueError): # ignore problems with sending + pass + + def _getremoteerror(self): + try: + return self._remoteerrors.pop(0) + except IndexError: + try: + return self.gateway._error + except AttributeError: + pass + return None + + # + # public API for channel objects + # + def isclosed(self): + """ return True if the channel is closed. A closed + channel may still hold items. + """ + return self._closed + + def makefile(self, mode='w', proxyclose=False): + """ return a file-like object. + mode can be 'w' or 'r' for writeable/readable files. + if proxyclose is true file.close() will also close the channel. + """ + if mode == "w": + return ChannelFileWrite(channel=self, proxyclose=proxyclose) + elif mode == "r": + return ChannelFileRead(channel=self, proxyclose=proxyclose) + raise ValueError("mode %r not availabe" %(mode,)) + + def close(self, error=None): + """ close down this channel with an optional error message. + Note that closing of a channel tied to remote_exec happens + automatically at the end of execution and cannot be done explicitely. + """ + if self._executing: + raise IOError("cannot explicitly close channel within remote_exec") + if self._closed: + self.gateway._trace(self, "ignoring redundant call to close()") + if not self._closed: + # state transition "opened/sendonly" --> "closed" + # threads warning: the channel might be closed under our feet, + # but it's never damaging to send too many CHANNEL_CLOSE messages + # however, if the other side triggered a close already, we + # do not send back a closed message. + if not self._receiveclosed.isSet(): + put = self.gateway._send + if error is not None: + put(Message.CHANNEL_CLOSE_ERROR, self.id, dumps_internal(error)) + else: + put(Message.CHANNEL_CLOSE, self.id) + self._trace("sent channel close message") + if isinstance(error, RemoteError): + self._remoteerrors.append(error) + self._closed = True # --> "closed" + self._receiveclosed.set() + queue = self._items + if queue is not None: + queue.put(ENDMARKER) + self.gateway._channelfactory._no_longer_opened(self.id) + + def waitclose(self, timeout=None): + """ wait until this channel is closed (or the remote side + otherwise signalled that no more data was being sent). + The channel may still hold receiveable items, but not receive + any more after waitclose() has returned. Exceptions from executing + code on the other side are reraised as local channel.RemoteErrors. + EOFError is raised if the reading-connection was prematurely closed, + which often indicates a dying process. + self.TimeoutError is raised after the specified number of seconds + (default is None, i.e. wait indefinitely). + """ + self._receiveclosed.wait(timeout=timeout) # wait for non-"opened" state + if not self._receiveclosed.isSet(): + raise self.TimeoutError("Timeout after %r seconds" % timeout) + error = self._getremoteerror() + if error: + raise error + + def send(self, item): + """sends the given item to the other side of the channel, + possibly blocking if the sender queue is full. + The item must be a simple python type and will be + copied to the other side by value. IOError is + raised if the write pipe was prematurely closed. + """ + if self.isclosed(): + raise IOError("cannot send to %r" %(self,)) + self.gateway._send(Message.CHANNEL_DATA, self.id, dumps_internal(item)) + + def receive(self, timeout=-1): + """receive a data item that was sent from the other side. + timeout: -1 [default] blocked waiting, but wake up periodically + to let CTRL-C through. A positive number indicates the + number of seconds after which a channel.TimeoutError exception + will be raised if no item was received. + Note that exceptions from the remotely executing code will be + reraised as channel.RemoteError exceptions containing + a textual representation of the remote traceback. + """ + itemqueue = self._items + if itemqueue is None: + raise IOError("cannot receive(), channel has receiver callback") + if timeout < 0: + internal_timeout = self._INTERNALWAKEUP + else: + internal_timeout = timeout + + while 1: + try: + x = itemqueue.get(timeout=internal_timeout) + break + except queue.Empty: + if timeout < 0: + continue + raise self.TimeoutError("no item after %r seconds" %(timeout)) + if x is ENDMARKER: + itemqueue.put(x) # for other receivers + raise self._getremoteerror() or EOFError() + else: + return x + + def __iter__(self): + return self + + def next(self): + try: + return self.receive() + except EOFError: + raise StopIteration + __next__ = next + + + def reconfigure(self, py2str_as_py3str=True, py3str_as_py2str=False): + """ + set the string coercion for this channel + the default is to try to convert py2 str as py3 str, + but not to try and convert py3 str to py2 str + """ + self._strconfig = (py2str_as_py3str, py3str_as_py2str) + data = dumps_internal(self._strconfig) + self.gateway._send(Message.RECONFIGURE, self.id, data=data) + +ENDMARKER = object() +INTERRUPT_TEXT = "keyboard-interrupted" + +class ChannelFactory(object): + def __init__(self, gateway, startcount=1): + self._channels = weakref.WeakValueDictionary() + self._callbacks = {} + self._writelock = threading.Lock() + self.gateway = gateway + self.count = startcount + self.finished = False + self._list = list # needed during interp-shutdown + + def new(self, id=None): + """ create a new Channel with 'id' (or create new id if None). """ + self._writelock.acquire() + try: + if self.finished: + raise IOError("connexion already closed: %s" % (self.gateway,)) + if id is None: + id = self.count + self.count += 2 + try: + channel = self._channels[id] + except KeyError: + channel = self._channels[id] = Channel(self.gateway, id) + return channel + finally: + self._writelock.release() + + def channels(self): + return self._list(self._channels.values()) + + # + # internal methods, called from the receiver thread + # + def _no_longer_opened(self, id): + try: + del self._channels[id] + except KeyError: + pass + try: + callback, endmarker, strconfig = self._callbacks.pop(id) + except KeyError: + pass + else: + if endmarker is not NO_ENDMARKER_WANTED: + callback(endmarker) + + def _local_close(self, id, remoteerror=None, sendonly=False): + channel = self._channels.get(id) + if channel is None: + # channel already in "deleted" state + if remoteerror: + remoteerror.warn() + self._no_longer_opened(id) + else: + # state transition to "closed" state + if remoteerror: + channel._remoteerrors.append(remoteerror) + queue = channel._items + if queue is not None: + queue.put(ENDMARKER) + self._no_longer_opened(id) + if not sendonly: # otherwise #--> "sendonly" + channel._closed = True # --> "closed" + channel._receiveclosed.set() + + def _local_receive(self, id, data): + # executes in receiver thread + try: + callback, endmarker, strconfig= self._callbacks[id] + channel = self._channels.get(id) + except KeyError: + channel = self._channels.get(id) + queue = channel and channel._items + if queue is None: + pass # drop data + else: + queue.put(loads_internal(data, channel)) + else: + try: + data = loads_internal(data, channel, strconfig) + callback(data) # even if channel may be already closed + except KeyboardInterrupt: + raise + except: + excinfo = sys.exc_info() + self.gateway._trace("exception during callback: %s" % excinfo[1]) + errortext = self.gateway._geterrortext(excinfo) + self.gateway._send(Message.CHANNEL_CLOSE_ERROR, id, dumps_internal(errortext)) + self._local_close(id, errortext) + + def _finished_receiving(self): + self._writelock.acquire() + try: + self.finished = True + finally: + self._writelock.release() + for id in self._list(self._channels): + self._local_close(id, sendonly=True) + for id in self._list(self._callbacks): + self._no_longer_opened(id) + +class ChannelFile(object): + def __init__(self, channel, proxyclose=True): + self.channel = channel + self._proxyclose = proxyclose + + def isatty(self): + return False + + def close(self): + if self._proxyclose: + self.channel.close() + + def __repr__(self): + state = self.channel.isclosed() and 'closed' or 'open' + return '' %(self.channel.id, state) + +class ChannelFileWrite(ChannelFile): + def write(self, out): + self.channel.send(out) + + def flush(self): + pass + +class ChannelFileRead(ChannelFile): + def __init__(self, channel, proxyclose=True): + super(ChannelFileRead, self).__init__(channel, proxyclose) + self._buffer = None + + def read(self, n): + try: + if self._buffer is None: + self._buffer = self.channel.receive() + while len(self._buffer) < n: + self._buffer += self.channel.receive() + except EOFError: + self.close() + if self._buffer is None: + ret = "" + else: + ret = self._buffer[:n] + self._buffer = self._buffer[n:] + return ret + + def readline(self): + if self._buffer is not None: + i = self._buffer.find("\n") + if i != -1: + return self.read(i+1) + line = self.read(len(self._buffer)+1) + else: + line = self.read(1) + while line and line[-1] != "\n": + c = self.read(1) + if not c: + break + line += c + return line + +class BaseGateway(object): + exc_info = sys.exc_info + _sysex = sysex + id = "" + + class _StopExecLoop(Exception): + pass + + def __init__(self, io, id, _startcount=2): + self._io = io + self.id = id + self._strconfig = Unserializer.py2str_as_py3str, Unserializer.py3str_as_py2str + self._channelfactory = ChannelFactory(self, _startcount) + self._receivelock = threading.RLock() + # globals may be NONE at process-termination + self.__trace = trace + self._geterrortext = geterrortext + + def _trace(self, *msg): + self.__trace(self.id, *msg) + + def _initreceive(self): + self._receiverthread = threading.Thread(name="receiver", + target=self._thread_receiver) + self._receiverthread.setDaemon(1) + self._receiverthread.start() + + def _thread_receiver(self): + self._trace("RECEIVERTHREAD: starting to run") + eof = False + io = self._io + try: + try: + while 1: + msg = Message.from_io(io) + self._trace("received", msg) + _receivelock = self._receivelock + _receivelock.acquire() + try: + msg.received(self) + del msg + finally: + _receivelock.release() + except self._sysex: + self._trace("RECEIVERTHREAD: doing io.close_read()") + self._io.close_read() + except EOFError: + self._trace("RECEIVERTHREAD: got EOFError") + self._trace("RECEIVERTHREAD: traceback was: ", + self._geterrortext(self.exc_info())) + self._error = self.exc_info()[1] + eof = True + except: + self._trace("RECEIVERTHREAD", self._geterrortext(self.exc_info())) + finally: + try: + self._trace('RECEIVERTHREAD', 'entering finalization') + if eof: + self._terminate_execution() + self._channelfactory._finished_receiving() + self._trace('RECEIVERTHREAD', 'leaving finalization') + except: + pass # XXX be silent at interp-shutdown + + def _terminate_execution(self): + pass + + def _send(self, msgcode, channelid=0, data=bytes()): + message = Message(msgcode, channelid, data) + try: + message.to_io(self._io) + self._trace('sent', message) + except (IOError, ValueError): + e = sys.exc_info()[1] + self._trace('failed to send', message, e) + raise + + + def _local_schedulexec(self, channel, sourcetask): + channel.close("execution disallowed") + + # _____________________________________________________________________ + # + # High Level Interface + # _____________________________________________________________________ + # + def newchannel(self): + """ return a new independent channel. """ + return self._channelfactory.new() + + def join(self, timeout=None): + """ Wait for receiverthread to terminate. """ + current = threading.currentThread() + if self._receiverthread.isAlive(): + self._trace("joining receiver thread") + self._receiverthread.join(timeout) + else: + self._trace("gateway.join() called while receiverthread " + "already finished") + +class SlaveGateway(BaseGateway): + def _local_schedulexec(self, channel, sourcetask): + sourcetask = loads_internal(sourcetask) + self._execqueue.put((channel, sourcetask)) + + def _terminate_execution(self): + # called from receiverthread + self._trace("putting None to execqueue") + self._execqueue.put(None) + if interrupt_main: + self._trace("calling interrupt_main()") + interrupt_main() + self._execfinished.wait(10.0) + if not self._execfinished.isSet(): + self._trace("execution did not finish in 10 secs, calling os._exit()") + os._exit(1) + + def serve(self, joining=True): + try: + try: + self._execqueue = queue.Queue() + self._execfinished = threading.Event() + self._initreceive() + while 1: + item = self._execqueue.get() + if item is None: + break + try: + self.executetask(item) + except self._StopExecLoop: + break + finally: + self._execfinished.set() + self._trace("io.close_write()") + self._io.close_write() + self._trace("slavegateway.serve finished") + if joining: + self.join() + except KeyboardInterrupt: + # in the slave we can't really do anything sensible + self._trace("swallowing keyboardinterrupt in main-thread") + + def executetask(self, item): + try: + channel, (source, call_name, kwargs) = item + if not ISPY3 and kwargs: + # some python2 versions do not accept unicode keyword params + # note: Unserializer generally turns py2-str to py3-str objects + newkwargs = {} + for name, value in kwargs.items(): + if isinstance(name, unicode): + name = name.encode('ascii') + newkwargs[name] = value + kwargs = newkwargs + loc = {'channel' : channel, '__name__': '__channelexec__'} + self._trace("execution starts[%s]: %s" % + (channel.id, repr(source)[:50])) + channel._executing = True + try: + co = compile(source+'\n', '', 'exec') + do_exec(co, loc) + if call_name: + self._trace('calling %s(**%60r)' % (call_name, kwargs)) + function = loc[call_name] + function(channel, **kwargs) + finally: + channel._executing = False + self._trace("execution finished") + except self._StopExecLoop: + channel.close() + raise + except KeyboardInterrupt: + channel.close(INTERRUPT_TEXT) + raise + except: + excinfo = self.exc_info() + self._trace("got exception: %s" % (excinfo[1],)) + errortext = self._geterrortext(excinfo) + channel.close(errortext) + else: + channel.close() + +# +# Cross-Python pickling code, tested from test_serializer.py +# + +class DataFormatError(Exception): + pass + +class DumpError(DataFormatError): + """Error while serializing an object.""" + +class LoadError(DataFormatError): + """Error while unserializing an object.""" + +if ISPY3: + def bchr(n): + return bytes([n]) +else: + bchr = chr + +DUMPFORMAT_VERSION = bchr(1) + +FOUR_BYTE_INT_MAX = 2147483647 + +FLOAT_FORMAT = "!d" +FLOAT_FORMAT_SIZE = struct.calcsize(FLOAT_FORMAT) + +class _Stop(Exception): + pass + +class Unserializer(object): + num2func = {} # is filled after this class definition + py2str_as_py3str = True # True + py3str_as_py2str = False # false means py2 will get unicode + + def __init__(self, stream, channel_or_gateway=None, strconfig=None): + gateway = getattr(channel_or_gateway, 'gateway', channel_or_gateway) + strconfig = getattr(channel_or_gateway, '_strconfig', strconfig) + if strconfig: + self.py2str_as_py3str, self.py3str_as_py2str = strconfig + self.stream = stream + self.channelfactory = getattr(gateway, '_channelfactory', gateway) + + def load(self, versioned=False): + if versioned: + ver = self.stream.read(1) + if ver != DUMPFORMAT_VERSION: + raise LoadError("wrong dumpformat version") + self.stack = [] + try: + while True: + opcode = self.stream.read(1) + if not opcode: + raise EOFError + try: + loader = self.num2func[opcode] + except KeyError: + raise LoadError("unkown opcode %r - " + "wire protocol corruption?" % (opcode,)) + loader(self) + except _Stop: + if len(self.stack) != 1: + raise LoadError("internal unserialization error") + return self.stack.pop(0) + else: + raise LoadError("didn't get STOP") + + def load_none(self): + self.stack.append(None) + + def load_true(self): + self.stack.append(True) + + def load_false(self): + self.stack.append(False) + + def load_int(self): + i = self._read_int4() + self.stack.append(i) + + def load_longint(self): + s = self._read_byte_string() + self.stack.append(int(s)) + + if ISPY3: + load_long = load_int + load_longlong = load_longint + else: + def load_long(self): + i = self._read_int4() + self.stack.append(long(i)) + + def load_longlong(self): + l = self._read_byte_string() + self.stack.append(long(l)) + + def load_float(self): + binary = self.stream.read(FLOAT_FORMAT_SIZE) + self.stack.append(struct.unpack(FLOAT_FORMAT, binary)[0]) + + def _read_int4(self): + return struct.unpack("!i", self.stream.read(4))[0] + + def _read_byte_string(self): + length = self._read_int4() + as_bytes = self.stream.read(length) + return as_bytes + + def load_py3string(self): + as_bytes = self._read_byte_string() + if not ISPY3 and self.py3str_as_py2str: + # XXX Should we try to decode into latin-1? + self.stack.append(as_bytes) + else: + self.stack.append(as_bytes.decode("utf-8")) + + def load_py2string(self): + as_bytes = self._read_byte_string() + if ISPY3 and self.py2str_as_py3str: + s = as_bytes.decode("latin-1") + else: + s = as_bytes + self.stack.append(s) + + def load_bytes(self): + s = self._read_byte_string() + self.stack.append(s) + + def load_unicode(self): + self.stack.append(self._read_byte_string().decode("utf-8")) + + def load_newlist(self): + length = self._read_int4() + self.stack.append([None] * length) + + def load_setitem(self): + if len(self.stack) < 3: + raise LoadError("not enough items for setitem") + value = self.stack.pop() + key = self.stack.pop() + self.stack[-1][key] = value + + def load_newdict(self): + self.stack.append({}) + + def _load_collection(self, type_): + length = self._read_int4() + if length: + res = type_(self.stack[-length:]) + del self.stack[-length:] + self.stack.append(res) + else: + self.stack.append(type_()) + + def load_buildtuple(self): + self._load_collection(tuple) + + def load_set(self): + self._load_collection(set) + + def load_frozenset(self): + self._load_collection(frozenset) + + def load_stop(self): + raise _Stop + + def load_channel(self): + id = self._read_int4() + newchannel = self.channelfactory.new(id) + self.stack.append(newchannel) + +# automatically build opcodes and byte-encoding + +class opcode: + """ container for name -> num mappings. """ + +def _buildopcodes(): + l = [] + for name, func in Unserializer.__dict__.items(): + if name.startswith("load_"): + opname = name[5:].upper() + l.append((opname, func)) + l.sort() + for i,(opname, func) in enumerate(l): + assert i < 26, "xxx" + i = bchr(64+i) + Unserializer.num2func[i] = func + setattr(opcode, opname, i) + +_buildopcodes() + +def dumps(obj): + """ return a serialized bytestring of the given obj. + + The obj and all contained objects must be of a builtin + python type (so nested dicts, sets, etc. are all ok but + not user-level instances). + """ + return _Serializer().save(obj, versioned=True) + +def loads(bytestring, py2str_as_py3str=False, py3str_as_py2str=False): + """ return the object as deserialized from the given bytestring. + + py2str_as_py3str: if true then string (str) objects previously + dumped on Python2 will be loaded as Python3 + strings which really are text objects. + py3str_as_py2str: if true then string (str) objects previously + dumped on Python3 will be loaded as Python2 + strings instead of unicode objects. + + if the bytestring was dumped with an incompatible protocol + version or if the bytestring is corrupted, the + ``execnet.DataFormatError`` will be raised. + """ + strconfig=(py2str_as_py3str, py3str_as_py2str) + io = BytesIO(bytestring) + return Unserializer(io, strconfig=strconfig).load(versioned=True) + +def loads_internal(bytestring, channelfactory=None, strconfig=None): + io = BytesIO(bytestring) + return Unserializer(io, channelfactory, strconfig).load() + +def dumps_internal(obj): + return _Serializer().save(obj) + + +class _Serializer(object): + _dispatch = {} + + def __init__(self): + self._streamlist = [] + + def _write(self, data): + self._streamlist.append(data) + + def save(self, obj, versioned=False): + # calling here is not re-entrant but multiple instances + # may write to the same stream because of the common platform + # atomic-write guaruantee (concurrent writes each happen atomicly) + if versioned: + self._write(DUMPFORMAT_VERSION) + self._save(obj) + self._write(opcode.STOP) + s = type(self._streamlist[0])().join(self._streamlist) + return s + + def _save(self, obj): + tp = type(obj) + try: + dispatch = self._dispatch[tp] + except KeyError: + methodname = 'save_' + tp.__name__ + meth = getattr(self.__class__, methodname, None) + if meth is None: + raise DumpError("can't serialize %s" % (tp,)) + dispatch = self._dispatch[tp] = meth + dispatch(self, obj) + + def save_NoneType(self, non): + self._write(opcode.NONE) + + def save_bool(self, boolean): + if boolean: + self._write(opcode.TRUE) + else: + self._write(opcode.FALSE) + + def save_bytes(self, bytes_): + self._write(opcode.BYTES) + self._write_byte_sequence(bytes_) + + if ISPY3: + def save_str(self, s): + self._write(opcode.PY3STRING) + self._write_unicode_string(s) + else: + def save_str(self, s): + self._write(opcode.PY2STRING) + self._write_byte_sequence(s) + + def save_unicode(self, s): + self._write(opcode.UNICODE) + self._write_unicode_string(s) + + def _write_unicode_string(self, s): + try: + as_bytes = s.encode("utf-8") + except UnicodeEncodeError: + raise DumpError("strings must be utf-8 encodable") + self._write_byte_sequence(as_bytes) + + def _write_byte_sequence(self, bytes_): + self._write_int4(len(bytes_), "string is too long") + self._write(bytes_) + + def _save_integral(self, i, short_op, long_op): + if i <= FOUR_BYTE_INT_MAX: + self._write(short_op) + self._write_int4(i) + else: + self._write(long_op) + self._write_byte_sequence(str(i).rstrip("L").encode("ascii")) + + def save_int(self, i): + self._save_integral(i, opcode.INT, opcode.LONGINT) + + def save_long(self, l): + self._save_integral(l, opcode.LONG, opcode.LONGLONG) + + def save_float(self, flt): + self._write(opcode.FLOAT) + self._write(struct.pack(FLOAT_FORMAT, flt)) + + def _write_int4(self, i, error="int must be less than %i" % + (FOUR_BYTE_INT_MAX,)): + if i > FOUR_BYTE_INT_MAX: + raise DumpError(error) + self._write(struct.pack("!i", i)) + + def save_list(self, L): + self._write(opcode.NEWLIST) + self._write_int4(len(L), "list is too long") + for i, item in enumerate(L): + self._write_setitem(i, item) + + def _write_setitem(self, key, value): + self._save(key) + self._save(value) + self._write(opcode.SETITEM) + + def save_dict(self, d): + self._write(opcode.NEWDICT) + for key, value in d.items(): + self._write_setitem(key, value) + + def save_tuple(self, tup): + for item in tup: + self._save(item) + self._write(opcode.BUILDTUPLE) + self._write_int4(len(tup), "tuple is too long") + + def _write_set(self, s, op): + for item in s: + self._save(item) + self._write(op) + self._write_int4(len(s), "set is too long") + + def save_set(self, s): + self._write_set(s, opcode.SET) + + def save_frozenset(self, s): + self._write_set(s, opcode.FROZENSET) + + def save_Channel(self, channel): + self._write(opcode.CHANNEL) + self._write_int4(channel.id) + +def init_popen_io(): + if not hasattr(os, 'dup'): # jython + io = Popen2IO(sys.stdout, sys.stdin) + import tempfile + sys.stdin = tempfile.TemporaryFile('r') + sys.stdout = tempfile.TemporaryFile('w') + else: + try: + devnull = os.devnull + except AttributeError: + if os.name == 'nt': + devnull = 'NUL' + else: + devnull = '/dev/null' + # stdin + stdin = os.fdopen(os.dup(0), 'r', 1) + fd = os.open(devnull, os.O_RDONLY) + os.dup2(fd, 0) + os.close(fd) + + # stdout + stdout = os.fdopen(os.dup(1), 'w', 1) + fd = os.open(devnull, os.O_WRONLY) + os.dup2(fd, 1) + + # stderr for win32 + if os.name == 'nt': + sys.stderr = os.fdopen(os.dup(2), 'w', 1) + os.dup2(fd, 2) + os.close(fd) + io = Popen2IO(stdout, stdin) + sys.stdin = os.fdopen(0, 'r', 1) + sys.stdout = os.fdopen(1, 'w', 1) + return io + +def serve(io, id): + trace("creating slavegateway on %r" %(io,)) + SlaveGateway(io=io, id=id, _startcount=2).serve() diff --git a/remoto/lib/execnet/gateway_bootstrap.py b/remoto/lib/execnet/gateway_bootstrap.py new file mode 100644 index 0000000..ed125b9 --- /dev/null +++ b/remoto/lib/execnet/gateway_bootstrap.py @@ -0,0 +1,83 @@ +""" +code to initialize the remote side of a gateway once the io is created +""" +import os +import inspect +import execnet +from execnet import gateway_base +from execnet.gateway import Gateway +importdir = os.path.dirname(os.path.dirname(execnet.__file__)) + + +class HostNotFound(Exception): + pass + + +def bootstrap_popen(io, spec): + sendexec(io, + "import sys", + "sys.path.insert(0, %r)" % importdir, + "from execnet.gateway_base import serve, init_popen_io", + "sys.stdout.write('1')", + "sys.stdout.flush()", + "serve(init_popen_io(), id='%s-slave')" % spec.id, + ) + s = io.read(1) + assert s == "1".encode('ascii') + + +def bootstrap_ssh(io, spec): + try: + sendexec(io, + inspect.getsource(gateway_base), + 'io = init_popen_io()', + "io.write('1'.encode('ascii'))", + "serve(io, id='%s-slave')" % spec.id, + ) + s = io.read(1) + assert s == "1".encode('ascii') + except EOFError: + ret = io.wait() + if ret == 255: + raise HostNotFound(io.remoteaddress) + + +def bootstrap_socket(io, id): + #XXX: switch to spec + from execnet.gateway_socket import SocketIO + + sendexec(io, + inspect.getsource(gateway_base), + 'import socket', + inspect.getsource(SocketIO), + "io = SocketIO(clientsock)", + "io.write('1'.encode('ascii'))", + "serve(io, id='%s-slave')" % id, + ) + s = io.read(1) + assert s == "1".encode('ascii') + + +def sendexec(io, *sources): + source = "\n".join(sources) + io.write((repr(source)+ "\n").encode('ascii')) + + +def bootstrap(io, spec): + if spec.popen: + bootstrap_popen(io, spec) + elif spec.ssh: + bootstrap_ssh(io, spec) + elif spec.socket: + bootstrap_socket(io, spec) + else: + raise ValueError('unknown gateway type, cant bootstrap') + gw = Gateway(io, spec.id) + if hasattr(io, 'popen'): + # fix for jython 2.5.1 + if io.popen.pid is None: + io.popen.pid = gw.remote_exec( + "import os; channel.send(os.getpid())").receive() + return gw + + diff --git a/remoto/lib/execnet/gateway_io.py b/remoto/lib/execnet/gateway_io.py new file mode 100644 index 0000000..d8b7111 --- /dev/null +++ b/remoto/lib/execnet/gateway_io.py @@ -0,0 +1,174 @@ +""" +execnet io initialization code + +creates io instances used for gateway io +""" +import os +import sys +from subprocess import Popen, PIPE + +try: + from execnet.gateway_base import Popen2IO, Message +except ImportError: + from __main__ import Popen2IO, Message + +class Popen2IOMaster(Popen2IO): + def __init__(self, args): + self.popen = p = Popen(args, stdin=PIPE, stdout=PIPE) + Popen2IO.__init__(self, p.stdin, p.stdout) + + def wait(self): + try: + return self.popen.wait() + except OSError: + pass # subprocess probably dead already + + def kill(self): + killpopen(self.popen) + +def killpopen(popen): + try: + if hasattr(popen, 'kill'): + popen.kill() + else: + killpid(popen.pid) + except EnvironmentError: + sys.stderr.write("ERROR killing: %s\n" %(sys.exc_info()[1])) + sys.stderr.flush() + +def killpid(pid): + if hasattr(os, 'kill'): + os.kill(pid, 15) + elif sys.platform == "win32" or getattr(os, '_name', None) == 'nt': + try: + import ctypes + except ImportError: + import subprocess + # T: treekill, F: Force + cmd = ("taskkill /T /F /PID %d" %(pid)).split() + ret = subprocess.call(cmd) + if ret != 0: + raise EnvironmentError("taskkill returned %r" %(ret,)) + else: + PROCESS_TERMINATE = 1 + handle = ctypes.windll.kernel32.OpenProcess( + PROCESS_TERMINATE, False, pid) + ctypes.windll.kernel32.TerminateProcess(handle, -1) + ctypes.windll.kernel32.CloseHandle(handle) + else: + raise EnvironmentError("no method to kill %s" %(pid,)) + + + +popen_bootstrapline = "import sys;exec(eval(sys.stdin.readline()))" + + +def popen_args(spec): + python = spec.python or sys.executable + args = [str(python), '-u'] + if spec is not None and spec.dont_write_bytecode: + args.append("-B") + # Slight gymnastics in ordering these arguments because CPython (as of + # 2.7.1) ignores -B if you provide `python -c "something" -B` + args.extend(['-c', popen_bootstrapline]) + return args + +def ssh_args(spec): + remotepython = spec.python or 'python' + args = ['ssh', '-C' ] + if spec.ssh_config is not None: + args.extend(['-F', str(spec.ssh_config)]) + remotecmd = '%s -c "%s"' %(remotepython, popen_bootstrapline) + args.extend([spec.ssh, remotecmd]) + return args + + + +def create_io(spec): + if spec.popen: + args = popen_args(spec) + return Popen2IOMaster(args) + if spec.ssh: + args = ssh_args(spec) + io = Popen2IOMaster(args) + io.remoteaddress = spec.ssh + return io + +RIO_KILL = 1 +RIO_WAIT = 2 +RIO_REMOTEADDRESS = 3 +RIO_CLOSE_WRITE = 4 + +class RemoteIO(object): + def __init__(self, master_channel): + self.iochan = master_channel.gateway.newchannel() + self.controlchan = master_channel.gateway.newchannel() + master_channel.send((self.iochan, self.controlchan)) + self.io = self.iochan.makefile('r') + + + def read(self, nbytes): + return self.io.read(nbytes) + + def write(self, data): + return self.iochan.send(data) + + def _controll(self, event): + self.controlchan.send(event) + return self.controlchan.receive() + + def close_write(self): + self._controll(RIO_CLOSE_WRITE) + + def kill(self): + self._controll(RIO_KILL) + + def wait(self): + return self._controll(RIO_WAIT) + + def __repr__(self): + return '' % (self.iochan.gateway.id, ) + + +def serve_remote_io(channel): + class PseudoSpec(object): + def __getattr__(self, name): + return None + spec = PseudoSpec() + spec.__dict__.update(channel.receive()) + io = create_io(spec) + io_chan, control_chan = channel.receive() + io_target = io_chan.makefile() + + def iothread(): + initial = io.read(1) + assert initial == '1'.encode('ascii') + channel.gateway._trace('initializing transfer io for', spec.id) + io_target.write(initial) + while True: + message = Message.from_io(io) + message.to_io(io_target) + import threading + thread = threading.Thread(name='io-forward-'+spec.id, + target=iothread) + thread.setDaemon(True) + thread.start() + + def iocallback(data): + io.write(data) + io_chan.setcallback(iocallback) + + + def controll(data): + if data==RIO_WAIT: + control_chan.send(io.wait()) + elif data==RIO_KILL: + control_chan.send(io.kill()) + elif data==RIO_REMOTEADDRESS: + control_chan.send(io.remoteaddress) + elif data==RIO_CLOSE_WRITE: + control_chan.send(io.close_write()) + control_chan.setcallback(controll) + +if __name__ == "__channelexec__": + serve_remote_io(channel) diff --git a/remoto/lib/execnet/gateway_socket.py b/remoto/lib/execnet/gateway_socket.py new file mode 100644 index 0000000..7b98bff --- /dev/null +++ b/remoto/lib/execnet/gateway_socket.py @@ -0,0 +1,91 @@ +import socket +from execnet.gateway import Gateway +from execnet.gateway_bootstrap import HostNotFound +import os, sys, inspect + + +try: bytes +except NameError: bytes = str + +class SocketIO: + + error = (socket.error, EOFError) + def __init__(self, sock): + self.sock = sock + try: + sock.setsockopt(socket.SOL_IP, socket.IP_TOS, 0x10)# IPTOS_LOWDELAY + sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) + except (AttributeError, socket.error): + sys.stderr.write("WARNING: cannot set socketoption") + + def read(self, numbytes): + "Read exactly 'bytes' bytes from the socket." + buf = bytes() + while len(buf) < numbytes: + t = self.sock.recv(numbytes - len(buf)) + if not t: + raise EOFError + buf += t + return buf + + def write(self, data): + self.sock.sendall(data) + + def close_read(self): + try: + self.sock.shutdown(0) + except socket.error: + pass + def close_write(self): + try: + self.sock.shutdown(1) + except socket.error: + pass + + def wait(self): + pass + + def kill(self): + pass + + +def start_via(gateway, hostport=None): + """ return a host, port tuple, + after instanciating a socketserver on the given gateway + """ + if hostport is None: + host, port = ('localhost', 0) + else: + host, port = hostport + + from execnet.script import socketserver + + # execute the above socketserverbootstrap on the other side + channel = gateway.remote_exec(socketserver) + channel.send((host, port)) + (realhost, realport) = channel.receive() + #self._trace("new_remote received" + # "port=%r, hostname = %r" %(realport, hostname)) + if not realhost or realhost=="0.0.0.0": + realhost = "localhost" + return realhost, realport + + +def create_io(spec, group): + assert not spec.python, ( + "socket: specifying python executables not yet supported") + gateway_id = spec.installvia + if gateway_id: + host, port = start_via(group[gateway_id]) + else: + host, port = spec.socket.split(":") + port = int(port) + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + io = SocketIO(sock) + io.remoteaddress = '%s:%d' % (host, port) + try: + sock.connect((host, port)) + except socket.gaierror: + raise HostNotFound(str(sys.exc_info()[1])) + return io diff --git a/remoto/lib/execnet/multi.py b/remoto/lib/execnet/multi.py new file mode 100644 index 0000000..a3417ce --- /dev/null +++ b/remoto/lib/execnet/multi.py @@ -0,0 +1,259 @@ +""" +Managing Gateway Groups and interactions with multiple channels. + +(c) 2008-2009, Holger Krekel and others +""" + +import os, sys, atexit +import time +import execnet +from execnet.threadpool import WorkerPool + +from execnet import XSpec +from execnet import gateway, gateway_io, gateway_bootstrap +from execnet.gateway_base import queue, reraise, trace, TimeoutError + +NO_ENDMARKER_WANTED = object() + +class Group: + """ Gateway Groups. """ + defaultspec = "popen" + def __init__(self, xspecs=()): + """ initialize group and make gateways as specified. """ + # Gateways may evolve to become GC-collectable + self._gateways = [] + self._autoidcounter = 0 + self._gateways_to_join = [] + for xspec in xspecs: + self.makegateway(xspec) + atexit.register(self._cleanup_atexit) + + def __repr__(self): + idgateways = [gw.id for gw in self] + return "" %(idgateways) + + def __getitem__(self, key): + if isinstance(key, int): + return self._gateways[key] + for gw in self._gateways: + if gw == key or gw.id == key: + return gw + raise KeyError(key) + + def __contains__(self, key): + try: + self[key] + return True + except KeyError: + return False + + def __len__(self): + return len(self._gateways) + + def __iter__(self): + return iter(list(self._gateways)) + + def makegateway(self, spec=None): + """create and configure a gateway to a Python interpreter. + The ``spec`` string encodes the target gateway type + and configuration information. The general format is:: + + key1=value1//key2=value2//... + + If you leave out the ``=value`` part a True value is assumed. + Valid types: ``popen``, ``ssh=hostname``, ``socket=host:port``. + Valid configuration:: + + id= specifies the gateway id + python= specifies which python interpreter to execute + chdir= specifies to which directory to change + nice= specifies process priority of new process + env:NAME=value specifies a remote environment variable setting. + + If no spec is given, self.defaultspec is used. + """ + if not spec: + spec = self.defaultspec + if not isinstance(spec, XSpec): + spec = XSpec(spec) + self.allocate_id(spec) + if spec.via: + assert not spec.socket + master = self[spec.via] + channel = master.remote_exec(gateway_io) + channel.send(vars(spec)) + io = gateway_io.RemoteIO(channel) + gw = gateway_bootstrap.bootstrap(io, spec) + elif spec.popen or spec.ssh: + io = gateway_io.create_io(spec) + gw = gateway_bootstrap.bootstrap(io, spec) + elif spec.socket: + from execnet import gateway_socket + io = gateway_socket.create_io(spec, self) + gw = gateway_bootstrap.bootstrap(io, spec) + else: + raise ValueError("no gateway type found for %r" % (spec._spec,)) + gw.spec = spec + self._register(gw) + if spec.chdir or spec.nice or spec.env: + channel = gw.remote_exec(""" + import os + path, nice, env = channel.receive() + if path: + if not os.path.exists(path): + os.mkdir(path) + os.chdir(path) + if nice and hasattr(os, 'nice'): + os.nice(nice) + if env: + for name, value in env.items(): + os.environ[name] = value + """) + nice = spec.nice and int(spec.nice) or 0 + channel.send((spec.chdir, nice, spec.env)) + channel.waitclose() + return gw + + def allocate_id(self, spec): + """ allocate id for the given xspec object. """ + if spec.id is None: + id = "gw" + str(self._autoidcounter) + self._autoidcounter += 1 + if id in self: + raise ValueError("already have gateway with id %r" %(id,)) + spec.id = id + + def _register(self, gateway): + assert not hasattr(gateway, '_group') + assert gateway.id + assert id not in self + self._gateways.append(gateway) + gateway._group = self + + def _unregister(self, gateway): + self._gateways.remove(gateway) + self._gateways_to_join.append(gateway) + + def _cleanup_atexit(self): + trace("=== atexit cleanup %r ===" %(self,)) + self.terminate(timeout=1.0) + + def terminate(self, timeout=None): + """ trigger exit of member gateways and wait for termination + of member gateways and associated subprocesses. After waiting + timeout seconds try to to kill local sub processes of popen- + and ssh-gateways. Timeout defaults to None meaning + open-ended waiting and no kill attempts. + """ + + while self: + from execnet.threadpool import WorkerPool + vias = {} + for gw in self: + if gw.spec.via: + vias[gw.spec.via] = True + for gw in self: + if gw.id not in vias: + gw.exit() + + def join_wait(gw): + gw.join() + gw._io.wait() + def kill(gw): + trace("Gateways did not come down after timeout: %r" % gw) + gw._io.kill() + + safe_terminate(timeout, [ + (lambda: join_wait(gw), lambda: kill(gw)) + for gw in self._gateways_to_join]) + self._gateways_to_join[:] = [] + + def remote_exec(self, source, **kwargs): + """ remote_exec source on all member gateways and return + MultiChannel connecting to all sub processes. + """ + channels = [] + for gw in self: + channels.append(gw.remote_exec(source, **kwargs)) + return MultiChannel(channels) + +class MultiChannel: + def __init__(self, channels): + self._channels = channels + + def __len__(self): + return len(self._channels) + + def __iter__(self): + return iter(self._channels) + + def __getitem__(self, key): + return self._channels[key] + + def __contains__(self, chan): + return chan in self._channels + + def send_each(self, item): + for ch in self._channels: + ch.send(item) + + def receive_each(self, withchannel=False): + assert not hasattr(self, '_queue') + l = [] + for ch in self._channels: + obj = ch.receive() + if withchannel: + l.append((ch, obj)) + else: + l.append(obj) + return l + + def make_receive_queue(self, endmarker=NO_ENDMARKER_WANTED): + try: + return self._queue + except AttributeError: + self._queue = queue.Queue() + for ch in self._channels: + def putreceived(obj, channel=ch): + self._queue.put((channel, obj)) + if endmarker is NO_ENDMARKER_WANTED: + ch.setcallback(putreceived) + else: + ch.setcallback(putreceived, endmarker=endmarker) + return self._queue + + + def waitclose(self): + first = None + for ch in self._channels: + try: + ch.waitclose() + except ch.RemoteError: + if first is None: + first = sys.exc_info() + if first: + reraise(*first) + + + +def safe_terminate(timeout, list_of_paired_functions): + workerpool = WorkerPool(len(list_of_paired_functions)*2) + + def termkill(termfunc, killfunc): + termreply = workerpool.dispatch(termfunc) + try: + termreply.get(timeout=timeout) + except IOError: + killfunc() + + replylist = [] + for termfunc, killfunc in list_of_paired_functions: + reply = workerpool.dispatch(termkill, termfunc, killfunc) + replylist.append(reply) + for reply in replylist: + reply.get() + + +default_group = Group() +makegateway = default_group.makegateway + diff --git a/remoto/lib/execnet/rsync.py b/remoto/lib/execnet/rsync.py new file mode 100644 index 0000000..ccfad91 --- /dev/null +++ b/remoto/lib/execnet/rsync.py @@ -0,0 +1,207 @@ +""" +1:N rsync implemenation on top of execnet. + +(c) 2006-2009, Armin Rigo, Holger Krekel, Maciej Fijalkowski +""" +import os, stat + +try: + from hashlib import md5 +except ImportError: + from md5 import md5 + +try: + from queue import Queue +except ImportError: + from Queue import Queue + +import execnet.rsync_remote + +class RSync(object): + """ This class allows to send a directory structure (recursively) + to one or multiple remote filesystems. + + There is limited support for symlinks, which means that symlinks + pointing to the sourcetree will be send "as is" while external + symlinks will be just copied (regardless of existance of such + a path on remote side). + """ + def __init__(self, sourcedir, callback=None, verbose=True): + self._sourcedir = str(sourcedir) + self._verbose = verbose + assert callback is None or hasattr(callback, '__call__') + self._callback = callback + self._channels = {} + self._receivequeue = Queue() + self._links = [] + + def filter(self, path): + return True + + def _end_of_channel(self, channel): + if channel in self._channels: + # too early! we must have got an error + channel.waitclose() + # or else we raise one + raise IOError('connection unexpectedly closed: %s ' % ( + channel.gateway,)) + + def _process_link(self, channel): + for link in self._links: + channel.send(link) + # completion marker, this host is done + channel.send(42) + + def _done(self, channel): + """ Call all callbacks + """ + finishedcallback = self._channels.pop(channel) + if finishedcallback: + finishedcallback() + channel.waitclose() + + def _list_done(self, channel): + # sum up all to send + if self._callback: + s = sum([self._paths[i] for i in self._to_send[channel]]) + self._callback("list", s, channel) + + def _send_item(self, channel, data): + """ Send one item + """ + modified_rel_path, checksum = data + modifiedpath = os.path.join(self._sourcedir, *modified_rel_path) + try: + f = open(modifiedpath, 'rb') + data = f.read() + except IOError: + data = None + + # provide info to progress callback function + modified_rel_path = "/".join(modified_rel_path) + if data is not None: + self._paths[modified_rel_path] = len(data) + else: + self._paths[modified_rel_path] = 0 + if channel not in self._to_send: + self._to_send[channel] = [] + self._to_send[channel].append(modified_rel_path) + #print "sending", modified_rel_path, data and len(data) or 0, checksum + + if data is not None: + f.close() + if checksum is not None and checksum == md5(data).digest(): + data = None # not really modified + else: + self._report_send_file(channel.gateway, modified_rel_path) + channel.send(data) + + def _report_send_file(self, gateway, modified_rel_path): + if self._verbose: + print("%s <= %s" %(gateway, modified_rel_path)) + + def send(self, raises=True): + """ Sends a sourcedir to all added targets. Flag indicates + whether to raise an error or return in case of lack of + targets + """ + if not self._channels: + if raises: + raise IOError("no targets available, maybe you " + "are trying call send() twice?") + return + # normalize a trailing '/' away + self._sourcedir = os.path.dirname(os.path.join(self._sourcedir, 'x')) + # send directory structure and file timestamps/sizes + self._send_directory_structure(self._sourcedir) + + # paths and to_send are only used for doing + # progress-related callbacks + self._paths = {} + self._to_send = {} + + # send modified file to clients + while self._channels: + channel, req = self._receivequeue.get() + if req is None: + self._end_of_channel(channel) + else: + command, data = req + if command == "links": + self._process_link(channel) + elif command == "done": + self._done(channel) + elif command == "ack": + if self._callback: + self._callback("ack", self._paths[data], channel) + elif command == "list_done": + self._list_done(channel) + elif command == "send": + self._send_item(channel, data) + del data + else: + assert "Unknown command %s" % command + + def add_target(self, gateway, destdir, + finishedcallback=None, **options): + """ Adds a remote target specified via a gateway + and a remote destination directory. + """ + for name in options: + assert name in ('delete',) + def itemcallback(req): + self._receivequeue.put((channel, req)) + channel = gateway.remote_exec(execnet.rsync_remote) + channel.reconfigure(py2str_as_py3str=False, py3str_as_py2str=False) + channel.setcallback(itemcallback, endmarker = None) + channel.send((str(destdir), options)) + self._channels[channel] = finishedcallback + + def _broadcast(self, msg): + for channel in self._channels: + channel.send(msg) + + def _send_link(self, linktype, basename, linkpoint): + self._links.append((linktype, basename, linkpoint)) + + def _send_directory(self, path): + # dir: send a list of entries + names = [] + subpaths = [] + for name in os.listdir(path): + p = os.path.join(path, name) + if self.filter(p): + names.append(name) + subpaths.append(p) + mode = os.lstat(path).st_mode + self._broadcast([mode] + names) + for p in subpaths: + self._send_directory_structure(p) + + def _send_link_structure(self, path): + linkpoint = os.readlink(path) + basename = path[len(self._sourcedir) + 1:] + if linkpoint.startswith(self._sourcedir): + self._send_link("linkbase", basename, + linkpoint[len(self._sourcedir) + 1:]) + else: + # relative or absolute link, just send it + self._send_link("link", basename, linkpoint) + self._broadcast(None) + + def _send_directory_structure(self, path): + try: + st = os.lstat(path) + except OSError: + self._broadcast((None, 0, 0)) + return + if stat.S_ISREG(st.st_mode): + # regular file: send a mode/timestamp/size pair + self._broadcast((st.st_mode, st.st_mtime, st.st_size)) + elif stat.S_ISDIR(st.st_mode): + self._send_directory(path) + elif stat.S_ISLNK(st.st_mode): + self._send_link_structure(path) + else: + raise ValueError("cannot sync %r" % (path,)) + diff --git a/remoto/lib/execnet/rsync_remote.py b/remoto/lib/execnet/rsync_remote.py new file mode 100644 index 0000000..eee139c --- /dev/null +++ b/remoto/lib/execnet/rsync_remote.py @@ -0,0 +1,109 @@ +""" +(c) 2006-2009, Armin Rigo, Holger Krekel, Maciej Fijalkowski +""" +def serve_rsync(channel): + import os, stat, shutil + try: + from hashlib import md5 + except ImportError: + from md5 import md5 + destdir, options = channel.receive() + modifiedfiles = [] + + def remove(path): + assert path.startswith(destdir) + try: + os.unlink(path) + except OSError: + # assume it's a dir + shutil.rmtree(path) + + def receive_directory_structure(path, relcomponents): + try: + st = os.lstat(path) + except OSError: + st = None + msg = channel.receive() + if isinstance(msg, list): + if st and not stat.S_ISDIR(st.st_mode): + os.unlink(path) + st = None + if not st: + os.makedirs(path) + mode = msg.pop(0) + if mode: + os.chmod(path, mode) + entrynames = {} + for entryname in msg: + destpath = os.path.join(path, entryname) + receive_directory_structure(destpath, relcomponents + [entryname]) + entrynames[entryname] = True + if options.get('delete'): + for othername in os.listdir(path): + if othername not in entrynames: + otherpath = os.path.join(path, othername) + remove(otherpath) + elif msg is not None: + assert isinstance(msg, tuple) + checksum = None + if st: + if stat.S_ISREG(st.st_mode): + msg_mode, msg_mtime, msg_size = msg + if msg_size != st.st_size: + pass + elif msg_mtime != st.st_mtime: + f = open(path, 'rb') + checksum = md5(f.read()).digest() + f.close() + elif msg_mode and msg_mode != st.st_mode: + os.chmod(path, msg_mode) + return + else: + return # already fine + else: + remove(path) + channel.send(("send", (relcomponents, checksum))) + modifiedfiles.append((path, msg)) + receive_directory_structure(destdir, []) + + STRICT_CHECK = False # seems most useful this way for py.test + channel.send(("list_done", None)) + + for path, (mode, time, size) in modifiedfiles: + data = channel.receive() + channel.send(("ack", path[len(destdir) + 1:])) + if data is not None: + if STRICT_CHECK and len(data) != size: + raise IOError('file modified during rsync: %r' % (path,)) + f = open(path, 'wb') + f.write(data) + f.close() + try: + if mode: + os.chmod(path, mode) + os.utime(path, (time, time)) + except OSError: + pass + del data + channel.send(("links", None)) + + msg = channel.receive() + while msg != 42: + # we get symlink + _type, relpath, linkpoint = msg + path = os.path.join(destdir, relpath) + try: + remove(path) + except OSError: + pass + if _type == "linkbase": + src = os.path.join(destdir, linkpoint) + else: + assert _type == "link", _type + src = linkpoint + os.symlink(src, path) + msg = channel.receive() + channel.send(("done", None)) + +if __name__ == '__channelexec__': + serve_rsync(channel) diff --git a/remoto/lib/execnet/script/__init__.py b/remoto/lib/execnet/script/__init__.py new file mode 100644 index 0000000..792d600 --- /dev/null +++ b/remoto/lib/execnet/script/__init__.py @@ -0,0 +1 @@ +# diff --git a/remoto/lib/execnet/script/loop_socketserver.py b/remoto/lib/execnet/script/loop_socketserver.py new file mode 100644 index 0000000..44896b6 --- /dev/null +++ b/remoto/lib/execnet/script/loop_socketserver.py @@ -0,0 +1,14 @@ + +import os, sys +import subprocess + +if __name__ == '__main__': + directory = os.path.dirname(os.path.abspath(sys.argv[0])) + script = os.path.join(directory, 'socketserver.py') + while 1: + cmdlist = ["python", script] + cmdlist.extend(sys.argv[1:]) + text = "starting subcommand: " + " ".join(cmdlist) + print(text) + process = subprocess.Popen(cmdlist) + process.wait() diff --git a/remoto/lib/execnet/script/quitserver.py b/remoto/lib/execnet/script/quitserver.py new file mode 100644 index 0000000..5b7ebdb --- /dev/null +++ b/remoto/lib/execnet/script/quitserver.py @@ -0,0 +1,16 @@ +""" + + send a "quit" signal to a remote server + +""" + +import sys +import socket + +hostport = sys.argv[1] +host, port = hostport.split(':') +hostport = (host, int(port)) + +sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +sock.connect(hostport) +sock.sendall('"raise KeyboardInterrupt"\n') diff --git a/remoto/lib/execnet/script/shell.py b/remoto/lib/execnet/script/shell.py new file mode 100644 index 0000000..9196f41 --- /dev/null +++ b/remoto/lib/execnet/script/shell.py @@ -0,0 +1,85 @@ +#! /usr/bin/env python +""" +a remote python shell + +for injection into startserver.py +""" +import sys, os, socket, select + +try: + clientsock +except NameError: + print("client side starting") + import sys + host, port = sys.argv[1].split(':') + port = int(port) + myself = open(os.path.abspath(sys.argv[0]), 'rU').read() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((host, port)) + sock.sendall(repr(myself)+'\n') + print("send boot string") + inputlist = [ sock, sys.stdin ] + try: + while 1: + r,w,e = select.select(inputlist, [], []) + if sys.stdin in r: + line = raw_input() + sock.sendall(line + '\n') + if sock in r: + line = sock.recv(4096) + sys.stdout.write(line) + sys.stdout.flush() + except: + import traceback + print(traceback.print_exc()) + + sys.exit(1) + +print("server side starting") +# server side +# +from traceback import print_exc +from threading import Thread + +class promptagent(Thread): + def __init__(self, clientsock): + Thread.__init__(self) + self.clientsock = clientsock + + def run(self): + print("Entering thread prompt loop") + clientfile = self.clientsock.makefile('w') + + filein = self.clientsock.makefile('r') + loc = self.clientsock.getsockname() + + while 1: + try: + clientfile.write('%s %s >>> ' % loc) + clientfile.flush() + line = filein.readline() + if len(line)==0: raise EOFError("nothing") + #print >>sys.stderr,"got line: " + line + if line.strip(): + oldout, olderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = clientfile, clientfile + try: + try: + exec(compile(line + '\n','', 'single')) + except: + print_exc() + finally: + sys.stdout=oldout + sys.stderr=olderr + clientfile.flush() + except EOFError: + e = sys.exc_info()[1] + sys.stderr.write("connection close, prompt thread returns") + break + #print >>sys.stdout, "".join(apply(format_exception,sys.exc_info())) + + self.clientsock.close() + +prompter = promptagent(clientsock) +prompter.start() +print("promptagent - thread started") diff --git a/remoto/lib/execnet/script/socketserver.py b/remoto/lib/execnet/script/socketserver.py new file mode 100644 index 0000000..596597b --- /dev/null +++ b/remoto/lib/execnet/script/socketserver.py @@ -0,0 +1,112 @@ +#! /usr/bin/env python + +""" + start socket based minimal readline exec server + + it can exeuted in 2 modes of operation + + 1. as normal script, that listens for new connections + + 2. via existing_gateway.remote_exec (as imported module) + +""" +# this part of the program only executes on the server side +# + +progname = 'socket_readline_exec_server-1.2' + +import sys, socket, os +try: + import fcntl +except ImportError: + fcntl = None + +debug = 0 + +if debug: # and not os.isatty(sys.stdin.fileno()): + f = open('/tmp/execnet-socket-pyout.log', 'w') + old = sys.stdout, sys.stderr + sys.stdout = sys.stderr = f + +def print_(*args): + print(" ".join(str(arg) for arg in args)) + +if sys.version_info > (3, 0): + exec("""def exec_(source, locs): + exec(source, locs)""") +else: + exec("""def exec_(source, locs): + exec source in locs""") + +def exec_from_one_connection(serversock): + print_(progname, 'Entering Accept loop', serversock.getsockname()) + clientsock,address = serversock.accept() + print_(progname, 'got new connection from %s %s' % address) + clientfile = clientsock.makefile('rb') + print_("reading line") + # rstrip so that we can use \r\n for telnet testing + source = clientfile.readline().rstrip() + clientfile.close() + g = {'clientsock' : clientsock, 'address' : address} + source = eval(source) + if source: + co = compile(source+'\n', source, 'exec') + print_(progname, 'compiled source, executing') + try: + exec_(co, g) + finally: + print_(progname, 'finished executing code') + # background thread might hold a reference to this (!?) + #clientsock.close() + +def bind_and_listen(hostport): + if isinstance(hostport, str): + host, port = hostport.split(':') + hostport = (host, int(port)) + serversock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # set close-on-exec + if hasattr(fcntl, 'FD_CLOEXEC'): + old = fcntl.fcntl(serversock.fileno(), fcntl.F_GETFD) + fcntl.fcntl(serversock.fileno(), fcntl.F_SETFD, old | fcntl.FD_CLOEXEC) + # allow the address to be re-used in a reasonable amount of time + if os.name == 'posix' and sys.platform != 'cygwin': + serversock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + serversock.bind(hostport) + serversock.listen(5) + return serversock + +def startserver(serversock, loop=False): + try: + while 1: + try: + exec_from_one_connection(serversock) + except (KeyboardInterrupt, SystemExit): + raise + except: + if debug: + import traceback + traceback.print_exc() + else: + excinfo = sys.exc_info() + print_("got exception", excinfo[1]) + if not loop: + break + finally: + print_("leaving socketserver execloop") + serversock.shutdown(2) + +if __name__ == '__main__': + import sys + if len(sys.argv)>1: + hostport = sys.argv[1] + else: + hostport = ':8888' + serversock = bind_and_listen(hostport) + startserver(serversock, loop=False) +elif __name__=='__channelexec__': + bindname = channel.receive() + sock = bind_and_listen(bindname) + port = sock.getsockname() + channel.send(port) + startserver(sock) diff --git a/remoto/lib/execnet/script/socketserverservice.py b/remoto/lib/execnet/script/socketserverservice.py new file mode 100644 index 0000000..0d208f8 --- /dev/null +++ b/remoto/lib/execnet/script/socketserverservice.py @@ -0,0 +1,91 @@ +""" +A windows service wrapper for the py.execnet socketserver. + +To use, run: + python socketserverservice.py register + net start ExecNetSocketServer +""" + +import sys +import os +import time +import win32serviceutil +import win32service +import win32event +import win32evtlogutil +import servicemanager +import threading +import socketserver + + +appname = 'ExecNetSocketServer' + + +class SocketServerService(win32serviceutil.ServiceFramework): + _svc_name_ = appname + _svc_display_name_ = "%s" % appname + _svc_deps_ = ["EventLog"] + def __init__(self, args): + # The exe-file has messages for the Event Log Viewer. + # Register the exe-file as event source. + # + # Probably it would be better if this is done at installation time, + # so that it also could be removed if the service is uninstalled. + # Unfortunately it cannot be done in the 'if __name__ == "__main__"' + # block below, because the 'frozen' exe-file does not run this code. + # + win32evtlogutil.AddSourceToRegistry(self._svc_display_name_, + servicemanager.__file__, + "Application") + win32serviceutil.ServiceFramework.__init__(self, args) + self.hWaitStop = win32event.CreateEvent(None, 0, 0, None) + self.WAIT_TIME = 1000 # in milliseconds + + + def SvcStop(self): + self.ReportServiceStatus(win32service.SERVICE_STOP_PENDING) + win32event.SetEvent(self.hWaitStop) + + + def SvcDoRun(self): + # Redirect stdout and stderr to prevent "IOError: [Errno 9] + # Bad file descriptor". Windows services don't have functional + # output streams. + sys.stdout = sys.stderr = open('nul', 'w') + + # Write a 'started' event to the event log... + win32evtlogutil.ReportEvent(self._svc_display_name_, + servicemanager.PYS_SERVICE_STARTED, + 0, # category + servicemanager.EVENTLOG_INFORMATION_TYPE, + (self._svc_name_, '')) + print("Begin: %s" % (self._svc_display_name_)) + + hostport = ':8888' + print('Starting py.execnet SocketServer on %s' % hostport) + serversock = socketserver.bind_and_listen(hostport) + thread = threading.Thread(target=socketserver.startserver, + args=(serversock,), + kwargs={'loop':True}) + thread.setDaemon(True) + thread.start() + + # wait to be stopped or self.WAIT_TIME to pass + while True: + result = win32event.WaitForSingleObject(self.hWaitStop, + self.WAIT_TIME) + if result == win32event.WAIT_OBJECT_0: + break + + # write a 'stopped' event to the event log. + win32evtlogutil.ReportEvent(self._svc_display_name_, + servicemanager.PYS_SERVICE_STOPPED, + 0, # category + servicemanager.EVENTLOG_INFORMATION_TYPE, + (self._svc_name_, '')) + print("End: %s" % appname) + + +if __name__ == '__main__': + # Note that this code will not be run in the 'frozen' exe-file!!! + win32serviceutil.HandleCommandLine(SocketServerService) diff --git a/remoto/lib/execnet/script/xx.py b/remoto/lib/execnet/script/xx.py new file mode 100644 index 0000000..931e4b7 --- /dev/null +++ b/remoto/lib/execnet/script/xx.py @@ -0,0 +1,9 @@ +import rlcompleter2 +rlcompleter2.setup() + +import register, sys +try: + hostport = sys.argv[1] +except: + hostport = ':8888' +gw = register.ServerGateway(hostport) diff --git a/remoto/lib/execnet/threadpool.py b/remoto/lib/execnet/threadpool.py new file mode 100644 index 0000000..812d16a --- /dev/null +++ b/remoto/lib/execnet/threadpool.py @@ -0,0 +1,183 @@ +""" +dispatching execution to threads + +(c) 2009, holger krekel +""" +import threading +import time +import sys + +# py2/py3 compatibility +try: + import queue +except ImportError: + import Queue as queue +if sys.version_info >= (3,0): + exec ("def reraise(cls, val, tb): raise val") +else: + exec ("def reraise(cls, val, tb): raise cls, val, tb") + +ERRORMARKER = object() + +class Reply(object): + """ reply instances provide access to the result + of a function execution that got dispatched + through WorkerPool.dispatch() + """ + _excinfo = None + def __init__(self, task): + self.task = task + self._queue = queue.Queue() + + def _set(self, result): + self._queue.put(result) + + def _setexcinfo(self, excinfo): + self._excinfo = excinfo + self._queue.put(ERRORMARKER) + + def get(self, timeout=None): + """ get the result object from an asynchronous function execution. + if the function execution raised an exception, + then calling get() will reraise that exception + including its traceback. + """ + if self._queue is None: + raise EOFError("reply has already been delivered") + try: + result = self._queue.get(timeout=timeout) + except queue.Empty: + raise IOError("timeout waiting for %r" %(self.task, )) + if result is ERRORMARKER: + self._queue = None + excinfo = self._excinfo + reraise(excinfo[0], excinfo[1], excinfo[2]) + return result + +class WorkerThread(threading.Thread): + def __init__(self, pool): + threading.Thread.__init__(self) + self._queue = queue.Queue() + self._pool = pool + self.setDaemon(1) + + def _run_once(self): + reply = self._queue.get() + if reply is SystemExit: + return False + assert self not in self._pool._ready + task = reply.task + try: + func, args, kwargs = task + result = func(*args, **kwargs) + except (SystemExit, KeyboardInterrupt): + return False + except: + reply._setexcinfo(sys.exc_info()) + else: + reply._set(result) + # at this point, reply, task and all other local variables go away + return True + + def run(self): + try: + while self._run_once(): + self._pool._ready[self] = True + finally: + del self._pool._alive[self] + try: + del self._pool._ready[self] + except KeyError: + pass + + def send(self, task): + reply = Reply(task) + self._queue.put(reply) + return reply + + def stop(self): + self._queue.put(SystemExit) + +class WorkerPool(object): + """ A WorkerPool allows to dispatch function executions + to threads. Each Worker Thread is reused for multiple + function executions. The dispatching operation + takes care to create and dispatch to existing + threads. + + You need to call shutdown() to signal + the WorkerThreads to terminate and join() + in order to wait until all worker threads + have terminated. + """ + _shuttingdown = False + def __init__(self, maxthreads=None): + """ init WorkerPool instance which may + create up to `maxthreads` worker threads. + """ + self.maxthreads = maxthreads + self._ready = {} + self._alive = {} + + def dispatch(self, func, *args, **kwargs): + """ return Reply object for the asynchronous dispatch + of the given func(*args, **kwargs) in a + separate worker thread. + """ + if self._shuttingdown: + raise IOError("WorkerPool is already shutting down") + try: + thread, _ = self._ready.popitem() + except KeyError: # pop from empty list + if self.maxthreads and len(self._alive) >= self.maxthreads: + raise IOError("can't create more than %d threads." % + (self.maxthreads,)) + thread = self._newthread() + return thread.send((func, args, kwargs)) + + def _newthread(self): + thread = WorkerThread(self) + self._alive[thread] = True + thread.start() + return thread + + def shutdown(self): + """ signal all worker threads to terminate. + call join() to wait until all threads termination. + """ + if not self._shuttingdown: + self._shuttingdown = True + for t in list(self._alive): + t.stop() + + def join(self, timeout=None): + """ wait until all worker threads have terminated. """ + current = threading.currentThread() + deadline = delta = None + if timeout is not None: + deadline = time.time() + timeout + for thread in list(self._alive): + if deadline: + delta = deadline - time.time() + if delta <= 0: + raise IOError("timeout while joining threads") + thread.join(timeout=delta) + if thread.isAlive(): + raise IOError("timeout while joining threads") + +if __name__ == '__channelexec__': + maxthreads = channel.receive() + execpool = WorkerPool(maxthreads=maxthreads) + gw = channel.gateway + channel.send("ok") + gw._trace("instantiated thread work pool maxthreads=%s" %(maxthreads,)) + while 1: + gw._trace("waiting for new exec task") + task = gw._execqueue.get() + if task is None: + gw._trace("thread-dispatcher got None, exiting") + execpool.shutdown() + execpool.join() + raise gw._StopExecLoop + gw._trace("dispatching exec task to thread pool") + execpool.dispatch(gw.executetask, task) diff --git a/remoto/lib/execnet/xspec.py b/remoto/lib/execnet/xspec.py new file mode 100644 index 0000000..549966a --- /dev/null +++ b/remoto/lib/execnet/xspec.py @@ -0,0 +1,54 @@ +""" +(c) 2008-2009, holger krekel +""" +import execnet + +class XSpec: + """ Execution Specification: key1=value1//key2=value2 ... + * keys need to be unique within the specification scope + * neither key nor value are allowed to contain "//" + * keys are not allowed to contain "=" + * keys are not allowed to start with underscore + * if no "=value" is given, assume a boolean True value + """ + # XXX allow customization, for only allow specific key names + popen = ssh = socket = python = chdir = nice = dont_write_bytecode = None + + def __init__(self, string): + self._spec = string + self.env = {} + for keyvalue in string.split("//"): + i = keyvalue.find("=") + if i == -1: + key, value = keyvalue, True + else: + key, value = keyvalue[:i], keyvalue[i+1:] + if key[0] == "_": + raise AttributeError("%r not a valid XSpec key" % key) + if key in self.__dict__: + raise ValueError("duplicate key: %r in %r" %(key, string)) + if key.startswith("env:"): + self.env[key[4:]] = value + else: + setattr(self, key, value) + + def __getattr__(self, name): + if name[0] == "_": + raise AttributeError(name) + return None + + def __repr__(self): + return "" %(self._spec,) + def __str__(self): + return self._spec + + def __hash__(self): + return hash(self._spec) + def __eq__(self, other): + return self._spec == getattr(other, '_spec', None) + def __ne__(self, other): + return self._spec != getattr(other, '_spec', None) + + def _samefilesystem(self): + return bool(self.popen and not self.chdir) + diff --git a/remoto/log.py b/remoto/log.py new file mode 100644 index 0000000..aba9538 --- /dev/null +++ b/remoto/log.py @@ -0,0 +1,21 @@ + + +def reporting(conn, result, timeout=None): + timeout = timeout or -1 # a.k.a. wait for ever + log_map = {'debug': conn.logger.debug, 'error': conn.logger.error} + while True: + try: + received = result.receive(timeout) + level_received, message = list(received.items())[0] + log_map[level_received](message.strip('\n')) + except EOFError: + break + except Exception as err: + # the things we need to do here :( + # because execnet magic, we cannot catch this as + # `except TimeoutError` + if err.__class__.__name__ == 'TimeoutError': + msg = 'No data was received after %s seconds, disconnecting...' % timeout + conn.logger.warning(msg) + break + raise diff --git a/remoto/process.py b/remoto/process.py new file mode 100644 index 0000000..38de2f0 --- /dev/null +++ b/remoto/process.py @@ -0,0 +1,73 @@ +from .log import reporting +from .util import admin_command + + +def _remote_run(channel, cmd): + import subprocess + import sys + + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + if process.stderr: + while True: + err = process.stderr.readline() + if err == '' and process.poll() != None: + break + if err != '': + channel.send({'error':err}) + sys.stderr.flush() + if process.stdout: + while True: + out = process.stdout.readline() + if out == '' and process.poll() != None: + break + if out != '': + channel.send({'debug':out}) + sys.stdout.flush() + + +def run(conn, command, exit=False, timeout=None): + """ + A real-time-logging implementation of a remote subprocess.Popen call where + a command is just executed on the remote end and no other handling is done. + + :param conn: A connection oject + :param command: The command to pass in to the remote subprocess.Popen + :param exit: If this call should close the connection at the end + :param timeout: How many seconds to wait after no remote data is received + (defaults to wait for ever) + """ + timeout = timeout or -1 + conn.logger.info('Running command: %s' % ' '.join(admin_command(conn.sudo, command))) + result = conn.execute(_remote_run, cmd=command) + reporting(conn, result, timeout) + if exit: + conn.exit() + + +def _remote_check(channel, cmd): + import subprocess + + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + stdout = [line.strip('\n') for line in process.stdout.readlines()] + stderr = [line.strip('\n') for line in process.stderr.readlines()] + channel.send((stdout, stderr, process.wait())) + + +def check(conn, command, exit=False): + """ + Execute a remote command with ``subprocess.Popen`` but report back the + results in a tuple with three items: stdout, stderr, and exit status. + + This helper function *does not* provide any logging as it is the caller's + responsibility to do so. + """ + conn.logger.info('Running command: %s' % ' '.join(admin_command(conn.sudo, command))) + result = conn.execute(_remote_check, cmd=command) + return result.receive() + if exit: + conn.exit() diff --git a/remoto/tests/test_connection.py b/remoto/tests/test_connection.py new file mode 100644 index 0000000..e4040bf --- /dev/null +++ b/remoto/tests/test_connection.py @@ -0,0 +1,81 @@ +from py.test import raises +from remoto import connection + + +class FakeSocket(object): + + def __init__(self, gethostname): + self.gethostname = lambda: gethostname + + +class TestNeedsSsh(object): + + def test_short_hostname_matches(self): + socket = FakeSocket('foo.example.org') + assert connection.needs_ssh('foo', socket) is False + + def test_long_hostname_matches(self): + socket = FakeSocket('foo.example.org') + assert connection.needs_ssh('foo.example.org', socket) is False + + def test_hostname_does_not_match(self): + socket = FakeSocket('foo') + assert connection.needs_ssh('meh', socket) is True + + +class FakeGateway(object): + + def remote_exec(self, module): + pass + + +class TestModuleExecuteArgs(object): + + def setup(self): + self.remote_module = connection.ModuleExecute(FakeGateway(), None) + + def test_single_argument(self): + assert self.remote_module._convert_args(('foo',)) == "'foo'" + + def test_more_than_one_argument(self): + args = ('foo', 'bar', 1) + assert self.remote_module._convert_args(args) == "'foo', 'bar', 1" + + def test_dictionary_as_argument(self): + args = ({'some key': 1},) + assert self.remote_module._convert_args(args) == "{'some key': 1}" + + +class TestModuleExecuteGetAttr(object): + + def setup(self): + self.remote_module = connection.ModuleExecute(FakeGateway(), None) + + def test_raise_attribute_error(self): + with raises(AttributeError) as err: + self.remote_module.foo() + assert err.value.args[0] == 'module None does not have attribute foo' + + +class TestMakeConnectionString(object): + + def test_makes_sudo_python_no_ssh(self): + conn = connection.Connection('localhost', sudo=True, eager=False) + conn_string = conn._make_connection_string('localhost', _needs_ssh=lambda x: False) + assert conn_string == 'python=sudo python' + + def test_makes_sudo_python_with_ssh(self): + conn = connection.Connection('localhost', sudo=True, eager=False) + conn_string = conn._make_connection_string('localhost', _needs_ssh=lambda x: True) + assert conn_string == 'ssh=localhost//python=sudo python' + + def test_makes_python_no_ssh(self): + conn = connection.Connection('localhost', sudo=False, eager=False) + conn_string = conn._make_connection_string('localhost', _needs_ssh=lambda x: False) + assert conn_string == 'python=python' + + def test_makes_python_with_ssh(self): + conn = connection.Connection('localhost', sudo=False, eager=False) + conn_string = conn._make_connection_string('localhost', _needs_ssh=lambda x: True) + assert conn_string == 'ssh=localhost//python=python' + diff --git a/remoto/tests/test_log.py b/remoto/tests/test_log.py new file mode 100644 index 0000000..8aecc90 --- /dev/null +++ b/remoto/tests/test_log.py @@ -0,0 +1,52 @@ +from pytest import raises +from remoto import log +from remoto.exc import TimeoutError +from mock import Mock + + +class TestReporting(object): + + def test_reporting_when_channel_is_empty(self): + conn = Mock() + result = Mock() + result.receive.side_effect = EOFError + log.reporting(conn, result) + + def test_write_debug_statements(self): + conn = Mock() + result = Mock() + result.receive.side_effect = [{'debug': 'a debug message'}, EOFError] + log.reporting(conn, result) + assert conn.logger.debug.called is True + assert conn.logger.info.called is False + + def test_write_info_statements(self): + conn = Mock() + result = Mock() + result.receive.side_effect = [{'error': 'an error message'}, EOFError] + log.reporting(conn, result) + assert conn.logger.debug.called is False + assert conn.logger.error.called is True + + def test_strip_new_lines(self): + conn = Mock() + result = Mock() + result.receive.side_effect = [{'error': 'an error message\n\n'}, EOFError] + log.reporting(conn, result) + message = conn.logger.error.call_args[0][0] + assert message == 'an error message' + + def test_timeout_error(self): + conn = Mock() + result = Mock() + result.receive.side_effect = TimeoutError + log.reporting(conn, result) + message = conn.logger.warning.call_args[0][0] + assert 'No data was received after ' in message + + def test_raises_other_errors(self): + conn = Mock() + result = Mock() + result.receive.side_effect = OSError + with raises(OSError): + log.reporting(conn, result) diff --git a/remoto/tests/test_process.py b/remoto/tests/test_process.py new file mode 100644 index 0000000..1eff8c3 --- /dev/null +++ b/remoto/tests/test_process.py @@ -0,0 +1,3 @@ +# Having imports inlined in the function makes it really complicated to test +# while controlling the environment. Figure out a way to deal with inlined imports +# so testing evolves nicely. diff --git a/remoto/tests/test_util.py b/remoto/tests/test_util.py new file mode 100644 index 0000000..4ecfb7d --- /dev/null +++ b/remoto/tests/test_util.py @@ -0,0 +1,12 @@ +from remoto import util + + +class TestAdminCommand(object): + + def test_prepend_list_if_sudo(self): + result = util.admin_command(True, ['ls']) + assert result == ['sudo', 'ls'] + + def test_skip_prepend_if_not_sudo(self): + result = util.admin_command(False, ['ls']) + assert result == ['ls'] diff --git a/remoto/util.py b/remoto/util.py new file mode 100644 index 0000000..63b1bc0 --- /dev/null +++ b/remoto/util.py @@ -0,0 +1,14 @@ + + +def admin_command(sudo, command): + """ + If sudo is needed, make sure the command is prepended + correctly, otherwise return the command as it came. + + :param sudo: A boolean representing the intention of having a sudo command + (or not) + :param command: A list of the actual command to execute with Popen. + """ + if sudo: + command.insert(0, 'sudo') + return command diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8df6dc3 --- /dev/null +++ b/setup.py @@ -0,0 +1,32 @@ +import re + +module_file = open("remoto/__init__.py").read() +metadata = dict(re.findall("__([a-z]+)__\s*=\s*'([^']+)'", module_file)) +long_description = open('README.rst').read() + +from setuptools import setup, find_packages + +setup( + name = 'remoto', + description = 'Execute remote commands or processes.', + packages = find_packages(), + author = 'Alfredo Deza', + author_email = 'contact [at] deza.pe', + version = metadata['version'], + url = 'http://github.com/alfredodeza/remoto', + license = "MIT", + zip_safe = False, + keywords = "remote, commands, unix, ssh, socket, execute, terminal", + long_description = long_description, + classifiers = [ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Topic :: Utilities', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: POSIX', + 'Programming Language :: Python :: 2.6', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.3', + ] +) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..b9c7d47 --- /dev/null +++ b/tox.ini @@ -0,0 +1,8 @@ +[tox] +envlist = py26, py27, py33 + +[testenv] +deps = + pytest + mock +commands = py.test -v