Browse Source

Implement tests for bitelab exec board cmd.

This was pretty tricky to get working w/ the pipes and mocking
stdin/stdout, but w/ some pipes it's working..  There is an unclosed
transport that needs to be investigated though.
main
John-Mark Gurney 3 years ago
parent
commit
6973b1d08d
3 changed files with 230 additions and 6 deletions
  1. +228
    -5
      bitelab/__main__.py
  2. +1
    -1
      bitelab/testing.py
  3. +1
    -0
      setup.py

+ 228
- 5
bitelab/__main__.py View File

@@ -27,7 +27,7 @@
#

from httpx import AsyncClient, Auth
from io import StringIO
from io import BytesIO, StringIO, TextIOWrapper
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, \
HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT
@@ -35,15 +35,19 @@ from unittest.mock import patch, AsyncMock, Mock, mock_open

from . import BiteAuth, Board

import aioconsole.stream
import argparse
import asyncio
import contextlib
import functools
import io
import json
import os
import sys
import unittest
import urllib
import urllib.parse
import websockets
import wsfwd

def check_res_code(res):
if res.status_code == HTTP_401_UNAUTHORIZED:
@@ -98,6 +102,48 @@ def get_sshpubkey(fname):

raise IOError

# how to do stdin/stdout via async:
# https://github.com/vxgmichel/aioconsole/blob/master/aioconsole/stream.py#L130
async def fwd_data(reader, writer):
while True:
data = await reader.read(16384)
if data == b'':
writer.close()
await writer.wait_closed()
return

writer.write(data)

await writer.drain()

async def run_exec(baseurl, authkey, board, args):
url = urllib.parse.urljoin(baseurl, 'board/%s/exec' % urllib.parse.quote(board, safe=''))
stdin, stdout = await aioconsole.stream.get_standard_streams()

async with websockets.connect(url) as ws, wsfwd.WSFWDClient(ws.recv, ws.send) as client:
try:
await client.auth(dict(bearer=authkey))

proc = await client.exec(args=args)

toexec_task = asyncio.create_task(fwd_data(stdin, proc.stdin))
fromexec_task = asyncio.create_task(fwd_data(proc.stdout, stdout))

r = await proc.wait()

await toexec_task
await fromexec_task
sys.exit(r)
except RuntimeError as e:
print('failed to exec: %s' % e.args)

# not a fan of this, shouldn't be needed, but
# how tests are run w/ runAsyncMain, it is
# required here.
sys.stdout.flush()

sys.exit(1)

async def real_main():
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(title='subcommands',
@@ -114,12 +160,22 @@ async def real_main():
parser_set.add_argument('setvars', type=str, nargs='+', help='name of the board or class')
parser_set.add_argument('board', type=str, help='name of the board or class')

parser_exec = subparsers.add_parser('exec', help='run a program in the jail for a board')
parser_exec.add_argument('board', type=str, help='name of the board or class')
parser_exec.add_argument('prog', type=str, help='program to exec')
parser_exec.add_argument('args', type=str, nargs='*', help='arguments for program')

args = parser.parse_args()
#print(repr(args), file=sys.stderr)
#print(repr(args), file=sys.__stderr__)

baseurl = os.environ['BITELAB_URL']
authkey = os.environ['BITELAB_AUTH']

if args.subparser_name == 'exec':
await run_exec(baseurl, authkey, args.board,
[ args.prog ] + args.args)
sys.exit(0) #pragma: no cover

client = AsyncClient(base_url=baseurl)

try:
@@ -172,6 +228,166 @@ def main():
if __name__ == '__main__': #pragma: no cover
main()

class TestExecClient(unittest.IsolatedAsyncioTestCase):
patches = [
patch.dict(os.environ, dict(BITELAB_URL='http://someserver/')),
patch.dict(os.environ, dict(BITELAB_AUTH='thisisanapikey'))
]

async def asyncSetUp(self):
self.toclient = asyncio.Queue()
self.toserver = asyncio.Queue()

for i in self.patches:
i.start()

async def asyncTearDown(self):
self.assertTrue(self.toclient.empty())
self.assertTrue(self.toserver.empty())

for i in self.patches[::-1]:
i.stop()

@contextlib.contextmanager
def make_pipe(self):
r, w = os.pipe()
with os.fdopen(r, 'rb', buffering=65536) as readfl, os.fdopen(w, 'wb', buffering=65536) as writefl:
yield readfl, writefl

# too lazy to make this async since async file-like objects
# aren't standard yet in Python
def copytask(self, reader, writer, doclose=True):
while True:
data = reader.read(16384)
if not data:
if doclose:
writer.close()
return

writer.write(data)

async def runAsyncMain(self, fun=real_main, stdin=''):
# make stdin bytes
if isinstance(stdin, str):
stdin = stdin.encode()

# make stdout
stdout = io.BytesIO()

# Data path:
# stdin -> stdin_task -> stdinwriter -> pipe -> stdinreader -> sys.stdin -> real_main ->
# sys.stdout -> stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout
#
# How things get closed:
# stdin is already "closed", stdin_task will close stdinwriter when eof encountered (doclose).

# create the pipes needed
with self.make_pipe() as (stdinreader, stdinwriter), self.make_pipe() as (stdoutreader, stdoutwriter):
# setup the threads to move data
loop = asyncio.get_running_loop()
stdin_task = loop.run_in_executor(None, self.copytask, io.BytesIO(stdin), stdinwriter)

# do not close stdout, otherwise we cannot obtain the value
stdout_task = loop.run_in_executor(None, self.copytask, stdoutreader, stdout, False)

# insert the pipes
with patch.dict(sys.__dict__, dict(stdin=TextIOWrapper(stdinreader),
stdout=TextIOWrapper(stdoutwriter))):
try:
# run the function
await fun()
#await asyncio.wait_for(fun(), 1)
ret = 0 #pragma: no cover
except SystemExit as e:
ret = e.code

# No one to read anything anymore
stdinwriter.close()
stdinreader.close()

# No one to write anything anymore
stdoutwriter.close()

# make sure all the data has been copied
await asyncio.gather(stdin_task, stdout_task)

stdoutvalue = stdout.getvalue()

return ret, stdoutvalue

def setup_websockets_mock(self, webcon):
conobj = Mock()

webcon().__aenter__.return_value = conobj

conobj.send = self.toserver.put
conobj.recv = self.toclient.get

webcon.reset_mock()

@wsfwd.timeout(2)
async def test_exec_badauth(self):
class TestServer(wsfwd.WSFWDCommon):
async def handle_auth(self, msg):
raise RuntimeError('badauth')

server = TestServer(self.toserver.get, self.toclient.put)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \
patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon)

ret, stdout = await self.runAsyncMain()

webcon.assert_called_with(urllib.parse.urljoin('http://someserver/', 'board/cora-z7s/exec'))

await server.__aexit__(None, None, None)

await asyncio.sleep(.1)
self.assertEqual(stdout.decode(), 'failed to exec: Got auth error: \'badauth\'\n')

self.assertEqual(ret, 1)

@wsfwd.timeout(2)
async def test_exec(self):
class TestServer(wsfwd.WSFWDCommon):
async def echo_handler(self, stream, msg):
self.sendstream(stream, msg)
await self.drain(stream)

async def handle_auth(self, msg):
assert msg['auth']['bearer'] == 'thisisanapikey'

async def handle_chanclose(self, msg):
self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='chanclose', chan=self._stdout_stream))))

async def handle_exec(self, msg):
self._stdout_stream = msg['stdout']
assert msg['args'] == [ 'program', 'arg1', 'arg2' ]
self.add_stream_handler(msg['stdin'],
functools.partial(self.echo_handler,
msg['stdout']))

# pretend it's done immediately
self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='exit', code=0))))

server = TestServer(self.toserver.get, self.toclient.put)

with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \
patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon)

inpdata = bytes(range(0, 255))

ret, stdout = await self.runAsyncMain(stdin=inpdata)

await server.__aexit__(None, None, None)

self.assertEqual(stdout, inpdata)

self.assertEqual(ret, 0)

@patch.dict(os.environ, dict(BITELAB_URL='http://someserver/'))
@patch.dict(os.environ, dict(BITELAB_AUTH='thisisanapikey'))
class TestClient(unittest.TestCase):
@@ -187,10 +403,12 @@ class TestClient(unittest.TestCase):
self.acp = self.ac.return_value.post = AsyncMock()
self.acpr = self.acp.return_value = Mock()

def runMain(self, fun=main):
def runMain(self, fun=main, stdin=''):
try:
stdout = StringIO()
with patch.dict(sys.__dict__, dict(stdout=stdout)):
stdin = StringIO(stdin)
with patch.dict(sys.__dict__, dict(stdout=stdout,
stdin=stdin)):
fun()

ret = 0
@@ -214,6 +432,11 @@ class TestClient(unittest.TestCase):
fname = os.path.expanduser(os.path.join('~', '.ssh', i + '.pub'))
mock_file.assert_any_call(fname)

with patch('builtins.open') as mock_file:
mock_file.side_effect = IOError()
with self.assertRaises(IOError):
get_sshpubkey(None)

@patch.dict(sys.__dict__, dict(argv=[ '', 'list' ]))
def test_list_failure(self):
ac = self.ac


+ 1
- 1
bitelab/testing.py View File

@@ -32,4 +32,4 @@ from .snmp import TestSNMPPower, TestSNMPWrapper
from .data import TestDatabase
from . import TestBiteLab, TestUnhashLRU, TestAttrs, TestBoardImpl
from . import TestWebSocket, TestLogEvent
from .__main__ import TestClient
from .__main__ import TestClient, TestExecClient

+ 1
- 0
setup.py View File

@@ -25,6 +25,7 @@ setup(
'wsfwd @ git+https://www.funkthat.com/gitea/jmg/wsfwd.git',
'orm',
'ucl',
'aioconsole', # for aioconsole.stream only
'databases[sqlite]',
],
extras_require = {


Loading…
Cancel
Save