diff --git a/lora.py b/lora.py index e369093..01da465 100644 --- a/lora.py +++ b/lora.py @@ -46,6 +46,8 @@ CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds class LORANode(object): '''Implement a LORANode initiator.''' + MAC_LEN = 8 + def __init__(self, syncdatagram, shared=None): self.sd = syncdatagram self.st = Strobe(domain, F=KeccakF(800)) @@ -53,25 +55,38 @@ class LORANode(object): self.st.key(shared) async def start(self): - msg = self.st.send_enc(os.urandom(16) + b'reqreset') + \ - self.st.send_mac(8) - - resp = await self.sd.sendtillrecv(msg, 1) - - self.st.recv_enc(resp[:16]) - self.st.recv_mac(resp[16:]) + resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset') self.st.ratchet() - resp = await self.sd.sendtillrecv( - self.st.send_enc(b'confirm') + self.st.send_mac(8), 1) - - pkt = self.st.recv_enc(resp[:9]) - self.st.recv_mac(resp[9:]) + pkt = await self.sendrecvvalid(b'confirm') if pkt != b'confirmed': raise RuntimeError + async def sendrecvvalid(self, msg): + msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN) + + origstate = self.st.copy() + + while True: + resp = await self.sd.sendtillrecv(msg, 1) + #_debprint('got:', resp) + + try: + decmsg = self.st.recv_enc(resp[:-self.MAC_LEN]) + self.st.recv_mac(resp[-self.MAC_LEN:]) + break + except AuthenticationFailed: + # didn't get a valid packet, restore + # state and retry + + #_debprint('failed') + self.st.set_state_from(origstate) + + #_debprint('got rep:', repr(resp), repr(decmsg)) + return decmsg + @staticmethod def _encodeargs(*args): r = [] @@ -82,15 +97,12 @@ class LORANode(object): async def _sendcmd(self, cmd, *args): cmdbyte = cmd.to_bytes(1, byteorder='little') - pkt = await self.sd.sendtillrecv( - self.st.send_enc(cmdbyte + - self._encodeargs(*args)) + self.st.send_mac(8), 1) - - resp = self.st.recv_enc(pkt[:-8]) - self.st.recv_mac(pkt[-8:]) + resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args)) if resp[0:1] != cmdbyte: - raise RuntimeError('response does not match, got: %s, expected: %s' % (repr(resp[0:1]), repr(cmdbyte))) + raise RuntimeError( + 'response does not match, got: %s, expected: %s' % + (repr(resp[0:1]), repr(cmdbyte))) async def waitfor(self, length): return await self._sendcmd(CMD_WAITFOR, length) @@ -109,8 +121,8 @@ class SyncDatagram(object): async def recv(self, timeout=None): #pragma: no cover '''Receive a datagram. If timeout is not None, wait that many - seconds, and if nothing is received in that time, raise an TimeoutError - exception.''' + seconds, and if nothing is received in that time, raise an + TimeoutError exception.''' raise NotImplementedError @@ -125,6 +137,7 @@ class SyncDatagram(object): then raise an TimeoutError exception.''' while True: + #_debprint('sending:', repr(data)) await self.send(data) try: return await self.recv(freq) @@ -200,12 +213,42 @@ def timeout(timeout): return timeout_wrapper +def _debprint(*args): # pragma: no cover + import traceback, sys, os.path + st = traceback.extract_stack(limit=2)[0] + + sep = '' + if args: + sep = ':' + + print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep), + *args) + sys.stdout.flush() + class TestLORANode(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_lora(self): + _self = self shared_key = os.urandom(32) class TestSD(MockSyncDatagram): + async def sendgettest(self, msg): + '''Send the message, but make sure that if a + bad message is sent afterward, that it replies + w/ the same previous message. + ''' + + await self.put(msg) + resp = await self.get() + + await self.put(b'bogusmsg' * 5) + + resp2 = await self.get() + + _self.assertEqual(resp, resp2) + + return resp + async def runner(self): l = Strobe(domain, F=KeccakF(800)) @@ -219,42 +262,66 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): assert pkt.endswith(b'reqreset') + # make sure junk gets ignored + await self.put(b'sdlfkj') + + # and that the packet remains the same + _self.assertEqual(r, await self.get()) + + # and a couple more times + await self.put(b'0' * 24) + _self.assertEqual(r, await self.get()) + await self.put(b'0' * 32) + _self.assertEqual(r, await self.get()) + + # send the response await self.put(l.send_enc(os.urandom(16)) + l.send_mac(8)) + # require no more back tracking at this point l.ratchet() + # get the confirmation message r = await self.get() + + # test the resend capabilities + await self.put(b'0' * 24) + _self.assertEqual(r, await self.get()) + + # decode confirmation message c = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) - assert c == b'confirm' + # assert that we got it + _self.assertEqual(c, b'confirm') - await self.put(l.send_enc(b'confirmed') + - l.send_mac(8)) + # send confirmed reply + r = await self.sendgettest(l.send_enc( + b'confirmed') + l.send_mac(8)) - r = await self.get() + # test and decode remaining command messages cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_WAITFOR - assert int.from_bytes(cmd[1:], byteorder='little') == 30 + assert int.from_bytes(cmd[1:], + byteorder='little') == 30 - await self.put(l.send_enc(cmd[0:1]) + - l.send_mac(8)) + r = await self.sendgettest(l.send_enc( + cmd[0:1]) + l.send_mac(8)) - r = await self.get() cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_RUNFOR - assert int.from_bytes(cmd[1:5], byteorder='little') == 1 - assert int.from_bytes(cmd[5:], byteorder='little') == 50 + assert int.from_bytes(cmd[1:5], + byteorder='little') == 1 + assert int.from_bytes(cmd[5:], + byteorder='little') == 50 - await self.put(l.send_enc(cmd[0:1]) + - l.send_mac(8)) + r = await self.sendgettest(l.send_enc( + cmd[0:1]) + l.send_mac(8)) - r = await self.get() cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) @@ -279,6 +346,7 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): # Make sure all messages have been processed self.assertTrue(tsd.sendq.empty()) self.assertTrue(tsd.recvq.empty()) + #_debprint('done') @timeout(2) async def test_ccode(self): @@ -345,8 +413,10 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): lora_comms.comms_process(commstate, r, outbuf) - # make sure that the reply matches previous - _self.assertEqual(origmsg, outbuf._from()) + # make sure that the reply matches + # the previous + _self.assertEqual(origmsg, + outbuf._from()) # pass the reply back await self.put(outbytes[:outbuf.pktlen]) @@ -379,3 +449,4 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): # Make sure all the expected messages have been # processed. self.assertFalse(exptmsgs) + #_debprint('done')