@@ -27,7 +27,7 @@
#
from httpx import AsyncClient, Auth
from io import StringIO
from io import BytesIO, StringIO, TextIOWrapper
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, \
HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT
@@ -35,15 +35,19 @@ from unittest.mock import patch, AsyncMock, Mock, mock_open
from . import BiteAuth, Board
import aioconsole.stream
import argparse
import asyncio
import contextlib
import functools
import io
import json
import os
import sys
import unittest
import urllib
import urllib.parse
import websockets
import wsfwd
def check_res_code(res):
if res.status_code == HTTP_401_UNAUTHORIZED:
@@ -98,6 +102,48 @@ def get_sshpubkey(fname):
raise IOError
# how to do stdin/stdout via async:
# https://github.com/vxgmichel/aioconsole/blob/master/aioconsole/stream.py#L130
async def fwd_data(reader, writer):
while True:
data = await reader.read(16384)
if data == b'':
writer.close()
await writer.wait_closed()
return
writer.write(data)
await writer.drain()
async def run_exec(baseurl, authkey, board, args):
url = urllib.parse.urljoin(baseurl, 'board/%s/exec' % urllib.parse.quote(board, safe=''))
stdin, stdout = await aioconsole.stream.get_standard_streams()
async with websockets.connect(url) as ws, wsfwd.WSFWDClient(ws.recv, ws.send) as client:
try:
await client.auth(dict(bearer=authkey))
proc = await client.exec(args=args)
toexec_task = asyncio.create_task(fwd_data(stdin, proc.stdin))
fromexec_task = asyncio.create_task(fwd_data(proc.stdout, stdout))
r = await proc.wait()
await toexec_task
await fromexec_task
sys.exit(r)
except RuntimeError as e:
print('failed to exec: %s' % e.args)
# not a fan of this, shouldn't be needed, but
# how tests are run w/ runAsyncMain, it is
# required here.
sys.stdout.flush()
sys.exit(1)
async def real_main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(title='subcommands',
@@ -114,12 +160,22 @@ async def real_main():
parser_set.add_argument('setvars', type=str, nargs='+', help='name of the board or class')
parser_set.add_argument('board', type=str, help='name of the board or class')
parser_exec = subparsers.add_parser('exec', help='run a program in the jail for a board')
parser_exec.add_argument('board', type=str, help='name of the board or class')
parser_exec.add_argument('prog', type=str, help='program to exec')
parser_exec.add_argument('args', type=str, nargs='*', help='arguments for program')
args = parser.parse_args()
#print(repr(args), file=sys.stderr)
#print(repr(args), file=sys.__ stderr__ )
baseurl = os.environ['BITELAB_URL']
authkey = os.environ['BITELAB_AUTH']
if args.subparser_name == 'exec':
await run_exec(baseurl, authkey, args.board,
[ args.prog ] + args.args)
sys.exit(0) #pragma: no cover
client = AsyncClient(base_url=baseurl)
try:
@@ -172,6 +228,166 @@ def main():
if __name__ == '__main__': #pragma: no cover
main()
class TestExecClient(unittest.IsolatedAsyncioTestCase):
patches = [
patch.dict(os.environ, dict(BITELAB_URL='http://someserver/')),
patch.dict(os.environ, dict(BITELAB_AUTH='thisisanapikey'))
]
async def asyncSetUp(self):
self.toclient = asyncio.Queue()
self.toserver = asyncio.Queue()
for i in self.patches:
i.start()
async def asyncTearDown(self):
self.assertTrue(self.toclient.empty())
self.assertTrue(self.toserver.empty())
for i in self.patches[::-1]:
i.stop()
@contextlib.contextmanager
def make_pipe(self):
r, w = os.pipe()
with os.fdopen(r, 'rb', buffering=65536) as readfl, os.fdopen(w, 'wb', buffering=65536) as writefl:
yield readfl, writefl
# too lazy to make this async since async file-like objects
# aren't standard yet in Python
def copytask(self, reader, writer, doclose=True):
while True:
data = reader.read(16384)
if not data:
if doclose:
writer.close()
return
writer.write(data)
async def runAsyncMain(self, fun=real_main, stdin=''):
# make stdin bytes
if isinstance(stdin, str):
stdin = stdin.encode()
# make stdout
stdout = io.BytesIO()
# Data path:
# stdin -> stdin_task -> stdinwriter -> pipe -> stdinreader -> sys.stdin -> real_main ->
# sys.stdout -> stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout
#
# How things get closed:
# stdin is already "closed", stdin_task will close stdinwriter when eof encountered (doclose).
# create the pipes needed
with self.make_pipe() as (stdinreader, stdinwriter), self.make_pipe() as (stdoutreader, stdoutwriter):
# setup the threads to move data
loop = asyncio.get_running_loop()
stdin_task = loop.run_in_executor(None, self.copytask, io.BytesIO(stdin), stdinwriter)
# do not close stdout, otherwise we cannot obtain the value
stdout_task = loop.run_in_executor(None, self.copytask, stdoutreader, stdout, False)
# insert the pipes
with patch.dict(sys.__dict__, dict(stdin=TextIOWrapper(stdinreader),
stdout=TextIOWrapper(stdoutwriter))):
try:
# run the function
await fun()
#await asyncio.wait_for(fun(), 1)
ret = 0 #pragma: no cover
except SystemExit as e:
ret = e.code
# No one to read anything anymore
stdinwriter.close()
stdinreader.close()
# No one to write anything anymore
stdoutwriter.close()
# make sure all the data has been copied
await asyncio.gather(stdin_task, stdout_task)
stdoutvalue = stdout.getvalue()
return ret, stdoutvalue
def setup_websockets_mock(self, webcon):
conobj = Mock()
webcon().__aenter__.return_value = conobj
conobj.send = self.toserver.put
conobj.recv = self.toclient.get
webcon.reset_mock()
@wsfwd.timeout(2)
async def test_exec_badauth(self):
class TestServer(wsfwd.WSFWDCommon):
async def handle_auth(self, msg):
raise RuntimeError('badauth')
server = TestServer(self.toserver.get, self.toclient.put)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \
patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon)
ret, stdout = await self.runAsyncMain()
webcon.assert_called_with(urllib.parse.urljoin('http://someserver/', 'board/cora-z7s/exec'))
await server.__aexit__(None, None, None)
await asyncio.sleep(.1)
self.assertEqual(stdout.decode(), 'failed to exec: Got auth error: \'badauth\'\n')
self.assertEqual(ret, 1)
@wsfwd.timeout(2)
async def test_exec(self):
class TestServer(wsfwd.WSFWDCommon):
async def echo_handler(self, stream, msg):
self.sendstream(stream, msg)
await self.drain(stream)
async def handle_auth(self, msg):
assert msg['auth']['bearer'] == 'thisisanapikey'
async def handle_chanclose(self, msg):
self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='chanclose', chan=self._stdout_stream))))
async def handle_exec(self, msg):
self._stdout_stream = msg['stdout']
assert msg['args'] == [ 'program', 'arg1', 'arg2' ]
self.add_stream_handler(msg['stdin'],
functools.partial(self.echo_handler,
msg['stdout']))
# pretend it's done immediately
self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='exit', code=0))))
server = TestServer(self.toserver.get, self.toclient.put)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \
patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon)
inpdata = bytes(range(0, 255))
ret, stdout = await self.runAsyncMain(stdin=inpdata)
await server.__aexit__(None, None, None)
self.assertEqual(stdout, inpdata)
self.assertEqual(ret, 0)
@patch.dict(os.environ, dict(BITELAB_URL='http://someserver/'))
@patch.dict(os.environ, dict(BITELAB_AUTH='thisisanapikey'))
class TestClient(unittest.TestCase):
@@ -187,10 +403,12 @@ class TestClient(unittest.TestCase):
self.acp = self.ac.return_value.post = AsyncMock()
self.acpr = self.acp.return_value = Mock()
def runMain(self, fun=main):
def runMain(self, fun=main, stdin='' ):
try:
stdout = StringIO()
with patch.dict(sys.__dict__, dict(stdout=stdout)):
stdin = StringIO(stdin)
with patch.dict(sys.__dict__, dict(stdout=stdout,
stdin=stdin)):
fun()
ret = 0
@@ -214,6 +432,11 @@ class TestClient(unittest.TestCase):
fname = os.path.expanduser(os.path.join('~', '.ssh', i + '.pub'))
mock_file.assert_any_call(fname)
with patch('builtins.open') as mock_file:
mock_file.side_effect = IOError()
with self.assertRaises(IOError):
get_sshpubkey(None)
@patch.dict(sys.__dict__, dict(argv=[ '', 'list' ]))
def test_list_failure(self):
ac = self.ac