--- /dev/null
+#!/bin/python3
+
+# Forward stdin to a subcommand. If EOF is read from stdin or
+# stdin/stdout/stderr are closed or hungup, then give the command "timeout"
+# seconds to complete before it is killed.
+#
+# The command is run in a separate process group. This is mostly to simplify
+# killing the set of processes (if well-behaving). You can configure that with
+# --setpgrp switch.
+
+# usage: stdin-killer [-h] [--timeout TIMEOUT] [--debug DEBUG] [--signal SIGNAL] [--verbose] [--setpgrp {no,self,child}] command [arguments ...]
+#
+# wait for stdin EOF then kill forked subcommand
+#
+# positional arguments:
+# command command to execute
+# arguments arguments to command
+#
+# options:
+# -h, --help show this help message and exit
+# --timeout TIMEOUT time to wait for forked subcommand to willing terminate
+# --debug DEBUG debug file
+# --signal SIGNAL signal to send
+# --verbose increase debugging
+# --setpgrp {no,self,child}
+# create process group
+
+
+import argparse
+import fcntl
+import logging
+import os
+import select
+import signal
+import struct
+import subprocess
+import sys
+import time
+
+NAME = "stdin-killer"
+
+log = logging.getLogger(NAME)
+PAGE_SIZE = 4096
+
+POLL_HANGUP = select.POLLHUP | select.POLLRDHUP | select.POLLERR
+
+
+def handle_event(poll, buffer, fd, event, p):
+ if sigfdr == fd:
+ b = os.read(sigfdr, 1)
+ (signum,) = struct.unpack("B", b)
+ log.debug("got signal %d", signum)
+ try:
+ p.wait(timeout=0)
+ return True
+ except subprocess.TimeoutExpired:
+ pass
+ elif 0 == fd:
+ if event & POLL_HANGUP:
+ log.debug("peer closed connection, waiting for process exit")
+ poll.unregister(0)
+ sys.stdin.close()
+ if len(buffer) == 0 and p.stdin is not None:
+ p.stdin.close()
+ p.stdin = None
+ return True
+ elif event & select.POLLIN:
+ b = os.read(0, PAGE_SIZE)
+ if b == b"":
+ log.debug("read EOF")
+ poll.unregister(0)
+ sys.stdin.close()
+ if len(buffer) == 0:
+ p.stdin.close()
+ return True
+ if p.stdin is not None:
+ buffer += b
+ # ignore further POLLIN until buffer is written to p.stdin
+ poll.register(0, POLL_HANGUP)
+ poll.register(p.stdin.fileno(), select.POLLOUT)
+ elif p.stdin is not None and p.stdin.fileno() == fd:
+ assert event & select.POLLOUT
+ b = buffer[:PAGE_SIZE]
+ log.debug("sending %d bytes to process", len(b))
+ try:
+ n = p.stdin.write(b)
+ p.stdin.flush()
+ log.debug("wrote %d bytes", n)
+ buffer = buffer[n:]
+ poll.register(0, select.POLLIN | POLL_HANGUP)
+ poll.unregister(p.stdin.fileno())
+ except BrokenPipeError:
+ log.debug("got SIGPIPE")
+ poll.unregister(p.stdin.fileno())
+ p.stdin.close()
+ p.stdin = None
+ return True
+ except BlockingIOError:
+ poll.register(p.stdin.fileno(), select.POLLOUT | POLL_HANGUP)
+ elif 1 == fd:
+ assert event & POLL_HANGUP
+ log.debug("stdout pipe has closed")
+ poll.unregister(1)
+ return True
+ elif 2 == fd:
+ assert event & POLL_HANGUP
+ log.debug("stderr pipe has closed")
+ poll.unregister(2)
+ return True
+ else:
+ assert False
+ return False
+
+
+def listen_for_events(sigfdr, p, timeout):
+ poll = select.poll()
+ # listen for data on stdin
+ poll.register(0, select.POLLIN | POLL_HANGUP)
+ # listen for stdout/stderr to be closed, if they are closed then my parent
+ # is gone and I should expire the command and myself.
+ poll.register(1, POLL_HANGUP)
+ poll.register(2, POLL_HANGUP)
+ # for SIGCHLD
+ poll.register(sigfdr, select.POLLIN)
+ buffer = bytearray()
+ expired = 0.0
+ while True:
+ if expired > 0.0:
+ since = time.monotonic() - expired
+ wait = int((timeout - since) * 1000.0)
+ if wait <= 0:
+ return
+ else:
+ wait = 5000
+ log.debug("polling for %d milliseconds", wait)
+ events = poll.poll(wait)
+ for fd, event in events:
+ log.debug("event: (%d, %d)", fd, event)
+ if handle_event(poll, buffer, fd, event, p):
+ if p.returncode is not None:
+ return
+ if expired == 0.0:
+ expired = time.monotonic()
+ log.info(
+ "expiration expected; waiting %d seconds for command to complete",
+ NS.timeout,
+ )
+
+
+if __name__ == "__main__":
+ signal.signal(signal.SIGPIPE, signal.SIG_IGN)
+ (sigfdr, sigfdw) = os.pipe2(os.O_NONBLOCK | os.O_CLOEXEC)
+ signal.set_wakeup_fd(sigfdw)
+
+ def do_nothing(signum, frame):
+ pass
+
+ signal.signal(signal.SIGCHLD, do_nothing)
+
+ P = argparse.ArgumentParser(
+ description="wait for stdin EOF then kill forked subcommand"
+ )
+ P.add_argument(
+ "--timeout",
+ action="store",
+ default=5,
+ help="time to wait for forked subcommand to willing terminate",
+ type=int,
+ )
+ P.add_argument("--debug", action="store", help="debug file", type=str)
+ P.add_argument(
+ "--signal",
+ action="store",
+ help="signal to send",
+ type=int,
+ default=signal.SIGKILL,
+ )
+ P.add_argument("--verbose", action="store_true", help="increase debugging")
+ P.add_argument(
+ "--setpgrp",
+ action="store",
+ choices=["no", "self", "child"],
+ default="self",
+ help="create process group",
+ )
+ P.add_argument(
+ "cmd", metavar="command", type=str, nargs=1, help="command to execute"
+ )
+ P.add_argument(
+ "args", metavar="arguments", type=str, nargs="*", help="arguments to command"
+ )
+ NS = P.parse_args()
+
+ logargs = {}
+ if NS.debug is not None:
+ logargs["filename"] = NS.debug
+ else:
+ logargs["stream"] = sys.stderr
+ if NS.verbose:
+ logargs["level"] = logging.DEBUG
+ else:
+ logargs["level"] = logging.INFO
+ logargs["format"] = f"%(asctime)s {NAME} %(levelname)s: %(message)s"
+ logargs["datefmt"] = "%Y-%m-%dT%H:%M:%S"
+ logging.basicConfig(**logargs)
+
+ cargs = NS.cmd + NS.args
+ popen_kwargs = {
+ "stdin": subprocess.PIPE,
+ }
+
+ if NS.setpgrp == "self":
+ os.setpgrp()
+ pgrp = os.getpgrp()
+ elif NS.setpgrp == "child":
+ popen_kwargs["preexec_fn"] = os.setpgrp
+ pgrp = None
+ elif NS.setpgrp == "no":
+ pgrp = 0
+ else:
+ assert False
+
+ log.debug("executing %s", cargs)
+ p = subprocess.Popen(cargs, **popen_kwargs)
+ if pgrp is None:
+ pgrp = p.pid
+ flags = fcntl.fcntl(p.stdin.fileno(), fcntl.F_GETFL)
+ fcntl.fcntl(p.stdin.fileno(), fcntl.F_SETFL, flags | os.O_NONBLOCK)
+
+ listen_for_events(sigfdr, p, NS.timeout)
+
+ if p.returncode is None:
+ log.error("timeout expired: sending signal %d to command and myself", NS.signal)
+ if pgrp == 0:
+ os.kill(p.pid, NS.signal)
+ else:
+ os.killpg(pgrp, NS.signal) # should kill me too
+ os.kill(os.getpid(), NS.signal) # to exit abnormally with same signal
+ log.error("signal did not cause termination, sending myself SIGKILL")
+ os.kill(os.getpid(), signal.SIGKILL) # failsafe
+ rc = p.returncode
+ log.debug("rc = %d", rc)
+ assert rc is not None
+ if rc < 0:
+ log.error("command terminated with signal %d: sending same signal to myself!", -rc)
+ os.kill(os.getpid(), -rc) # kill myself with the same signal
+ log.error("signal did not cause termination, sending myself SIGKILL")
+ os.kill(os.getpid(), signal.SIGKILL) # failsafe
+ else:
+ log.info("command exited with status %d: exiting normally with same code!", rc)
+ sys.exit(rc)