diff --git a/NOTES.md b/NOTES.md index d465a15..be4700a 100644 --- a/NOTES.md +++ b/NOTES.md @@ -2,3 +2,16 @@ Issues to address ================= If app crashes, won't sync w/ board status in database. + + +CONNECT proxy +============= + +RFC definition: https://tools.ietf.org/html/rfc2817#section-5.2 + +FastAPI StreamingResponse: https://fastapi.tiangolo.com/advanced/custom-response/#streamingresponse + +Streaming Requests: https://github.com/tiangolo/fastapi/issues/58 + +Brief tests showed that this may not work reliably. I think WebSockets is +the best answer for this. diff --git a/bitelab/__init__.py b/bitelab/__init__.py index dcac91c..8efb6ef 100644 --- a/bitelab/__init__.py +++ b/bitelab/__init__.py @@ -31,8 +31,10 @@ from dataclasses import dataclass from functools import lru_cache, wraps from io import StringIO -from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException, Request +from fastapi import APIRouter, Body, Depends, FastAPI, HTTPException +from fastapi import Path, Request from fastapi.security import OAuth2PasswordBearer +from fastapi.websockets import WebSocket from httpx import AsyncClient, Auth from starlette.responses import JSONResponse from starlette.status import HTTP_200_OK @@ -40,6 +42,11 @@ from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, \ HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from unittest.mock import create_autospec, patch, AsyncMock, Mock, PropertyMock +from wsfwd import WSFWDServer, WSFWDClient, timeout, _tbprinter + +# For WebSocket testing +from hypercorn.config import Config +from hypercorn.asyncio import serve from . import config from .data import * @@ -53,6 +60,7 @@ import json import logging import orm import os +import shutil import socket import sqlite3 import subprocess @@ -62,6 +70,7 @@ import time import ucl import unittest import urllib +import websockets # fix up parse_socket_addr for hypercorn from hypercorn.utils import parse_socket_addr @@ -438,6 +447,111 @@ async def reserve_board(board_id_or_class, return brd +class HandleExec(WSFWDServer): + def __init__(self, *args, board_id, data, **kwargs): + super().__init__(*args, **kwargs) + + self._board_id = board_id + self._data = data + self._auth_user = None + self._did_exec = False + + self._finish_handler = asyncio.Event() + + async def handle_auth(self, msg): + try: + user = await lookup_user(msg['auth']['bearer'], + self._data) + except Exception: + raise RuntimeError('invalid token') + + self._auth_user = user + + async def shutdown(self): + pass + + async def process_stdin(self, data): + stdin = self._proc.stdin + stdin.write(data) + await stdin.drain() + + async def process_stdout(self): + stdout = self._proc.stdout + stream = self._stdout_stream + + try: + while True: + data = await stdout.read(16384) + if not data: + break + self.sendstream(stream, data) + await self.drain(stream) + finally: + await self.sendcmd(dict(cmd='chanclose', chan=stream)) + + async def process_proc_wait(self): + # Wait for process to exit + code = await self._proc.wait() + await self.sendcmd(dict(cmd='exit', code=code)) + + # Make sure that all stdout is sent + await self._stdout_task + + await self._stdin_event.wait() + + self._finish_handler.set() + + async def handle_chanclose(self, msg): + self.clear_stream_handler(self._stdin_stream) + self._proc.stdin.close() + await self._proc.stdin.wait_closed() + self._stdin_event.set() + + async def handle_exec(self, msg): + if self._did_exec: + raise RuntimeError('already did exec') + + if self._auth_user is None: + raise RuntimeError('not authenticated') + + self._proc = await asyncio.create_subprocess_exec( + 'jexec', self._board_id, *msg['args'], + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + + self._did_exec = True + + self._stdin_stream = msg['stdin'] + self._stdout_stream = msg['stdout'] + + # handle stdin + self._stdin_event = asyncio.Event() + self.add_stream_handler(msg['stdin'], self.process_stdin) + + # handle stdout + self._stdout_task = asyncio.create_task(self.process_stdout()) + + # handle process exit + self._proc_wait_task = asyncio.create_task(self.process_proc_wait()) + + async def get_finish_handler(self): + return await self._finish_handler.wait() + +@router.websocket("/board/{board_id}/exec") +async def board_exec_ws( + board_id, + websocket: WebSocket, + brdmgr: BoardManager = Depends(get_boardmanager), + settings: config.Settings = Depends(get_settings), + data: data.DataWrapper = Depends(get_data)): + await websocket.accept() + try: + async with HandleExec(websocket.receive_bytes, + websocket.send_bytes, data=data, + board_id=board_id) as server: + await server.get_finish_handler() + finally: + await websocket.close() + @router.post('/board/{board_id}/release', response_model=Union[Board, Error]) async def release_board(board_id, user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager), @@ -558,16 +672,7 @@ class TestUnhashLRU(unittest.TestCase): # does not return the same object as the first cache self.assertIsNot(cachefun(lsta), cachefun2(lsta)) -# Per RFC 5737 (https://tools.ietf.org/html/rfc5737): -# The blocks 192.0.2.0/24 (TEST-NET-1), 198.51.100.0/24 (TEST-NET-2), -# and 203.0.113.0/24 (TEST-NET-3) are provided for use in -# documentation. - -# Note: this will not work under python before 3.8 before -# IsolatedAsyncioTestCase was added. The tearDown has to happen -# with the event loop running, otherwise the task and other things -# do not get cleaned up properly. -class TestBiteLab(unittest.IsolatedAsyncioTestCase): +class TestCommon(unittest.IsolatedAsyncioTestCase): def get_settings_override(self): return self.settings @@ -599,14 +704,132 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): self.app.dependency_overrides[get_settings] = \ self.get_settings_override self.app.dependency_overrides[get_data] = self.get_data_override - self.app.dependency_overrides[get_boardmanager] = self.get_boardmanager_override + self.app.dependency_overrides[get_boardmanager] = \ + self.get_boardmanager_override + +# This is a different class then the other tests, as at the time of +# writing, there is no async WebSocket client that will talk directly +# to an ASGI server. The websockets client library can talk to a unix +# domain socket, so that is used. +class TestWebSocket(TestCommon): + async def asyncSetUp(self): + + await super().asyncSetUp() + + d = os.path.realpath(tempfile.mkdtemp()) + self.basetempdir = d + + self.shutdown_event = asyncio.Event() + + self.socketpath = os.path.join(self.basetempdir, 'wstest.sock') + + config = Config() + config.graceful_timeout = .01 + config.bind = [ 'unix:' + self.socketpath ] + config.loglevel = 'ERROR' + + self.serv_task = asyncio.create_task(serve(self.app, config, + shutdown_trigger=self.shutdown_event.wait)) + + # get the unix domain socket connected + # need a startup_trigger + await asyncio.sleep(.01) + + async def asyncTearDown(self): + self.app = None + + self.shutdown_event.set() + + await self.serv_task + + shutil.rmtree(self.basetempdir) + self.basetempdir = None + + @patch('asyncio.create_subprocess_exec') + @timeout(2) + async def test_exec_sshd(self, cse): + def wrapper(corofun): + async def foo(*args, **kwargs): + r = await corofun(*args, **kwargs) + #print('foo:', repr(corofun), repr((args, kwargs)), repr(r)) + return r + + return foo + + async with websockets.connect('ws://foo/board/cora-1/exec', + path=self.socketpath) as websocket, \ + WSFWDClient(wrapper(websocket.recv), wrapper(websocket.send)) as client: + mstdout = AsyncMock() + + cmdargs = [ 'sshd', '-i' ] + # that w/o auth, it fails + with self.assertRaises(RuntimeError): + await client.exec(cmdargs, stdin=1, stdout=2) + + # that and invalid token fails + with self.assertRaises(RuntimeError): + await client.auth(dict(bearer='invalidtoken')) + + # that a valid auth token works + await client.auth(dict(bearer='thisisanapikey')) + + # XXX - enforce board reservation and correct user + + echodata = b'somedata' + wrap_subprocess_exec(cse, stdout=echodata, retcode=0) + + client.add_stream_handler(2, mstdout) + proc = await client.exec([ 'sshd', '-i' ], stdin=1, stdout=2) + + with self.assertRaises(RuntimeError): + await client.exec([ 'sshd', '-i' ], stdin=1, stdout=2) + + stdin, stdout = proc.stdin, proc.stdout + + stdin.write(echodata) + await stdin.drain() + + # that we get our data + self.assertEqual(await stdout.read(len(echodata)), echodata) + + # and that there is no more + self.assertEqual(await stdout.read(len(echodata)), b'') + + # and we are truly at EOF + self.assertTrue(stdout.at_eof()) + + stdin.close() + await stdin.wait_closed() + + await proc.wait() + + cse.assert_called_with('jexec', 'cora-1', *cmdargs, + stdin=subprocess.PIPE, stdout=subprocess.PIPE) + + # spin things, not sure best way to handle this + await asyncio.sleep(.01) + + cse.return_value.stdin.close.assert_called_with() + +# Per RFC 5737 (https://tools.ietf.org/html/rfc5737): +# The blocks 192.0.2.0/24 (TEST-NET-1), 198.51.100.0/24 (TEST-NET-2), +# and 203.0.113.0/24 (TEST-NET-3) are provided for use in +# documentation. + +# Note: this will not work under python before 3.8 before +# IsolatedAsyncioTestCase was added. The tearDown has to happen +# with the event loop running, otherwise the task and other things +# do not get cleaned up properly. +class TestBiteLab(TestCommon): + async def asyncSetUp(self): + await super().asyncSetUp() self.client = AsyncClient(app=self.app, base_url='http://testserver') - def tearDown(self): + async def asyncTearDown(self): self.app = None - asyncio.run(self.client.aclose()) + await self.client.aclose() self.client = None async def test_basic(self): diff --git a/bitelab/testing.py b/bitelab/testing.py index 84cb45a..abfd15c 100644 --- a/bitelab/testing.py +++ b/bitelab/testing.py @@ -30,5 +30,6 @@ from .snmp import TestSNMPPower, TestSNMPWrapper from .data import TestDatabase -from . import TestBiteLab, TestUnhashLRU, TestAttrs, TestBoardImpl, TestLogEvent +from . import TestBiteLab, TestUnhashLRU, TestAttrs, TestBoardImpl +from . import TestWebSocket, TestLogEvent from .__main__ import TestClient diff --git a/setup.py b/setup.py index 967b30c..816cb25 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,14 @@ setup( 'httpx', 'hypercorn', # option, for server only? 'pydantic[dotenv]', - 'aiokq @ git+https://www.funkthat.com/gitea/jmg/aiokq.git', + 'wsfwd @ git+https://www.funkthat.com/gitea/jmg/wsfwd.git', 'orm', 'ucl', 'databases[sqlite]', ], extras_require = { - 'dev': [ 'coverage' ], + # requests needed for fastpi.testclient.TestClient + 'dev': [ 'coverage', 'requests' ], }, entry_points={ 'console_scripts': [