Browse Source

wrap lines at 80 chars..

really need to get code formatting added to the pipeline..
main
John-Mark Gurney 3 years ago
parent
commit
fe76b5a9b9
1 changed files with 88 additions and 47 deletions
  1. +88
    -47
      bitelab/__main__.py

+ 88
- 47
bitelab/__main__.py View File

@@ -55,10 +55,12 @@ def check_res_code(res):
sys.exit(1) sys.exit(1)
elif res.status_code != HTTP_200_OK: elif res.status_code != HTTP_200_OK:
try: try:
print('Got status: %d, json: %s' % (res.status_code, res.json()))
print('Got status: %d, json: %s' % (res.status_code,
res.json()))
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
# body is JSON # body is JSON
print('Got status: %d, body: %s' % (res.status_code, repr(res.text)))
print('Got status: %d, body: %s' % (res.status_code,
repr(res.text)))
sys.exit(1) sys.exit(1)


def makebool(s): def makebool(s):
@@ -125,18 +127,22 @@ async def fwd_data(reader, writer):
await writer.drain() await writer.drain()


async def run_exec(baseurl, authkey, board, args): async def run_exec(baseurl, authkey, board, args):
url = urllib.parse.urljoin(baseurl, 'board/%s/exec' % urllib.parse.quote(board, safe=''))
url = urllib.parse.urljoin(baseurl, 'board/%s/exec' %
urllib.parse.quote(board, safe=''))
url = convert_to_ws(url) url = convert_to_ws(url)
stdin, stdout = await aioconsole.stream.get_standard_streams() stdin, stdout = await aioconsole.stream.get_standard_streams()


async with websockets.connect(url) as ws, wsfwd.WSFWDClient(ws.recv, ws.send) as client:
async with websockets.connect(url) as ws, wsfwd.WSFWDClient(ws.recv,
ws.send) as client:
try: try:
await client.auth(dict(bearer=authkey)) await client.auth(dict(bearer=authkey))


proc = await client.exec(args=args) proc = await client.exec(args=args)


toexec_task = asyncio.create_task(fwd_data(stdin, proc.stdin))
fromexec_task = asyncio.create_task(fwd_data(proc.stdout, stdout))
toexec_task = asyncio.create_task(fwd_data(stdin,
proc.stdin))
fromexec_task = asyncio.create_task(fwd_data(
proc.stdout, stdout))


r = await proc.wait() r = await proc.wait()


@@ -159,20 +165,30 @@ async def real_main():
dest='subparser_name', dest='subparser_name',
description='valid subcommands', help='additional help') description='valid subcommands', help='additional help')


parse_list = subparsers.add_parser('list', help='list available board classes')

parser_reserve = subparsers.add_parser('reserve', aliases=[ 'release' ], help='reserve/release a board')
parser_reserve.add_argument('-i', metavar='identity_file', type=str, help='file name for ssh public key')
parser_reserve.add_argument('board', type=str, help='name of the board or class')

parser_set = subparsers.add_parser('set', help='set attributes on a board')
parser_set.add_argument('setvars', type=str, nargs='+', help='name of the board or class')
parser_set.add_argument('board', type=str, help='name of the board or class')

parser_exec = subparsers.add_parser('exec', help='run a program in the jail for a board')
parser_exec.add_argument('board', type=str, help='name of the board or class')
parse_list = subparsers.add_parser('list',
help='list available board classes')

parser_reserve = subparsers.add_parser('reserve', aliases=[ 'release' ],
help='reserve/release a board')
parser_reserve.add_argument('-i', metavar='identity_file', type=str,
help='file name for ssh public key')
parser_reserve.add_argument('board', type=str,
help='name of the board or class')

parser_set = subparsers.add_parser('set',
help='set attributes on a board')
parser_set.add_argument('setvars', type=str, nargs='+',
help='name of the board or class')
parser_set.add_argument('board', type=str,
help='name of the board or class')

parser_exec = subparsers.add_parser('exec',
help='run a program in the jail for a board')
parser_exec.add_argument('board', type=str,
help='name of the board or class')
parser_exec.add_argument('prog', type=str, help='program to exec') parser_exec.add_argument('prog', type=str, help='program to exec')
parser_exec.add_argument('args', type=str, nargs='*', help='arguments for program')
parser_exec.add_argument('args', type=str, nargs='*',
help='arguments for program')


args = parser.parse_args() args = parser.parse_args()
#print(repr(args), file=sys.__stderr__) #print(repr(args), file=sys.__stderr__)
@@ -202,7 +218,8 @@ async def real_main():
elif args.subparser_name in ('reserve', 'release'): elif args.subparser_name in ('reserve', 'release'):
kwargs = _httpxargs.copy() kwargs = _httpxargs.copy()
with contextlib.suppress(IOError): with contextlib.suppress(IOError):
kwargs['json'] = dict(sshpubkey=get_sshpubkey(args.i))
kwargs['json'] = dict(sshpubkey=get_sshpubkey(
args.i))
res = await client.post('board/%s/%s' % res = await client.post('board/%s/%s' %
(urllib.parse.quote(args.board, safe=''), (urllib.parse.quote(args.board, safe=''),
args.subparser_name), args.subparser_name),
@@ -260,7 +277,8 @@ class TestExecClient(unittest.IsolatedAsyncioTestCase):
@contextlib.contextmanager @contextlib.contextmanager
def make_pipe(self): def make_pipe(self):
r, w = os.pipe() r, w = os.pipe()
with os.fdopen(r, 'rb', buffering=65536) as readfl, os.fdopen(w, 'wb', buffering=65536) as writefl:
with os.fdopen(r, 'rb', buffering=65536) as readfl, \
os.fdopen(w, 'wb', buffering=65536) as writefl:
yield readfl, writefl yield readfl, writefl


# too lazy to make this async since async file-like objects # too lazy to make this async since async file-like objects
@@ -284,23 +302,30 @@ class TestExecClient(unittest.IsolatedAsyncioTestCase):
stdout = io.BytesIO() stdout = io.BytesIO()


# Data path: # Data path:
# stdin -> stdin_task -> stdinwriter -> pipe -> stdinreader -> sys.stdin -> real_main ->
# sys.stdout -> stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout
# stdin -> stdin_task -> stdinwriter -> pipe ->
# stdinreader -> sys.stdin -> real_main -> sys.stdout ->
# stdoutwriter -> pipe -> stdoutreader -> stdoud_task -> stdout
# #
# How things get closed: # How things get closed:
# stdin is already "closed", stdin_task will close stdinwriter when eof encountered (doclose).
# stdin is already "closed", stdin_task will close
# stdinwriter when eof encountered (doclose).


# create the pipes needed # create the pipes needed
with self.make_pipe() as (stdinreader, stdinwriter), self.make_pipe() as (stdoutreader, stdoutwriter):
with self.make_pipe() as (stdinreader, stdinwriter), \
self.make_pipe() as (stdoutreader, stdoutwriter):
# setup the threads to move data # setup the threads to move data
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
stdin_task = loop.run_in_executor(None, self.copytask, io.BytesIO(stdin), stdinwriter)
stdin_task = loop.run_in_executor(None, self.copytask,
io.BytesIO(stdin), stdinwriter)


# do not close stdout, otherwise we cannot obtain the value
stdout_task = loop.run_in_executor(None, self.copytask, stdoutreader, stdout, False)
# do not close stdout, otherwise we cannot obtain
# the value
stdout_task = loop.run_in_executor(None, self.copytask,
stdoutreader, stdout, False)


# insert the pipes # insert the pipes
with patch.dict(sys.__dict__, dict(stdin=TextIOWrapper(stdinreader),
with patch.dict(sys.__dict__,
dict(stdin=TextIOWrapper(stdinreader),
stdout=TextIOWrapper(stdoutwriter))): stdout=TextIOWrapper(stdoutwriter))):
try: try:
# run the function # run the function
@@ -341,18 +366,21 @@ class TestExecClient(unittest.IsolatedAsyncioTestCase):
raise RuntimeError('badauth') raise RuntimeError('badauth')


server = TestServer(self.toserver.get, self.toclient.put) server = TestServer(self.toserver.get, self.toclient.put)
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec',
'cora-z7s', 'program', 'arg1', 'arg2' ])), \
patch('websockets.connect') as webcon: patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon) self.setup_websockets_mock(webcon)


ret, stdout = await self.runAsyncMain() ret, stdout = await self.runAsyncMain()


webcon.assert_called_with(urllib.parse.urljoin('ws://someserver/', 'board/cora-z7s/exec'))
webcon.assert_called_with(urllib.parse.urljoin(
'ws://someserver/', 'board/cora-z7s/exec'))


await server.__aexit__(None, None, None) await server.__aexit__(None, None, None)


await asyncio.sleep(.1) await asyncio.sleep(.1)
self.assertEqual(stdout.decode(), 'failed to exec: Got auth error: \'badauth\'\n')
self.assertEqual(stdout.decode(),
'failed to exec: Got auth error: \'badauth\'\n')


self.assertEqual(ret, 1) self.assertEqual(ret, 1)


@@ -368,11 +396,13 @@ class TestExecClient(unittest.IsolatedAsyncioTestCase):


async def handle_chanclose(self, msg): async def handle_chanclose(self, msg):
self.add_tasks(asyncio.create_task(self.sendcmd( self.add_tasks(asyncio.create_task(self.sendcmd(
dict(cmd='chanclose', chan=self._stdout_stream))))
dict(cmd='chanclose',
chan=self._stdout_stream))))


async def handle_exec(self, msg): async def handle_exec(self, msg):
self._stdout_stream = msg['stdout'] self._stdout_stream = msg['stdout']
assert msg['args'] == [ 'program', 'arg1', 'arg2' ]
assert msg['args'] == [ 'program', 'arg1',
'arg2' ]
self.add_stream_handler(msg['stdin'], self.add_stream_handler(msg['stdin'],
functools.partial(self.echo_handler, functools.partial(self.echo_handler,
msg['stdout'])) msg['stdout']))
@@ -383,7 +413,8 @@ class TestExecClient(unittest.IsolatedAsyncioTestCase):


server = TestServer(self.toserver.get, self.toclient.put) server = TestServer(self.toserver.get, self.toclient.put)


with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec', 'cora-z7s', 'program', 'arg1', 'arg2' ])), \
with patch.dict(sys.__dict__, dict(argv=[ 'rand', 'exec',
'cora-z7s', 'program', 'arg1', 'arg2' ])), \
patch('websockets.connect') as webcon: patch('websockets.connect') as webcon:
self.setup_websockets_mock(webcon) self.setup_websockets_mock(webcon)


@@ -435,21 +466,25 @@ class TestClient(unittest.TestCase):
] ]


for orig, new in testpairs: for orig, new in testpairs:
self.assertEqual(convert_to_ws(orig), new, 'failed to convert: %s' % repr(orig))
self.assertEqual(convert_to_ws(orig), new,
'failed to convert: %s' % repr(orig))


def test_sshpubkey(self): def test_sshpubkey(self):
fname = 'fname' fname = 'fname'
rdata = 'foo' rdata = 'foo'
with patch('builtins.open', mock_open(read_data=rdata)) as mock_file:
with patch('builtins.open', mock_open(read_data=rdata)) as \
mock_file:
self.assertEqual(get_sshpubkey(fname), rdata) self.assertEqual(get_sshpubkey(fname), rdata)
mock_file.assert_called_with(fname) mock_file.assert_called_with(fname)


with patch('builtins.open') as mock_file: with patch('builtins.open') as mock_file:
mock_file.side_effect = [ IOError(), mock_open(read_data=rdata)() ]
mock_file.side_effect = [ IOError(),
mock_open(read_data=rdata)() ]
self.assertEqual(get_sshpubkey(None), rdata) self.assertEqual(get_sshpubkey(None), rdata)


for i in ('id_ed25519', 'id_rsa'): for i in ('id_ed25519', 'id_rsa'):
fname = os.path.expanduser(os.path.join('~', '.ssh', i + '.pub'))
fname = os.path.expanduser(os.path.join('~',
'.ssh', i + '.pub'))
mock_file.assert_any_call(fname) mock_file.assert_any_call(fname)


with patch('builtins.open') as mock_file: with patch('builtins.open') as mock_file:
@@ -477,7 +512,8 @@ class TestClient(unittest.TestCase):


ac.assert_called_with(base_url='http://someserver/') ac.assert_called_with(base_url='http://someserver/')


acg.assert_called_with('board/classes', auth=BiteAuth('thisisanapikey'), **_httpxargs)
acg.assert_called_with('board/classes',
auth=BiteAuth('thisisanapikey'), **_httpxargs)


# XXX - add error cases for UI # XXX - add error cases for UI


@@ -500,7 +536,8 @@ class TestClient(unittest.TestCase):


ac.assert_called_with(base_url='http://someserver/') ac.assert_called_with(base_url='http://someserver/')


acg.assert_called_with('board/classes', auth=BiteAuth('thisisanapikey'), **_httpxargs)
acg.assert_called_with('board/classes',
auth=BiteAuth('thisisanapikey'), **_httpxargs)


# XXX - add error cases for UI # XXX - add error cases for UI


@@ -542,7 +579,8 @@ Attributes:
auth=BiteAuth('thisisanapikey'), **_httpxargs) auth=BiteAuth('thisisanapikey'), **_httpxargs)


@patch('bitelab.__main__.get_sshpubkey') @patch('bitelab.__main__.get_sshpubkey')
@patch.dict(sys.__dict__, dict(argv=[ '', 'reserve', '-i', 'bogusfilename', 'cora-z7s' ]))
@patch.dict(sys.__dict__, dict(argv=[ '', 'reserve', '-i',
'bogusfilename', 'cora-z7s' ]))
def test_reserve_ssh(self, gspk): def test_reserve_ssh(self, gspk):
ac = self.ac ac = self.ac
acp = self.acp acp = self.acp
@@ -606,8 +644,10 @@ Attributes:
acp.assert_called_with('board/cora-z7s/release', acp.assert_called_with('board/cora-z7s/release',
auth=BiteAuth('thisisanapikey'), **_httpxargs) auth=BiteAuth('thisisanapikey'), **_httpxargs)


@patch('bitelab.__main__._typeconv', dict(power=makebool, other=makebool))
@patch.dict(sys.__dict__, dict(argv=[ '', 'set', 'power=on', 'other=off', 'cora-z7s' ]))
@patch('bitelab.__main__._typeconv', dict(power=makebool,
other=makebool))
@patch.dict(sys.__dict__, dict(argv=[ '', 'set', 'power=on',
'other=off', 'cora-z7s' ]))
def test_set(self): def test_set(self):
ac = self.ac ac = self.ac
acp = self.acp acp = self.acp
@@ -634,8 +674,8 @@ Attributes:
ac.assert_called_with(base_url='http://someserver/') ac.assert_called_with(base_url='http://someserver/')


acp.assert_called_with('board/cora-z7s/attrs', acp.assert_called_with('board/cora-z7s/attrs',
auth=BiteAuth('thisisanapikey'), json=dict(power=True, other=False),
**_httpxargs)
auth=BiteAuth('thisisanapikey'), json=dict(power=True,
other=False), **_httpxargs)


def test_make_attrs(self): def test_make_attrs(self):
self.assertEqual(make_attrs('power=on'), dict(power=True)) self.assertEqual(make_attrs('power=on'), dict(power=True))
@@ -650,7 +690,8 @@ Attributes:
def test_check_res_code(self): def test_check_res_code(self):
res = Mock() res = Mock()
res.status_code = HTTP_404_NOT_FOUND res.status_code = HTTP_404_NOT_FOUND
res.json.side_effect = json.decoder.JSONDecodeError('foo', 'bar', 1)
res.json.side_effect = json.decoder.JSONDecodeError('foo',
'bar', 1)
res.text = 'body' res.text = 'body'


ret, output = self.runMain(lambda: check_res_code(res)) ret, output = self.runMain(lambda: check_res_code(res))


Loading…
Cancel
Save