| @@ -1,9 +1,10 @@ | |||||
| from typing import Optional | |||||
| from typing import Optional, Dict, Any | |||||
| from functools import lru_cache, wraps | from functools import lru_cache, wraps | ||||
| from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request | from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request | ||||
| from fastapi.security import OAuth2PasswordBearer | from fastapi.security import OAuth2PasswordBearer | ||||
| from httpx import AsyncClient, Auth | from httpx import AsyncClient, Auth | ||||
| from pydantic import BaseModel | |||||
| from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_401_UNAUTHORIZED | from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND, HTTP_401_UNAUTHORIZED | ||||
| from . import config | from . import config | ||||
| @@ -28,9 +29,31 @@ def new_parse_socket_addr(domain, addr): | |||||
| tcp_server.parse_socket_addr = new_parse_socket_addr | tcp_server.parse_socket_addr = new_parse_socket_addr | ||||
| class BoardClassInfo(BaseModel): | |||||
| clsname: str | |||||
| arch: str | |||||
| class BoardImpl: | |||||
| def __init__(self, name, cls): | |||||
| self.name = name | |||||
| self.brdclass = cls | |||||
| @property | |||||
| def attrs(self): | |||||
| return {} | |||||
| class Board(BaseModel): | |||||
| name: str | |||||
| brdclass: str | |||||
| attrs: Dict[str, Any] | |||||
| class Config: | |||||
| orm_mode = True | |||||
| class BoardManager(object): | class BoardManager(object): | ||||
| board_class_info = { | board_class_info = { | ||||
| 'cora-z7s': { | 'cora-z7s': { | ||||
| 'clsname': 'cora-z7s', | |||||
| 'arch': 'arm64-aarch64', | 'arch': 'arm64-aarch64', | ||||
| }, | }, | ||||
| } | } | ||||
| @@ -39,9 +62,7 @@ class BoardManager(object): | |||||
| # <abbreviated class>-<num> | # <abbreviated class>-<num> | ||||
| # | # | ||||
| boards = { | boards = { | ||||
| 'cora-1': { | |||||
| 'class': 'cora-z7s', | |||||
| } | |||||
| 'cora-1': BoardImpl('cora-1', 'cora-z7s'), | |||||
| } | } | ||||
| def __init__(self, settings): | def __init__(self, settings): | ||||
| @@ -118,16 +139,19 @@ def board_priority(request: Request): | |||||
| scope = request.scope | scope = request.scope | ||||
| return scope['server'] | return scope['server'] | ||||
| @router.get('/board_classes') | |||||
| async def foo(user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager)): | |||||
| @router.get('/board_classes', response_model=Dict[str, BoardClassInfo]) | |||||
| async def foo(user: str = Depends(lookup_user), | |||||
| brdmgr: BoardManager = Depends(get_boardmanager)): | |||||
| return brdmgr.classes() | return brdmgr.classes() | ||||
| @router.get('/board_info') | |||||
| async def foo(user: str = Depends(lookup_user), brdmgr: BoardManager = Depends(get_boardmanager)): | |||||
| @router.get('/board_info',response_model=Dict[str, Board]) | |||||
| async def foo(user: str = Depends(lookup_user), | |||||
| brdmgr: BoardManager = Depends(get_boardmanager)): | |||||
| return brdmgr.boards | return brdmgr.boards | ||||
| @router.get('/') | @router.get('/') | ||||
| async def foo(board_prio: dict = Depends(board_priority), settings: config.Settings = Depends(get_settings)): | |||||
| async def foo(board_prio: dict = Depends(board_priority), | |||||
| settings: config.Settings = Depends(get_settings)): | |||||
| return { 'foo': 'bar', 'board': board_prio } | return { 'foo': 'bar', 'board': board_prio } | ||||
| def getApp(): | def getApp(): | ||||
| @@ -184,7 +208,8 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): | |||||
| # setup test database | # setup test database | ||||
| self.dbtempfile = tempfile.NamedTemporaryFile() | self.dbtempfile = tempfile.NamedTemporaryFile() | ||||
| self.database = data.databases.Database('sqlite:///' + self.dbtempfile.name) | |||||
| self.database = data.databases.Database('sqlite:///' + | |||||
| self.dbtempfile.name) | |||||
| self.data = data.make_orm(self.database) | self.data = data.make_orm(self.database) | ||||
| await _setup_data(self.data) | await _setup_data(self.data) | ||||
| @@ -192,7 +217,8 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): | |||||
| # setup settings | # setup settings | ||||
| self.settings = config.Settings(db_file=self.dbtempfile.name) | self.settings = config.Settings(db_file=self.dbtempfile.name) | ||||
| self.app.dependency_overrides[get_settings] = self.get_settings_override | |||||
| self.app.dependency_overrides[get_settings] = \ | |||||
| self.get_settings_override | |||||
| self.app.dependency_overrides[get_data] = self.get_data_override | self.app.dependency_overrides[get_data] = self.get_data_override | ||||
| self.client = AsyncClient(app=self.app, base_url='http://testserver') | self.client = AsyncClient(app=self.app, base_url='http://testserver') | ||||
| @@ -205,29 +231,36 @@ class TestBiteLab(unittest.IsolatedAsyncioTestCase): | |||||
| async def test_basic(self): | async def test_basic(self): | ||||
| res = await self.client.get('/') | res = await self.client.get('/') | ||||
| self.assertNotEqual(res.status_code, HTTP_404_NOT_FOUND) | self.assertNotEqual(res.status_code, HTTP_404_NOT_FOUND) | ||||
| self.assertEqual(res.json(), { 'foo': 'bar', 'board': [ 'testserver', None ] }) | |||||
| async def test_notauth(self): | async def test_notauth(self): | ||||
| # test that simple accesses are denied | |||||
| res = await self.client.get('/board_classes') | res = await self.client.get('/board_classes') | ||||
| self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | ||||
| res = await self.client.get('/board_info') | res = await self.client.get('/board_info') | ||||
| self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | ||||
| res = await self.client.get('/board_classes', auth=BiteAuth('badapikey')) | |||||
| # test that invalid api keys are denied | |||||
| res = await self.client.get('/board_classes', | |||||
| auth=BiteAuth('badapikey')) | |||||
| self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | self.assertEqual(res.status_code, HTTP_401_UNAUTHORIZED) | ||||
| async def test_classes(self): | async def test_classes(self): | ||||
| res = await self.client.get('/board_classes', auth=BiteAuth('thisisanapikey')) | |||||
| res = await self.client.get('/board_classes', | |||||
| auth=BiteAuth('thisisanapikey')) | |||||
| self.assertEqual(res.status_code, HTTP_200_OK) | self.assertEqual(res.status_code, HTTP_200_OK) | ||||
| self.assertEqual(res.json(), { 'cora-z7s': { 'arch': 'arm64-aarch64', } }) | |||||
| self.assertEqual(res.json(), { 'cora-z7s': BoardClassInfo(**{ | |||||
| 'arch': 'arm64-aarch64', 'clsname': 'cora-z7s', }) }) | |||||
| async def test_board_info(self): | async def test_board_info(self): | ||||
| res = await self.client.get('/board_info', auth=BiteAuth('thisisanapikey')) | |||||
| res = await self.client.get('/board_info', | |||||
| auth=BiteAuth('thisisanapikey')) | |||||
| self.assertEqual(res.status_code, HTTP_200_OK) | self.assertEqual(res.status_code, HTTP_200_OK) | ||||
| info = { | info = { | ||||
| 'cora-1': { | 'cora-1': { | ||||
| 'class': 'cora-z7s', | |||||
| 'name': 'cora-1', | |||||
| 'brdclass': 'cora-z7s', | |||||
| 'attrs': {}, | |||||
| }, | }, | ||||
| } | } | ||||
| self.assertEqual(res.json(), info) | self.assertEqual(res.json(), info) | ||||
| @@ -237,7 +270,8 @@ class TestData(unittest.IsolatedAsyncioTestCase): | |||||
| # setup temporary directory | # setup temporary directory | ||||
| self.dbtempfile = tempfile.NamedTemporaryFile() | self.dbtempfile = tempfile.NamedTemporaryFile() | ||||
| self.database = data.databases.Database('sqlite:///' + self.dbtempfile.name) | |||||
| self.database = data.databases.Database('sqlite:///' + | |||||
| self.dbtempfile.name) | |||||
| self.data = data.make_orm(self.database) | self.data = data.make_orm(self.database) | ||||
| def tearDown(self): | def tearDown(self): | ||||
| @@ -249,5 +283,7 @@ class TestData(unittest.IsolatedAsyncioTestCase): | |||||
| data = self.data | data = self.data | ||||
| self.assertEqual(await data.APIKey.objects.all(), []) | self.assertEqual(await data.APIKey.objects.all(), []) | ||||
| await _setup_data(data) | await _setup_data(data) | ||||
| self.assertEqual((await data.APIKey.objects.get(key='thisisanapikey')).user, 'foo') | |||||
| self.assertEqual((await data.APIKey.objects.get(key='anotherlongapikey')).user, 'bar') | |||||
| self.assertEqual((await data.APIKey.objects.get( | |||||
| key='thisisanapikey')).user, 'foo') | |||||
| self.assertEqual((await data.APIKey.objects.get( | |||||
| key='anotherlongapikey')).user, 'bar') | |||||