# Copyright 2021 John-Mark Gurney. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # import asyncio import functools import os import unittest from Strobe.Strobe import Strobe, KeccakF from Strobe.Strobe import AuthenticationFailed import lora_comms from lora_comms import make_pktbuf domain = b'com.funkthat.lora.irrigation.shared.v0.0.1' # Response to command will be the CMD and any arguments if needed. # The command is encoded as an unsigned byte CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms # The follow commands are queue up, but will be acknoledged when queued CMD_WAITFOR = 2 # arg: (length): waits for length seconds CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds class LORANode(object): '''Implement a LORANode initiator.''' MAC_LEN = 8 def __init__(self, syncdatagram, shared=None): self.sd = syncdatagram self.st = Strobe(domain, F=KeccakF(800)) if shared is not None: self.st.key(shared) async def start(self): resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset') self.st.ratchet() pkt = await self.sendrecvvalid(b'confirm') if pkt != b'confirmed': raise RuntimeError async def sendrecvvalid(self, msg): msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN) origstate = self.st.copy() while True: resp = await self.sd.sendtillrecv(msg, 1) #_debprint('got:', resp) try: decmsg = self.st.recv_enc(resp[:-self.MAC_LEN]) self.st.recv_mac(resp[-self.MAC_LEN:]) break except AuthenticationFailed: # didn't get a valid packet, restore # state and retry #_debprint('failed') self.st.set_state_from(origstate) #_debprint('got rep:', repr(resp), repr(decmsg)) return decmsg @staticmethod def _encodeargs(*args): r = [] for i in args: r.append(i.to_bytes(4, byteorder='little')) return b''.join(r) async def _sendcmd(self, cmd, *args): cmdbyte = cmd.to_bytes(1, byteorder='little') resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args)) if resp[0:1] != cmdbyte: raise RuntimeError( 'response does not match, got: %s, expected: %s' % (repr(resp[0:1]), repr(cmdbyte))) async def waitfor(self, length): return await self._sendcmd(CMD_WAITFOR, length) async def runfor(self, chan, length): return await self._sendcmd(CMD_RUNFOR, chan, length) async def terminate(self): return await self._sendcmd(CMD_TERMINATE) class SyncDatagram(object): '''Base interface for a more simple synchronous interface.''' def __init__(self): #pragma: no cover pass async def recv(self, timeout=None): #pragma: no cover '''Receive a datagram. If timeout is not None, wait that many seconds, and if nothing is received in that time, raise an TimeoutError exception.''' raise NotImplementedError async def send(self, data): #pragma: no cover '''Send a datagram.''' raise NotImplementedError async def sendtillrecv(self, data, freq): '''Send the datagram in data, every freq seconds until a datagram is received. If timeout seconds happen w/o receiving a datagram, then raise an TimeoutError exception.''' while True: #_debprint('sending:', repr(data)) await self.send(data) try: return await self.recv(freq) except TimeoutError: pass class MockSyncDatagram(SyncDatagram): '''A testing version of SyncDatagram. Define a method runner which implements part of the sequence. In the function, await on either self.get, to wait for the other side to send something, or await self.put w/ data to send.''' def __init__(self): self.sendq = asyncio.Queue() self.recvq = asyncio.Queue() self.task = None self.task = asyncio.create_task(self.runner()) self.get = self.sendq.get self.put = self.recvq.put async def drain(self): '''Wait for the runner thread to finish up.''' return await self.task async def runner(self): #pragma: no cover raise NotImplementedError async def recv(self, timeout=None): return await self.recvq.get() async def send(self, data): return await self.sendq.put(data) def __del__(self): #pragma: no cover if self.task is not None and not self.task.done(): self.task.cancel() class TestSyncData(unittest.IsolatedAsyncioTestCase): async def test_syncsendtillrecv(self): class MySync(SyncDatagram): def __init__(self): self.sendq = [] self.resp = [ TimeoutError(), b'a' ] async def recv(self, timeout=None): assert timeout == 1 r = self.resp.pop(0) if isinstance(r, Exception): raise r return r async def send(self, data): self.sendq.append(data) ms = MySync() r = await ms.sendtillrecv(b'foo', 1) self.assertEqual(r, b'a') self.assertEqual(ms.sendq, [ b'foo', b'foo' ]) def timeout(timeout): def timeout_wrapper(fun): @functools.wraps(fun) async def wrapper(*args, **kwargs): return await asyncio.wait_for(fun(*args, **kwargs), timeout) return wrapper return timeout_wrapper def _debprint(*args): # pragma: no cover import traceback, sys, os.path st = traceback.extract_stack(limit=2)[0] sep = '' if args: sep = ':' print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep), *args) sys.stdout.flush() class TestLORANode(unittest.IsolatedAsyncioTestCase): @timeout(2) async def test_lora(self): _self = self shared_key = os.urandom(32) class TestSD(MockSyncDatagram): async def sendgettest(self, msg): '''Send the message, but make sure that if a bad message is sent afterward, that it replies w/ the same previous message. ''' await self.put(msg) resp = await self.get() await self.put(b'bogusmsg' * 5) resp2 = await self.get() _self.assertEqual(resp, resp2) return resp async def runner(self): l = Strobe(domain, F=KeccakF(800)) l.key(shared_key) # start handshake r = await self.get() pkt = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert pkt.endswith(b'reqreset') # make sure junk gets ignored await self.put(b'sdlfkj') # and that the packet remains the same _self.assertEqual(r, await self.get()) # and a couple more times await self.put(b'0' * 24) _self.assertEqual(r, await self.get()) await self.put(b'0' * 32) _self.assertEqual(r, await self.get()) # send the response await self.put(l.send_enc(os.urandom(16)) + l.send_mac(8)) # require no more back tracking at this point l.ratchet() # get the confirmation message r = await self.get() # test the resend capabilities await self.put(b'0' * 24) _self.assertEqual(r, await self.get()) # decode confirmation message c = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) # assert that we got it _self.assertEqual(c, b'confirm') # send confirmed reply r = await self.sendgettest(l.send_enc( b'confirmed') + l.send_mac(8)) # test and decode remaining command messages cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_WAITFOR assert int.from_bytes(cmd[1:], byteorder='little') == 30 r = await self.sendgettest(l.send_enc( cmd[0:1]) + l.send_mac(8)) cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_RUNFOR assert int.from_bytes(cmd[1:5], byteorder='little') == 1 assert int.from_bytes(cmd[5:], byteorder='little') == 50 r = await self.sendgettest(l.send_enc( cmd[0:1]) + l.send_mac(8)) cmd = l.recv_enc(r[:-8]) l.recv_mac(r[-8:]) assert cmd[0] == CMD_TERMINATE await self.put(l.send_enc(cmd[0:1]) + l.send_mac(8)) tsd = TestSD() l = LORANode(tsd, shared=shared_key) await l.start() await l.waitfor(30) await l.runfor(1, 50) await l.terminate() await tsd.drain() # Make sure all messages have been processed self.assertTrue(tsd.sendq.empty()) self.assertTrue(tsd.recvq.empty()) #_debprint('done') @timeout(2) async def test_ccode(self): _self = self from ctypes import pointer, sizeof, c_uint8 # seed the RNG prngseed = b'abc123' lora_comms.strobe_seed_prng((c_uint8 * len(prngseed))(*prngseed), len(prngseed)) # Create the state for testing commstate = lora_comms.CommsState() # These are the expected messages and their arguments exptmsgs = [ (CMD_WAITFOR, [ 30 ]), (CMD_RUNFOR, [ 1, 50 ]), (CMD_TERMINATE, [ ]), ] def procmsg(msg, outbuf): msgbuf = msg._from() #print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf)) cmd = msgbuf[0] args = [ int.from_bytes(msgbuf[x:x + 4], byteorder='little') for x in range(1, len(msgbuf), 4) ] if exptmsgs[0] == (cmd, args): exptmsgs.pop(0) outbuf[0].pkt[0] = cmd outbuf[0].pktlen = 1 else: #pragma: no cover raise RuntimeError('cmd not found') # wrap the callback function cb = lora_comms.process_msgfunc_t(procmsg) class CCodeSD(MockSyncDatagram): async def runner(self): for expectlen in [ 24, 17, 9, 9, 9 ]: # get message gb = await self.get() r = make_pktbuf(gb) outbytes = bytearray(64) outbuf = make_pktbuf(outbytes) # process the test message lora_comms.comms_process(commstate, r, outbuf) # make sure the reply matches length _self.assertEqual(expectlen, outbuf.pktlen) # save what was originally replied origmsg = outbuf._from() # pretend that the reply didn't make it r = make_pktbuf(gb) outbuf = make_pktbuf(outbytes) lora_comms.comms_process(commstate, r, outbuf) # make sure that the reply matches # the previous _self.assertEqual(origmsg, outbuf._from()) # pass the reply back await self.put(outbytes[:outbuf.pktlen]) # Generate shared key shared_key = os.urandom(32) # Initialize everything lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) # Create test fixture tsd = CCodeSD() l = LORANode(tsd, shared=shared_key) # Send various messages await l.start() await l.waitfor(30) await l.runfor(1, 50) await l.terminate() await tsd.drain() # Make sure all messages have been processed self.assertTrue(tsd.sendq.empty()) self.assertTrue(tsd.recvq.empty()) # Make sure all the expected messages have been # processed. self.assertFalse(exptmsgs) #_debprint('done')