From fd750987cd59768cc6eaaae54e699ced54331090 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Tue, 24 Nov 2020 18:00:26 -0800 Subject: [PATCH] have the CLI handle auth errors... enforce proper return values... properly close the httpx client... marking get_data no cover was bad, had a bug in this, need to get coverage.. --- bitelab/__init__.py | 90 ++++++++++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 25 deletions(-) diff --git a/bitelab/__init__.py b/bitelab/__init__.py index 9764128..0b585e8 100644 --- a/bitelab/__init__.py +++ b/bitelab/__init__.py @@ -238,10 +238,10 @@ def get_settings(): # pragma: no cover # how to get coverage for this? @unhashable_lru() -def get_data(settings: config.Settings = Depends(get_settings)): # pragma: no cover +def get_data(settings: config.Settings = Depends(get_settings)): #print(repr(settings)) database = data.databases.Database('sqlite:///' + settings.db_file) - d = make_orm(self.database) + d = make_orm(database) return d @unhashable_lru() @@ -411,23 +411,32 @@ async def real_main(): client = AsyncClient(base_url=baseurl) - if sys.argv[1] == 'list': - res = await client.get('board/classes', auth=BiteAuth(authkey)) - - print('Classes:') - for i in res.json(): - print('\t' + i) - elif sys.argv[1] == 'reserve': - res = await client.get('board/%s/reserve' % - urllib.parse.quote(sys.argv[2], safe=''), - auth=BiteAuth(authkey)) - - brd = Board.parse_obj(res.json()) - print('Name:\t%s' % brd.name) - print('Class:\t%s' % brd.brdclass) - print('Attributes:') - for i in brd.attrs: - print('\t%s\t%s' % (i, brd.attrs[i])) + try: + if sys.argv[1] == 'list': + res = await client.get('board/classes', auth=BiteAuth(authkey)) + + if res.status_code == HTTP_401_UNAUTHORIZED: + print('Invalid authentication credentials.') + sys.exit(1) + + print('Classes:') + for i in res.json(): + print('\t' + i) + + res.close() + elif sys.argv[1] == 'reserve': + res = await client.get('board/%s/reserve' % + urllib.parse.quote(sys.argv[2], safe=''), + auth=BiteAuth(authkey)) + + brd = Board.parse_obj(res.json()) + print('Name:\t%s' % brd.name) + print('Class:\t%s' % brd.brdclass) + print('Attributes:') + for i in brd.attrs: + print('\t%s\t%s' % (i, brd.attrs[i])) + finally: + await client.aclose() def main(): asyncio.run(real_main()) @@ -770,14 +779,43 @@ class TestClient(unittest.TestCase): self.addCleanup(self.ac_patcher.stop) self.acg = self.ac.return_value.get = AsyncMock() + self.acaclose = self.ac.return_value.aclose = AsyncMock() self.acgr = self.acg.return_value = Mock() def runMain(self): - stdout = StringIO() - with patch.dict(sys.__dict__, dict(stdout=stdout)): - main() + try: + stdout = StringIO() + with patch.dict(sys.__dict__, dict(stdout=stdout)): + main() + + ret = 0 + except SystemExit as e: + ret = e.code + + return ret, stdout.getvalue() + + @patch.dict(sys.__dict__, dict(argv=[ '', 'list' ])) + def test_list_failure(self): + ac = self.ac + acg = self.acg + acg.return_value.status_code = HTTP_401_UNAUTHORIZED + acg.return_value.json.return_value = { + 'detail': 'Invalid authentication credentials' + } - return stdout.getvalue() + ret, stdout = self.runMain() + + output = '''Invalid authentication credentials. +''' + + self.assertEqual(ret, 1) + self.assertEqual(stdout, output) + + ac.assert_called_with(base_url='http://someserver/') + + acg.assert_called_with('board/classes', auth=BiteAuth('thisisanapikey')) + + # XXX - add error cases for UI @patch.dict(sys.__dict__, dict(argv=[ '', 'list' ])) def test_list(self): @@ -787,12 +825,13 @@ class TestClient(unittest.TestCase): acg.return_value.json.return_value = { 'cora-z7s': { 'arch': 'arm-armv7', 'clsname': 'cora-z7s', }} - stdout = self.runMain() + ret, stdout = self.runMain() output = '''Classes: cora-z7s ''' + self.assertEqual(ret, 0) self.assertEqual(stdout, output) ac.assert_called_with(base_url='http://someserver/') @@ -813,7 +852,7 @@ class TestClient(unittest.TestCase): 'ip': '172.20.20.5', }).dict() - stdout = self.runMain() + ret, stdout = self.runMain() output = '''Name:\tcora-1 Class:\tcora-z7s @@ -822,6 +861,7 @@ Attributes: \tip\t172.20.20.5 ''' + self.assertEqual(ret, 0) self.assertEqual(stdout, output) ac.assert_called_with(base_url='http://someserver/')