diff --git a/bitelab/__main__.py b/bitelab/__main__.py index 827bfee..95a2343 100644 --- a/bitelab/__main__.py +++ b/bitelab/__main__.py @@ -31,7 +31,7 @@ from io import StringIO from starlette.status import HTTP_200_OK from starlette.status import HTTP_400_BAD_REQUEST, HTTP_401_UNAUTHORIZED, \ HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT -from unittest.mock import patch, AsyncMock, Mock +from unittest.mock import patch, AsyncMock, Mock, mock_open from . import BiteAuth, Board @@ -82,7 +82,16 @@ def output_board(brd): print('\t%s\t%s' % (i, repr(brd.attrs[i]))) def get_sshpubkey(fname): - raise OSError + if fname is not None: + with open(fname) as fp: + return fp.read() + + for i in ('id_ed25519', 'id_rsa'): + fname = os.path.expanduser(os.path.join('~', '.ssh', i + '.pub')) + with contextlib.suppress(IOError), open(fname) as fp: + return fp.read() + + raise IOError async def real_main(): baseurl = os.environ['BITELAB_URL'] @@ -122,7 +131,7 @@ async def real_main(): res.close() elif args.subparser_name in ('reserve', 'release'): kwargs = _httpxargs.copy() - with contextlib.suppress(OSError): + with contextlib.suppress(IOError): kwargs['json'] = dict(sshpubkey=get_sshpubkey(args.i)) res = await client.post('board/%s/%s' % (urllib.parse.quote(args.board, safe=''), @@ -182,6 +191,21 @@ class TestClient(unittest.TestCase): return ret, stdout.getvalue() + def test_sshpubkey(self): + fname = 'fname' + rdata = 'foo' + with patch('builtins.open', mock_open(read_data=rdata)) as mock_file: + self.assertEqual(get_sshpubkey(fname), rdata) + mock_file.assert_called_with(fname) + + with patch('builtins.open') as mock_file: + mock_file.side_effect = [ IOError(), mock_open(read_data=rdata)() ] + self.assertEqual(get_sshpubkey(None), rdata) + + for i in ('id_ed25519', 'id_rsa'): + fname = os.path.expanduser(os.path.join('~', '.ssh', i + '.pub')) + mock_file.assert_any_call(fname) + @patch.dict(sys.__dict__, dict(argv=[ '', 'list' ])) def test_list_failure(self): ac = self.ac @@ -232,7 +256,7 @@ class TestClient(unittest.TestCase): @patch('bitelab.__main__.get_sshpubkey') @patch.dict(sys.__dict__, dict(argv=[ '', 'reserve', 'cora-z7s' ])) def test_reserve(self, gspk): - gspk.side_effect = OSError() + gspk.side_effect = IOError() ac = self.ac acp = self.acp acp.return_value.status_code = HTTP_200_OK @@ -298,8 +322,10 @@ Attributes: json=dict(sshpubkey=keydata), auth=BiteAuth('thisisanapikey'), **_httpxargs) + @patch('bitelab.__main__.get_sshpubkey') @patch.dict(sys.__dict__, dict(argv=[ '', 'release', 'cora-z7s' ])) - def test_release(self): + def test_release(self, gspk): + gspk.side_effect = IOError ac = self.ac acp = self.acp acp.return_value.status_code = HTTP_200_OK