|
|
@@ -0,0 +1,183 @@ |
|
|
|
from typing import Optional |
|
|
|
from functools import lru_cache, wraps |
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, FastAPI, Request |
|
|
|
from fastapi.security import OAuth2PasswordBearer |
|
|
|
from httpx import AsyncClient, Auth |
|
|
|
from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_401_UNAUTHORIZED |
|
|
|
|
|
|
|
from . import config |
|
|
|
|
|
|
|
import asyncio |
|
|
|
import gc |
|
|
|
import socket |
|
|
|
import sys |
|
|
|
import unittest |
|
|
|
|
|
|
|
# fix up parse_socket_addr for hypercorn |
|
|
|
from hypercorn.utils import parse_socket_addr |
|
|
|
from hypercorn.asyncio import tcp_server |
|
|
|
def new_parse_socket_addr(domain, addr): |
|
|
|
if domain == socket.AF_UNIX: |
|
|
|
return (addr, -1) |
|
|
|
|
|
|
|
return parse_socket_addr(domain, addr) |
|
|
|
|
|
|
|
tcp_server.parse_socket_addr = new_parse_socket_addr |
|
|
|
|
|
|
|
class BoardManager(object): |
|
|
|
board_classes = [ 'cora-z7s' ] |
|
|
|
|
|
|
|
def __init__(self, settings): |
|
|
|
self._settings = settings |
|
|
|
|
|
|
|
def classes(self): |
|
|
|
return self.board_classes |
|
|
|
|
|
|
|
def unhashable_lru(): |
|
|
|
def newwrapper(fun): |
|
|
|
cache = {} |
|
|
|
|
|
|
|
@wraps(fun) |
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
idargs = tuple(id(x) for x in args) |
|
|
|
idkwargs = tuple(sorted((k, id(v)) for k, v in |
|
|
|
kwargs.items())) |
|
|
|
k = (idargs, idkwargs) |
|
|
|
if k in cache: |
|
|
|
realargs, realkwargs, res = cache[k] |
|
|
|
if all(x is y for x, y in zip(args, |
|
|
|
realargs)) and all(realkwargs[x] is |
|
|
|
kwargs[x] for x in realkwargs): |
|
|
|
return res |
|
|
|
|
|
|
|
res = fun(*args, **kwargs) |
|
|
|
cache[k] = (args, kwargs, res) |
|
|
|
|
|
|
|
return res |
|
|
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
return newwrapper |
|
|
|
|
|
|
|
class BiteAuth(Auth): |
|
|
|
def __init__(self, token): |
|
|
|
self.token = token |
|
|
|
|
|
|
|
def auth_flow(self, request): |
|
|
|
request.headers['Authorization'] = 'Bearer ' + self.token |
|
|
|
yield request |
|
|
|
|
|
|
|
@lru_cache() |
|
|
|
def get_settings(): |
|
|
|
return config.Settings() |
|
|
|
|
|
|
|
@unhashable_lru() |
|
|
|
def get_boardmanager(settings: config.Settings = Depends(get_settings)): |
|
|
|
return BoardManager(settings) |
|
|
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl='/nonexistent') |
|
|
|
|
|
|
|
def lookup_user(token: str = Depends(oauth2_scheme), settings: config.Settings = Depends(get_settings)): |
|
|
|
try: |
|
|
|
return settings.apikeytouser(token) |
|
|
|
except KeyError: |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail='Invalid authentication credentials', |
|
|
|
headers={'WWW-Authenticate': 'Bearer'}, |
|
|
|
) |
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
|
|
def board_priority(request: Request): |
|
|
|
# Get the board, if any, from the connection |
|
|
|
scope = request.scope |
|
|
|
return scope['server'] |
|
|
|
|
|
|
|
@router.get('/board_classes') |
|
|
|
async def foo(user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager)): |
|
|
|
return brdmgr.classes() |
|
|
|
|
|
|
|
@router.get('/board_info') |
|
|
|
async def foo(user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager)): |
|
|
|
return brdmgr.classes() |
|
|
|
|
|
|
|
@router.get('/') |
|
|
|
async def foo(board_prio: dict = Depends(board_priority), settings: config.Settings = Depends(get_settings)): |
|
|
|
return { 'foo': 'bar', 'board': board_prio } |
|
|
|
|
|
|
|
def getApp(): |
|
|
|
app = FastAPI() |
|
|
|
app.include_router(router) |
|
|
|
|
|
|
|
return app |
|
|
|
|
|
|
|
# uvicorn can't call the above function, while hypercorn can |
|
|
|
#app = getApp() |
|
|
|
|
|
|
|
class TestUnhashLRU(unittest.TestCase): |
|
|
|
def test_unhashlru(self): |
|
|
|
lsta = [] |
|
|
|
lstb = [] |
|
|
|
|
|
|
|
# that a wrapped function |
|
|
|
cachefun = unhashable_lru()(lambda x: object()) |
|
|
|
|
|
|
|
# handles unhashable objects |
|
|
|
resa = cachefun(lsta) |
|
|
|
resb = cachefun(lstb) |
|
|
|
|
|
|
|
# that they return the same object again |
|
|
|
self.assertIs(resa, cachefun(lsta)) |
|
|
|
self.assertIs(resb, cachefun(lstb)) |
|
|
|
|
|
|
|
# that the object returned is not the same |
|
|
|
self.assertIsNot(cachefun(lsta), cachefun(lstb)) |
|
|
|
|
|
|
|
# that a second wrapped funcion |
|
|
|
cachefun2 = unhashable_lru()(lambda x: object()) |
|
|
|
|
|
|
|
# does not return the same object as the first cache |
|
|
|
self.assertIsNot(cachefun(lsta), cachefun2(lsta)) |
|
|
|
|
|
|
|
# Note: this will not work under python before 3.8 before |
|
|
|
# IsolatedAsyncioTestCase was added. The tearDown has to happen |
|
|
|
# with the event loop running, otherwise the task and other things |
|
|
|
# do not get cleaned up properly. |
|
|
|
class TestBiteLab(unittest.IsolatedAsyncioTestCase): |
|
|
|
async def get_settings_override(self): |
|
|
|
# Note: this gets run on each request. |
|
|
|
return config.Settings(apikeyfile="fixtures/api_keys") |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
self.app = getApp() |
|
|
|
self.app.dependency_overrides[get_settings] = self.get_settings_override |
|
|
|
self.client = AsyncClient(app=self.app, base_url='http://testserver') |
|
|
|
|
|
|
|
def tearDown(self): |
|
|
|
self.app = None |
|
|
|
asyncio.run(self.client.aclose()) |
|
|
|
self.client = None |
|
|
|
|
|
|
|
async def test_config(self): |
|
|
|
settings = await self.get_settings_override() |
|
|
|
self.assertEqual(settings.apikeytouser('thisisanapikey'), 'foo') |
|
|
|
self.assertEqual(settings.apikeytouser('anotherlongapikey'), 'bar') |
|
|
|
|
|
|
|
async def test_basic(self): |
|
|
|
res = await self.client.get('/') |
|
|
|
self.assertNotEqual(res.status_code, HTTP_404_NOT_FOUND) |
|
|
|
self.assertEqual(res.json(), { 'foo': 'bar', 'board': [ 'testserver', None ] }) |
|
|
|
|
|
|
|
async def test_notauth(self): |
|
|
|
res = await self.client.get('/board_classes') |
|
|
|
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) |
|
|
|
|
|
|
|
res = await self.client.get('/board_info') |
|
|
|
self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) |
|
|
|
|
|
|
|
async def test_classes(self): |
|
|
|
res = await self.client.get('/board_classes', auth=BiteAuth('thisisanapikey')) |
|
|
|
self.assertEqual(res.status_code, HTTP_200_OK) |
|
|
|
self.assertEqual(res.json(), [ 'cora-z7s' ]) |