diff --git a/bitelab/__main__.py b/bitelab/__main__.py index 95adc71..59a49bb 100644 --- a/bitelab/__main__.py +++ b/bitelab/__main__.py @@ -27,7 +27,7 @@ # from httpx import AsyncClient, Auth -from io import StringIO +from io import BytesIO, StringIO, TextIOWrapper from starlette.status import HTTP_200_OK from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, \ HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT @@ -35,15 +35,19 @@ from unittest.mock import patch, AsyncMock, Mock, mock_open from . import BiteAuth, Board +import aioconsole.stream import argparse import asyncio import contextlib +import functools import io import json import os import sys import unittest -import urllib +import urllib.parse +import websockets +import wsfwd def check_res_code(res): if res.status_code == HTTP_401_UNAUTHORIZED: @@ -98,6 +102,48 @@ def get_sshpubkey(fname): raise IOError +# how to do stdin/stdout via async: +# https://github.com/vxgmichel/aioconsole/blob/master/aioconsole/stream.py#L130 +async def fwd_data(reader, writer): + while True: + data = await reader.read(16384) + if data == b'': + writer.close() + await writer.wait_closed() + return + + writer.write(data) + + await writer.drain() + +async def run_exec(baseurl, authkey, board, args): + url = urllib.parse.urljoin(baseurl, 'board/%s/exec' % urllib.parse.quote(board, safe='')) + stdin, stdout = await aioconsole.stream.get_standard_streams() + + async with websockets.connect(url) as ws, wsfwd.WSFWDClient(ws.recv, ws.send) as client: + try: + await client.auth(dict(bearer=authkey)) + + proc = await client.exec(args=args) + + toexec_task = asyncio.create_task(fwd_data(stdin, proc.stdin)) + fromexec_task = asyncio.create_task(fwd_data(proc.stdout, stdout)) + + r = await proc.wait() + + await toexec_task + await fromexec_task + sys.exit(r) + except RuntimeError as e: + print('failed to exec: %s' % e.args) + + # not a fan of this, shouldn't be needed, but + # how tests are run w/ runAsyncMain, it is + # required here. + sys.stdout.flush() + + sys.exit(1) + async def real_main(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(title='subcommands', @@ -114,12 +160,22 @@ async def real_main(): parser_set.add_argument('setvars', type=str, nargs='+', help='name of the board or class') parser_set.add_argument('board', type=str, help='name of the board or class') + parser_exec = subparsers.add_parser('exec', help='run a program in the jail for a board') + parser_exec.add_argument('board', type=str, help='name of the board or class') + parser_exec.add_argument('prog', type=str, help='program to exec') + parser_exec.add_argument('args', type=str, nargs='*', help='arguments for program') + args = parser.parse_args() - #print(repr(args), file=sys.stderr) + #print(repr(args), file=sys.__stderr__) baseurl = os.environ['BITELAB_URL'] authkey = os.environ['BITELAB_AUTH'] + if args.subparser_name == 'exec': + await run_exec(baseurl, authkey, args.board, + [ args.prog ] + args.args) + sys.exit(0) #pragma: no cover + client = AsyncClient(base_url=baseurl) try: @@ -172,6 +228,166 @@ def main(): if __name__ == '__main__': #pragma: no cover main() +class TestExecClient(unittest.IsolatedAsyncioTestCase): + patches = [ + patch.dict(os.environ, dict(BITELAB_URL='http://someserver/')), + patch.dict(os.environ, dict(BITELAB_AUTH='thisisanapikey')) + ] + + async def asyncSetUp(self): + self.toclient = asyncio.Queue() + self.toserver = asyncio.Queue() + + for i in self.patches: + i.start() + + async def asyncTearDown(self): + self.assertTrue(self.toclient.empty()) + self.assertTrue(self.toserver.empty()) + + for i in self.patches[::-1]: + i.stop() + + @contextlib.contextmanager + def make_pipe(self): + r, w = os.pipe() + with os.fdopen(r, 'rb', buffering=65536) as readfl, os.fdopen(w, 'wb', buffering=65536) as writefl: + yield readfl, writefl + + # too lazy to make this async since async file-like objects + # aren't standard yet in Python + def copytask(self, reader, writer, doclose=True): + while True: + data = reader.read(16384) + if not data: + if doclose: + writer.close() + return + + writer.write(data) + + async def runAsyncMain(self, fun=real_main, stdin=''): + # make stdin bytes + if isinstance(stdin, str): + stdin = stdin.encode() + + # make stdout + stdout = io.BytesIO() + + # Data path: + # stdin -> stdin_task -> stdinwriter -> pipe -> stdinreader -> sys.stdin -> real_main -> + # sys.stdout -> stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout + # + # How things get closed: + # stdin is already "closed", stdin_task will close stdinwriter when eof encountered (doclose). + + # create the pipes needed + with self.make_pipe() as (stdinreader, stdinwriter), self.make_pipe() as (stdoutreader, stdoutwriter): + # setup the threads to move data + loop = asyncio.get_running_loop() + stdin_task = loop.run_in_executor(None, self.copytask, io.BytesIO(stdin), stdinwriter) + + # do not close stdout, otherwise we cannot obtain the value + stdout_task = loop.run_in_executor(None, self.copytask, stdoutreader, stdout, False) + + # insert the pipes + with patch.dict(sys.__dict__, dict(stdin=TextIOWrapper(stdinreader), + stdout=TextIOWrapper(stdoutwriter))): + try: + # run the function + await fun() + #await asyncio.wait_for(fun(), 1) + ret = 0 #pragma: no cover + except SystemExit as e: + ret = e.code + + # No one to read anything anymore + stdinwriter.close() + stdinreader.close() + + # No one to write anything anymore + stdoutwriter.close() + + # make sure all the data has been copied + await asyncio.gather(stdin_task, stdout_task) + + stdoutvalue = stdout.getvalue() + + return ret, stdoutvalue + + def setup_websockets_mock(self, webcon): + conobj = Mock() + + webcon().__aenter__.return_value = conobj + + conobj.send = self.toserver.put + conobj.recv = self.toclient.get + + webcon.reset_mock() + + @wsfwd.timeout(2) + async def test_exec_badauth(self): + class TestServer(wsfwd.WSFWDCommon): + async def handle_auth(self, msg): + raise RuntimeError('badauth') + + server = TestServer(self.toserver.get, self.toclient.put) + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \ + patch('websockets.connect') as webcon: + self.setup_websockets_mock(webcon) + + ret, stdout = await self.runAsyncMain() + + webcon.assert_called_with(urllib.parse.urljoin('http://someserver/', 'board/cora-z7s/exec')) + + await server.__aexit__(None, None, None) + + await asyncio.sleep(.1) + self.assertEqual(stdout.decode(), 'failed to exec: Got auth error: \'badauth\'\n') + + self.assertEqual(ret, 1) + + @wsfwd.timeout(2) + async def test_exec(self): + class TestServer(wsfwd.WSFWDCommon): + async def echo_handler(self, stream, msg): + self.sendstream(stream, msg) + await self.drain(stream) + + async def handle_auth(self, msg): + assert msg['auth']['bearer'] == 'thisisanapikey' + + async def handle_chanclose(self, msg): + self.add_tasks(asyncio.create_task(self.sendcmd( + dict(cmd='chanclose', chan=self._stdout_stream)))) + + async def handle_exec(self, msg): + self._stdout_stream = msg['stdout'] + assert msg['args'] == [ 'program', 'arg1', 'arg2' ] + self.add_stream_handler(msg['stdin'], + functools.partial(self.echo_handler, + msg['stdout'])) + + # pretend it's done immediately + self.add_tasks(asyncio.create_task(self.sendcmd( + dict(cmd='exit', code=0)))) + + server = TestServer(self.toserver.get, self.toclient.put) + + with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \ + patch('websockets.connect') as webcon: + self.setup_websockets_mock(webcon) + + inpdata = bytes(range(0, 255)) + + ret, stdout = await self.runAsyncMain(stdin=inpdata) + + await server.__aexit__(None, None, None) + + self.assertEqual(stdout, inpdata) + + self.assertEqual(ret, 0) + @patch.dict(os.environ, dict(BITELAB_URL='http://someserver/')) @patch.dict(os.environ, dict(BITELAB_AUTH='thisisanapikey')) class TestClient(unittest.TestCase): @@ -187,10 +403,12 @@ class TestClient(unittest.TestCase): self.acp = self.ac.return_value.post = AsyncMock() self.acpr = self.acp.return_value = Mock() - def runMain(self, fun=main): + def runMain(self, fun=main, stdin=''): try: stdout = StringIO() - with patch.dict(sys.__dict__, dict(stdout=stdout)): + stdin = StringIO(stdin) + with patch.dict(sys.__dict__, dict(stdout=stdout, + stdin=stdin)): fun() ret = 0 @@ -214,6 +432,11 @@ class TestClient(unittest.TestCase): fname = os.path.expanduser(os.path.join('~', '.ssh', i + '.pub')) mock_file.assert_any_call(fname) + with patch('builtins.open') as mock_file: + mock_file.side_effect = IOError() + with self.assertRaises(IOError): + get_sshpubkey(None) + @patch.dict(sys.__dict__, dict(argv=[ '', 'list' ])) def test_list_failure(self): ac = self.ac diff --git a/bitelab/testing.py b/bitelab/testing.py index abfd15c..3d8f00c 100644 --- a/bitelab/testing.py +++ b/bitelab/testing.py @@ -32,4 +32,4 @@ from .snmp import TestSNMPPower, TestSNMPWrapper from .data import TestDatabase from . import TestBiteLab, TestUnhashLRU, TestAttrs, TestBoardImpl from . import TestWebSocket, TestLogEvent -from .__main__ import TestClient +from .__main__ import TestClient, TestExecClient diff --git a/setup.py b/setup.py index 816cb25..7ea11cd 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ setup( 'wsfwd @ git+https://www.funkthat.com/gitea/jmg/wsfwd.git', 'orm', 'ucl', + 'aioconsole', # for aioconsole.stream only 'databases[sqlite]', ], extras_require = {