|
- import asyncio
- import functools
- import os
- import unittest
-
- from Strobe.Strobe import Strobe, KeccakF
- from Strobe.Strobe import AuthenticationFailed
-
- import lora_comms
- from lora_comms import make_pktbuf
-
- domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
-
- # Response to command will be the CMD and any arguments if needed.
- # The command is encoded as an unsigned byte
- CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
-
- # The follow commands are queue up, but will be acknoledged when queued
- CMD_WAITFOR = 2 # arg: (length): waits for length seconds
- CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
-
- class LORANode(object):
- '''Implement a LORANode initiator.'''
-
- def __init__(self, syncdatagram, shared=None):
- self.sd = syncdatagram
- self.st = Strobe(domain, F=KeccakF(800))
- if shared is not None:
- self.st.key(shared)
-
- async def start(self):
- msg = self.st.send_enc(os.urandom(16) + b'reqreset') + \
- self.st.send_mac(8)
-
- resp = await self.sd.sendtillrecv(msg, 1)
-
- self.st.recv_enc(resp[:16])
- self.st.recv_mac(resp[16:])
-
- self.st.ratchet()
-
- resp = await self.sd.sendtillrecv(
- self.st.send_enc(b'confirm') + self.st.send_mac(8), 1)
-
- pkt = self.st.recv_enc(resp[:9])
- self.st.recv_mac(resp[9:])
-
- if pkt != b'confirmed':
- raise RuntimeError
-
- @staticmethod
- def _encodeargs(*args):
- r = []
- for i in args:
- r.append(i.to_bytes(4, byteorder='little'))
-
- return b''.join(r)
-
- async def _sendcmd(self, cmd, *args):
- cmdbyte = cmd.to_bytes(1, byteorder='little')
- pkt = await self.sd.sendtillrecv(
- self.st.send_enc(cmdbyte +
- self._encodeargs(*args)) + self.st.send_mac(8), 1)
-
- resp = self.st.recv_enc(pkt[:-8])
- self.st.recv_mac(pkt[-8:])
-
- if resp[0:1] != cmdbyte:
- raise RuntimeError('response does not match, got: %s, expected: %s' % (repr(resp[0:1]), repr(cmdbyte)))
-
- async def waitfor(self, length):
- return await self._sendcmd(CMD_WAITFOR, length)
-
- async def runfor(self, chan, length):
- return await self._sendcmd(CMD_RUNFOR, chan, length)
-
- async def terminate(self):
- return await self._sendcmd(CMD_TERMINATE)
-
- class SyncDatagram(object):
- '''Base interface for a more simple synchronous interface.'''
-
- def __init__(self): #pragma: no cover
- pass
-
- async def recv(self, timeout=None): #pragma: no cover
- '''Receive a datagram. If timeout is not None, wait that many
- seconds, and if nothing is received in that time, raise an TimeoutError
- exception.'''
-
- raise NotImplementedError
-
- async def send(self, data): #pragma: no cover
- '''Send a datagram.'''
-
- raise NotImplementedError
-
- async def sendtillrecv(self, data, freq):
- '''Send the datagram in data, every freq seconds until a datagram
- is received. If timeout seconds happen w/o receiving a datagram,
- then raise an TimeoutError exception.'''
-
- while True:
- await self.send(data)
- try:
- return await self.recv(freq)
- except TimeoutError:
- pass
-
- class MockSyncDatagram(SyncDatagram):
- '''A testing version of SyncDatagram. Define a method runner which
- implements part of the sequence. In the function, await on either
- self.get, to wait for the other side to send something, or await
- self.put w/ data to send.'''
-
- def __init__(self):
- self.sendq = asyncio.Queue()
- self.recvq = asyncio.Queue()
- self.task = None
- self.task = asyncio.create_task(self.runner())
-
- self.get = self.sendq.get
- self.put = self.recvq.put
-
- async def drain(self):
- '''Wait for the runner thread to finish up.'''
-
- return await self.task
-
- async def runner(self): #pragma: no cover
- raise NotImplementedError
-
- async def recv(self, timeout=None):
- return await self.recvq.get()
-
- async def send(self, data):
- return await self.sendq.put(data)
-
- def __del__(self): #pragma: no cover
- if self.task is not None and not self.task.done():
- self.task.cancel()
-
- class TestSyncData(unittest.IsolatedAsyncioTestCase):
- async def test_syncsendtillrecv(self):
- class MySync(SyncDatagram):
- def __init__(self):
- self.sendq = []
- self.resp = [ TimeoutError(), b'a' ]
-
- async def recv(self, timeout=None):
- assert timeout == 1
- r = self.resp.pop(0)
- if isinstance(r, Exception):
- raise r
-
- return r
-
- async def send(self, data):
- self.sendq.append(data)
-
- ms = MySync()
-
- r = await ms.sendtillrecv(b'foo', 1)
-
- self.assertEqual(r, b'a')
- self.assertEqual(ms.sendq, [ b'foo', b'foo' ])
-
- def timeout(timeout):
- def timeout_wrapper(fun):
- @functools.wraps(fun)
- async def wrapper(*args, **kwargs):
- return await asyncio.wait_for(fun(*args, **kwargs),
- timeout)
-
- return wrapper
-
- return timeout_wrapper
-
- class TestLORANode(unittest.IsolatedAsyncioTestCase):
- @timeout(2)
- async def test_lora(self):
- shared_key = os.urandom(32)
-
- class TestSD(MockSyncDatagram):
- async def runner(self):
- l = Strobe(domain, F=KeccakF(800))
-
- l.key(shared_key)
-
- # start handshake
- r = await self.get()
-
- pkt = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert pkt.endswith(b'reqreset')
-
- await self.put(l.send_enc(os.urandom(16)) +
- l.send_mac(8))
-
- l.ratchet()
-
- r = await self.get()
- c = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert c == b'confirm'
-
- await self.put(l.send_enc(b'confirmed') +
- l.send_mac(8))
-
- r = await self.get()
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_WAITFOR
- assert int.from_bytes(cmd[1:], byteorder='little') == 30
-
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
-
- r = await self.get()
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_RUNFOR
- assert int.from_bytes(cmd[1:5], byteorder='little') == 1
- assert int.from_bytes(cmd[5:], byteorder='little') == 50
-
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
-
- r = await self.get()
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_TERMINATE
-
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
-
- tsd = TestSD()
- l = LORANode(tsd, shared=shared_key)
-
- await l.start()
-
- await l.waitfor(30)
-
- await l.runfor(1, 50)
-
- await l.terminate()
-
- await tsd.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
-
- @timeout(2)
- async def test_ccode(self):
- _self = self
- from ctypes import pointer, sizeof, c_uint8
-
- # seed the RNG
- prngseed = b'abc123'
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
-
- # Create the state for testing
- commstate = lora_comms.CommsState()
-
- # These are the expected messages and their arguments
- exptmsgs = [
- (CMD_WAITFOR, [ 30 ]),
- (CMD_RUNFOR, [ 1, 50 ]),
- (CMD_TERMINATE, [ ]),
- ]
- def procmsg(msg, outbuf):
- msgbuf = msg._from()
- #print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf))
- cmd = msgbuf[0]
- args = [ int.from_bytes(msgbuf[x:x + 4],
- byteorder='little') for x in range(1, len(msgbuf),
- 4) ]
-
- if exptmsgs[0] == (cmd, args):
- exptmsgs.pop(0)
- outbuf[0].pkt[0] = cmd
- outbuf[0].pktlen = 1
- else: #pragma: no cover
- raise RuntimeError('cmd not found')
-
- # wrap the callback function
- cb = lora_comms.process_msgfunc_t(procmsg)
-
- class CCodeSD(MockSyncDatagram):
- async def runner(self):
- for expectlen in [ 24, 17, 9, 9, 9 ]:
- # get message
- gb = await self.get()
- r = make_pktbuf(gb)
-
- outbytes = bytearray(64)
- outbuf = make_pktbuf(outbytes)
-
- # process the test message
- lora_comms.comms_process(commstate, r,
- outbuf)
-
- # make sure the reply matches length
- _self.assertEqual(expectlen,
- outbuf.pktlen)
-
- # save what was originally replied
- origmsg = outbuf._from()
-
- # pretend that the reply didn't make it
- r = make_pktbuf(gb)
- outbuf = make_pktbuf(outbytes)
-
- lora_comms.comms_process(commstate, r,
- outbuf)
-
- # make sure that the reply matches previous
- _self.assertEqual(origmsg, outbuf._from())
-
- # pass the reply back
- await self.put(outbytes[:outbuf.pktlen])
-
- # Generate shared key
- shared_key = os.urandom(32)
-
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
-
- # Create test fixture
- tsd = CCodeSD()
- l = LORANode(tsd, shared=shared_key)
-
- # Send various messages
- await l.start()
-
- await l.waitfor(30)
-
- await l.runfor(1, 50)
-
- await l.terminate()
-
- await tsd.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
-
- # Make sure all the expected messages have been
- # processed.
- self.assertFalse(exptmsgs)
|