- import asyncio
- import functools
- import os
- import unittest
- from Strobe.Strobe import Strobe, KeccakF
- from Strobe.Strobe import AuthenticationFailed
- import lora_comms
- from lora_comms import make_pktbuf
- domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
- # Response to command will be the CMD and any arguments if needed.
- # The command is encoded as an unsigned byte
- CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
- # The follow commands are queue up, but will be acknoledged when queued
- CMD_WAITFOR = 2 # arg: (length): waits for length seconds
- CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
- class LORANode(object):
- '''Implement a LORANode initiator.'''
- def __init__(self, syncdatagram, shared=None):
- self.sd = syncdatagram
- self.st = Strobe(domain, F=KeccakF(800))
- if shared is not None:
- self.st.key(shared)
- async def start(self):
- msg = self.st.send_enc(os.urandom(16) + b'reqreset') + \
- self.st.send_mac(8)
- resp = await self.sd.sendtillrecv(msg, 1)
- self.st.recv_enc(resp[:16])
- self.st.recv_mac(resp[16:])
- self.st.ratchet()
- resp = await self.sd.sendtillrecv(
- self.st.send_enc(b'confirm') + self.st.send_mac(8), 1)
- pkt = self.st.recv_enc(resp[:9])
- self.st.recv_mac(resp[9:])
- if pkt != b'confirmed':
- raise RuntimeError
- @staticmethod
- def _encodeargs(*args):
- r = []
- for i in args:
- r.append(i.to_bytes(4, byteorder='little'))
- return b''.join(r)
- async def _sendcmd(self, cmd, *args):
- cmdbyte = cmd.to_bytes(1, byteorder='little')
- pkt = await self.sd.sendtillrecv(
- self.st.send_enc(cmdbyte +
- self._encodeargs(*args)) + self.st.send_mac(8), 1)
- resp = self.st.recv_enc(pkt[:-8])
- self.st.recv_mac(pkt[-8:])
- if resp[0:1] != cmdbyte:
- raise RuntimeError('response does not match, got: %s, expected: %s' % (repr(resp[0:1]), repr(cmdbyte)))
- async def waitfor(self, length):
- return await self._sendcmd(CMD_WAITFOR, length)
- async def runfor(self, chan, length):
- return await self._sendcmd(CMD_RUNFOR, chan, length)
- async def terminate(self):
- return await self._sendcmd(CMD_TERMINATE)
- class SyncDatagram(object):
- '''Base interface for a more simple synchronous interface.'''
- def __init__(self): #pragma: no cover
- pass
- async def recv(self, timeout=None): #pragma: no cover
- '''Receive a datagram. If timeout is not None, wait that many
- seconds, and if nothing is received in that time, raise an TimeoutError
- exception.'''
- raise NotImplementedError
- async def send(self, data): #pragma: no cover
- '''Send a datagram.'''
- raise NotImplementedError
- async def sendtillrecv(self, data, freq):
- '''Send the datagram in data, every freq seconds until a datagram
- is received. If timeout seconds happen w/o receiving a datagram,
- then raise an TimeoutError exception.'''
- while True:
- await self.send(data)
- try:
- return await self.recv(freq)
- except TimeoutError:
- pass
- class MockSyncDatagram(SyncDatagram):
- '''A testing version of SyncDatagram. Define a method runner which
- implements part of the sequence. In the function, await on either
- self.get, to wait for the other side to send something, or await
- self.put w/ data to send.'''
- def __init__(self):
- self.sendq = asyncio.Queue()
- self.recvq = asyncio.Queue()
- self.task = None
- self.task = asyncio.create_task(self.runner())
- self.get = self.sendq.get
- self.put = self.recvq.put
- async def drain(self):
- '''Wait for the runner thread to finish up.'''
- return await self.task
- async def runner(self): #pragma: no cover
- raise NotImplementedError
- async def recv(self, timeout=None):
- return await self.recvq.get()
- async def send(self, data):
- return await self.sendq.put(data)
- def __del__(self): #pragma: no cover
- if self.task is not None and not self.task.done():
- self.task.cancel()
- class TestSyncData(unittest.IsolatedAsyncioTestCase):
- async def test_syncsendtillrecv(self):
- class MySync(SyncDatagram):
- def __init__(self):
- self.sendq = []
- self.resp = [ TimeoutError(), b'a' ]
- async def recv(self, timeout=None):
- assert timeout == 1
- r = self.resp.pop(0)
- if isinstance(r, Exception):
- raise r
- return r
- async def send(self, data):
- self.sendq.append(data)
- ms = MySync()
- r = await ms.sendtillrecv(b'foo', 1)
- self.assertEqual(r, b'a')
- self.assertEqual(ms.sendq, [ b'foo', b'foo' ])
- def timeout(timeout):
- def timeout_wrapper(fun):
- @functools.wraps(fun)
- async def wrapper(*args, **kwargs):
- return await asyncio.wait_for(fun(*args, **kwargs),
- timeout)
- return wrapper
- return timeout_wrapper
- class TestLORANode(unittest.IsolatedAsyncioTestCase):
- @timeout(2)
- async def test_lora(self):
- shared_key = os.urandom(32)
- class TestSD(MockSyncDatagram):
- async def runner(self):
- l = Strobe(domain, F=KeccakF(800))
- l.key(shared_key)
- # start handshake
- r = await self.get()
- pkt = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
- assert pkt.endswith(b'reqreset')
- await self.put(l.send_enc(os.urandom(16)) +
- l.send_mac(8))
- l.ratchet()
- r = await self.get()
- c = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
- assert c == b'confirm'
- await self.put(l.send_enc(b'confirmed') +
- l.send_mac(8))
- r = await self.get()
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
- assert cmd[0] == CMD_WAITFOR
- assert int.from_bytes(cmd[1:], byteorder='little') == 30
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
- r = await self.get()
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
- assert cmd[0] == CMD_RUNFOR
- assert int.from_bytes(cmd[1:5], byteorder='little') == 1
- assert int.from_bytes(cmd[5:], byteorder='little') == 50
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
- r = await self.get()
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
- assert cmd[0] == CMD_TERMINATE
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
- tsd = TestSD()
- l = LORANode(tsd, shared=shared_key)
- await l.start()
- await l.waitfor(30)
- await l.runfor(1, 50)
- await l.terminate()
- await tsd.drain()
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
- @timeout(2)
- async def test_ccode(self):
- _self = self
- from ctypes import pointer, sizeof, c_uint8
- # seed the RNG
- prngseed = b'abc123'
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
- # Create the state for testing
- commstate = lora_comms.CommsState()
- # These are the expected messages and their arguments
- exptmsgs = [
- (CMD_WAITFOR, [ 30 ]),
- (CMD_RUNFOR, [ 1, 50 ]),
- ]
- def procmsg(msg, outbuf):
- msgbuf = msg._from()
- #print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf))
- cmd = msgbuf[0]
- args = [ int.from_bytes(msgbuf[x:x + 4],
- byteorder='little') for x in range(1, len(msgbuf),
- 4) ]
- if exptmsgs[0] == (cmd, args):
- exptmsgs.pop(0)
- outbuf[0].pkt[0] = cmd
- outbuf[0].pktlen = 1
- else: #pragma: no cover
- raise RuntimeError('cmd not found')
- # wrap the callback function
- cb = lora_comms.process_msgfunc_t(procmsg)
- class CCodeSD(MockSyncDatagram):
- async def runner(self):
- for expectlen in [ 24, 17, 9, 9, 9 ]:
- # get message
- gb = await self.get()
- r = make_pktbuf(gb)
- outbytes = bytearray(64)
- outbuf = make_pktbuf(outbytes)
- # process the test message
- lora_comms.comms_process(commstate, r,
- outbuf)
- # make sure the reply matches length
- _self.assertEqual(expectlen,
- outbuf.pktlen)
- # pass the reply back
- await self.put(outbytes[:outbuf.pktlen])
- # Generate shared key
- shared_key = os.urandom(32)
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
- # Create test fixture
- tsd = CCodeSD()
- l = LORANode(tsd, shared=shared_key)
- # Send various messages
- await l.start()
- await l.waitfor(30)
- await l.runfor(1, 50)
- await l.terminate()
- await tsd.drain()
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
- # Make sure all the expected messages have been
- # processed.
- self.assertFalse(exptmsgs)