@@ -31,8 +31,10 @@ from dataclasses import dataclass
from functools import lru_cache, wraps
from functools import lru_cache, wraps
from io import StringIO
from io import StringIO
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request
from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException
from fastapi import Path, Request
from fastapi.security import OAuth2PasswordBearer
from fastapi.security import OAuth2PasswordBearer
from fastapi.websockets import WebSocket
from httpx import AsyncClient, Auth
from httpx import AsyncClient, Auth
from starlette.responses import JSONResponse
from starlette.responses import JSONResponse
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_200_OK
@@ -40,6 +42,11 @@ from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, \
HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT
HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from unittest.mock import create_autospec, patch, AsyncMock, Mock, PropertyMock
from unittest.mock import create_autospec, patch, AsyncMock, Mock, PropertyMock
from wsfwd import WSFWDServer, WSFWDClient, timeout, _tbprinter
# For WebSocket testing
from hypercorn.config import Config
from hypercorn.asyncio import serve
from . import config
from . import config
from .data import *
from .data import *
@@ -53,6 +60,7 @@ import json
import logging
import logging
import orm
import orm
import os
import os
import shutil
import socket
import socket
import sqlite3
import sqlite3
import subprocess
import subprocess
@@ -62,6 +70,7 @@ import time
import ucl
import ucl
import unittest
import unittest
import urllib
import urllib
import websockets
# fix up parse_socket_addr for hypercorn
# fix up parse_socket_addr for hypercorn
from hypercorn.utils import parse_socket_addr
from hypercorn.utils import parse_socket_addr
@@ -438,6 +447,111 @@ async def reserve_board(board_id_or_class,
return brd
return brd
class HandleExec(WSFWDServer):
def __init__(self, *args, board_id, data, **kwargs):
super().__init__(*args, **kwargs)
self._board_id = board_id
self._data = data
self._auth_user = None
self._did_exec = False
self._finish_handler = asyncio.Event()
async def handle_auth(self, msg):
try:
user = await lookup_user(msg['auth']['bearer'],
self._data)
except Exception:
raise RuntimeError('invalid token')
self._auth_user = user
async def shutdown(self):
pass
async def process_stdin(self, data):
stdin = self._proc.stdin
stdin.write(data)
await stdin.drain()
async def process_stdout(self):
stdout = self._proc.stdout
stream = self._stdout_stream
try:
while True:
data = await stdout.read(16384)
if not data:
break
self.sendstream(stream, data)
await self.drain(stream)
finally:
await self.sendcmd(dict(cmd='chanclose', chan=stream))
async def process_proc_wait(self):
# Wait for process to exit
code = await self._proc.wait()
await self.sendcmd(dict(cmd='exit', code=code))
# Make sure that all stdout is sent
await self._stdout_task
await self._stdin_event.wait()
self._finish_handler.set()
async def handle_chanclose(self, msg):
self.clear_stream_handler(self._stdin_stream)
self._proc.stdin.close()
await self._proc.stdin.wait_closed()
self._stdin_event.set()
async def handle_exec(self, msg):
if self._did_exec:
raise RuntimeError('already did exec')
if self._auth_user is None:
raise RuntimeError('not authenticated')
self._proc = await asyncio.create_subprocess_exec(
'jexec', self._board_id, *msg['args'],
stdin=subprocess.PIPE, stdout=subprocess.PIPE)
self._did_exec = True
self._stdin_stream = msg['stdin']
self._stdout_stream = msg['stdout']
# handle stdin
self._stdin_event = asyncio.Event()
self.add_stream_handler(msg['stdin'], self.process_stdin)
# handle stdout
self._stdout_task = asyncio.create_task(self.process_stdout())
# handle process exit
self._proc_wait_task = asyncio.create_task(self.process_proc_wait())
async def get_finish_handler(self):
return await self._finish_handler.wait()
@router.websocket("/board/{board_id}/exec")
async def board_exec_ws(
board_id,
websocket: WebSocket,
brdmgr: BoardManager = Depends(get_boardmanager),
settings: config.Settings = Depends(get_settings),
data: data.DataWrapper = Depends(get_data)):
await websocket.accept()
try:
async with HandleExec(websocket.receive_bytes,
websocket.send_bytes, data=data,
board_id=board_id) as server:
await server.get_finish_handler()
finally:
await websocket.close()
@router.post('/board/{board_id}/release', response_model=Union[Board, Error])
@router.post('/board/{board_id}/release', response_model=Union[Board, Error])
async def release_board(board_id, user: str = Depends(lookup_user),
async def release_board(board_id, user: str = Depends(lookup_user),
brdmgr: BoardManager = Depends(get_boardmanager),
brdmgr: BoardManager = Depends(get_boardmanager),
@@ -558,16 +672,7 @@ class TestUnhashLRU(unittest.TestCase):
# does not return the same object as the first cache
# does not return the same object as the first cache
self.assertIsNot(cachefun(lsta), cachefun2(lsta))
self.assertIsNot(cachefun(lsta), cachefun2(lsta))
# Per RFC 5737 (https://tools.ietf.org/html/rfc5737):
# The blocks 192.0.2.0/24 (TEST-NET-1), 198.51.100.0/24 (TEST-NET-2),
# and 203.0.113.0/24 (TEST-NET-3) are provided for use in
# documentation.
# Note: this will not work under python before 3.8 before
# IsolatedAsyncioTestCase was added. The tearDown has to happen
# with the event loop running, otherwise the task and other things
# do not get cleaned up properly.
class TestBiteLab(unittest.IsolatedAsyncioTestCase):
class TestCommon(unittest.IsolatedAsyncioTestCase):
def get_settings_override(self):
def get_settings_override(self):
return self.settings
return self.settings
@@ -599,14 +704,132 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase):
self.app.dependency_overrides[get_settings] = \
self.app.dependency_overrides[get_settings] = \
self.get_settings_override
self.get_settings_override
self.app.dependency_overrides[get_data] = self.get_data_override
self.app.dependency_overrides[get_data] = self.get_data_override
self.app.dependency_overrides[get_boardmanager] = self.get_boardmanager_override
self.app.dependency_overrides[get_boardmanager] = \
self.get_boardmanager_override
# This is a different class then the other tests, as at the time of
# writing, there is no async WebSocket client that will talk directly
# to an ASGI server. The websockets client library can talk to a unix
# domain socket, so that is used.
class TestWebSocket(TestCommon):
async def asyncSetUp(self):
await super().asyncSetUp()
d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d
self.shutdown_event = asyncio.Event()
self.socketpath = os.path.join(self.basetempdir, 'wstest.sock')
config = Config()
config.graceful_timeout = .01
config.bind = [ 'unix:' + self.socketpath ]
config.loglevel = 'ERROR'
self.serv_task = asyncio.create_task(serve(self.app, config,
shutdown_trigger=self.shutdown_event.wait))
# get the unix domain socket connected
# need a startup_trigger
await asyncio.sleep(.01)
async def asyncTearDown(self):
self.app = None
self.shutdown_event.set()
await self.serv_task
shutil.rmtree(self.basetempdir)
self.basetempdir = None
@patch('asyncio.create_subprocess_exec')
@timeout(2)
async def test_exec_sshd(self, cse):
def wrapper(corofun):
async def foo(*args, **kwargs):
r = await corofun(*args, **kwargs)
#print('foo:', repr(corofun), repr((args, kwargs)), repr(r))
return r
return foo
async with websockets.connect('ws://foo/board/cora-1/exec',
path=self.socketpath) as websocket, \
WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client:
mstdout = AsyncMock()
cmdargs = [ 'sshd', '-i' ]
# that w/o auth, it fails
with self.assertRaises(RuntimeError):
await client.exec(cmdargs, stdin=1, stdout=2)
# that and invalid token fails
with self.assertRaises(RuntimeError):
await client.auth(dict(bearer='invalidtoken'))
# that a valid auth token works
await client.auth(dict(bearer='thisisanapikey'))
# XXX - enforce board reservation and correct user
echodata = b'somedata'
wrap_subprocess_exec(cse, stdout=echodata, retcode=0)
client.add_stream_handler(2, mstdout)
proc = await client.exec([ 'sshd', '-i' ], stdin=1, stdout=2)
with self.assertRaises(RuntimeError):
await client.exec([ 'sshd', '-i' ], stdin=1, stdout=2)
stdin, stdout = proc.stdin, proc.stdout
stdin.write(echodata)
await stdin.drain()
# that we get our data
self.assertEqual(await stdout.read(len(echodata)), echodata)
# and that there is no more
self.assertEqual(await stdout.read(len(echodata)), b'')
# and we are truly at EOF
self.assertTrue(stdout.at_eof())
stdin.close()
await stdin.wait_closed()
await proc.wait()
cse.assert_called_with('jexec', 'cora-1', *cmdargs,
stdin=subprocess.PIPE, stdout=subprocess.PIPE)
# spin things, not sure best way to handle this
await asyncio.sleep(.01)
cse.return_value.stdin.close.assert_called_with()
# Per RFC 5737 (https://tools.ietf.org/html/rfc5737):
# The blocks 192.0.2.0/24 (TEST-NET-1), 198.51.100.0/24 (TEST-NET-2),
# and 203.0.113.0/24 (TEST-NET-3) are provided for use in
# documentation.
# Note: this will not work under python before 3.8 before
# IsolatedAsyncioTestCase was added. The tearDown has to happen
# with the event loop running, otherwise the task and other things
# do not get cleaned up properly.
class TestBiteLab(TestCommon):
async def asyncSetUp(self):
await super().asyncSetUp()
self.client = AsyncClient(app=self.app,
self.client = AsyncClient(app=self.app,
base_url='http://testserver')
base_url='http://testserver')
def tearDown(self):
async def asyncT earDown(self):
self.app = None
self.app = None
asyncio.run(self.client.aclose())
await self.client.aclose( )
self.client = None
self.client = None
async def test_basic(self):
async def test_basic(self):