| @@ -1,9 +1,10 @@ | |||
| from typing import Optional | |||
| from typing import Optional, Dict, Any | |||
| from functools import lru_cache, wraps | |||
| from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request | |||
| from fastapi.security import OAuth2PasswordBearer | |||
| from httpx import AsyncClient, Auth | |||
| from pydantic import BaseModel | |||
| from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_401_UNAUTHORIZED | |||
| from . import config | |||
| @@ -28,9 +29,31 @@ def new_parse_socket_addr(domain, addr): | |||
| tcp_server.parse_socket_addr = new_parse_socket_addr | |||
| class BoardClassInfo(BaseModel): | |||
| clsname: str | |||
| arch: str | |||
| class BoardImpl: | |||
| def __init__(self, name, cls): | |||
| self.name = name | |||
| self.brdclass = cls | |||
| @property | |||
| def attrs(self): | |||
| return {} | |||
| class Board(BaseModel): | |||
| name: str | |||
| brdclass: str | |||
| attrs: Dict[str, Any] | |||
| class Config: | |||
| orm_mode = True | |||
| class BoardManager(object): | |||
| board_class_info = { | |||
| 'cora-z7s': { | |||
| 'clsname': 'cora-z7s', | |||
| 'arch': 'arm64-aarch64', | |||
| }, | |||
| } | |||
| @@ -39,9 +62,7 @@ class BoardManager(object): | |||
| # <abbreviated class>-<num> | |||
| # | |||
| boards = { | |||
| 'cora-1': { | |||
| 'class': 'cora-z7s', | |||
| } | |||
| 'cora-1': BoardImpl('cora-1', 'cora-z7s'), | |||
| } | |||
| def __init__(self, settings): | |||
| @@ -118,16 +139,19 @@ def board_priority(request: Request): | |||
| scope = request.scope | |||
| return scope['server'] | |||
| @router.get('/board_classes') | |||
| async def foo(user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager)): | |||
| @router.get('/board_classes', response_model=Dict[str, BoardClassInfo]) | |||
| async def foo(user: str = Depends(lookup_user), | |||
| brdmgr: BoardManager = Depends(get_boardmanager)): | |||
| return brdmgr.classes() | |||
| @router.get('/board_info') | |||
| async def foo(user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager)): | |||
| @router.get('/board_info',response_model=Dict[str, Board]) | |||
| async def foo(user: str = Depends(lookup_user), | |||
| brdmgr: BoardManager = Depends(get_boardmanager)): | |||
| return brdmgr.boards | |||
| @router.get('/') | |||
| async def foo(board_prio: dict = Depends(board_priority), settings: config.Settings = Depends(get_settings)): | |||
| async def foo(board_prio: dict = Depends(board_priority), | |||
| settings: config.Settings = Depends(get_settings)): | |||
| return { 'foo': 'bar', 'board': board_prio } | |||
| def getApp(): | |||
| @@ -184,7 +208,8 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): | |||
| # setup test database | |||
| self.dbtempfile = tempfile.NamedTemporaryFile() | |||
| self.database = data.databases.Database('sqlite:///' + self.dbtempfile.name) | |||
| self.database = data.databases.Database('sqlite:///' + | |||
| self.dbtempfile.name) | |||
| self.data = data.make_orm(self.database) | |||
| await _setup_data(self.data) | |||
| @@ -192,7 +217,8 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): | |||
| # setup settings | |||
| self.settings = config.Settings(db_file=self.dbtempfile.name) | |||
| self.app.dependency_overrides[get_settings] = self.get_settings_override | |||
| self.app.dependency_overrides[get_settings] = \ | |||
| self.get_settings_override | |||
| self.app.dependency_overrides[get_data] = self.get_data_override | |||
| self.client = AsyncClient(app=self.app, base_url='http://testserver') | |||
| @@ -205,29 +231,36 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): | |||
| async def test_basic(self): | |||
| res = await self.client.get('/') | |||
| self.assertNotEqual(res.status_code, HTTP_404_NOT_FOUND) | |||
| self.assertEqual(res.json(), { 'foo': 'bar', 'board': [ 'testserver', None ] }) | |||
| async def test_notauth(self): | |||
| # test that simple accesses are denied | |||
| res = await self.client.get('/board_classes') | |||
| self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | |||
| res = await self.client.get('/board_info') | |||
| self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | |||
| res = await self.client.get('/board_classes', auth=BiteAuth('badapikey')) | |||
| # test that invalid api keys are denied | |||
| res = await self.client.get('/board_classes', | |||
| auth=BiteAuth('badapikey')) | |||
| self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | |||
| async def test_classes(self): | |||
| res = await self.client.get('/board_classes', auth=BiteAuth('thisisanapikey')) | |||
| res = await self.client.get('/board_classes', | |||
| auth=BiteAuth('thisisanapikey')) | |||
| self.assertEqual(res.status_code, HTTP_200_OK) | |||
| self.assertEqual(res.json(), { 'cora-z7s': { 'arch': 'arm64-aarch64', } }) | |||
| self.assertEqual(res.json(), { 'cora-z7s': BoardClassInfo(**{ | |||
| 'arch': 'arm64-aarch64', 'clsname': 'cora-z7s', }) }) | |||
| async def test_board_info(self): | |||
| res = await self.client.get('/board_info', auth=BiteAuth('thisisanapikey')) | |||
| res = await self.client.get('/board_info', | |||
| auth=BiteAuth('thisisanapikey')) | |||
| self.assertEqual(res.status_code, HTTP_200_OK) | |||
| info = { | |||
| 'cora-1': { | |||
| 'class': 'cora-z7s', | |||
| 'name': 'cora-1', | |||
| 'brdclass': 'cora-z7s', | |||
| 'attrs': {}, | |||
| }, | |||
| } | |||
| self.assertEqual(res.json(), info) | |||
| @@ -237,7 +270,8 @@ class TestData(unittest.IsolatedAsyncioTestCase): | |||
| # setup temporary directory | |||
| self.dbtempfile = tempfile.NamedTemporaryFile() | |||
| self.database = data.databases.Database('sqlite:///' + self.dbtempfile.name) | |||
| self.database = data.databases.Database('sqlite:///' + | |||
| self.dbtempfile.name) | |||
| self.data = data.make_orm(self.database) | |||
| def tearDown(self): | |||
| @@ -249,5 +283,7 @@ class TestData(unittest.IsolatedAsyncioTestCase): | |||
| data = self.data | |||
| self.assertEqual(await data.APIKey.objects.all(), []) | |||
| await _setup_data(data) | |||
| self.assertEqual((await data.APIKey.objects.get(key='thisisanapikey')).user, 'foo') | |||
| self.assertEqual((await data.APIKey.objects.get(key='anotherlongapikey')).user, 'bar') | |||
| self.assertEqual((await data.APIKey.objects.get( | |||
| key='thisisanapikey')).user, 'foo') | |||
| self.assertEqual((await data.APIKey.objects.get( | |||
| key='anotherlongapikey')).user, 'bar') | |||