|
- # Copyright 2021 John-Mark Gurney.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- # 1. Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # 2. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
- # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
- # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
- # SUCH DAMAGE.
- #
-
- import asyncio
- import functools
- import os
- import unittest
-
- from Strobe.Strobe import Strobe, KeccakF
- from Strobe.Strobe import AuthenticationFailed
-
- import lora_comms
- from lora_comms import make_pktbuf
-
- domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
-
- # Response to command will be the CMD and any arguments if needed.
- # The command is encoded as an unsigned byte
- CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
-
- # The follow commands are queue up, but will be acknoledged when queued
- CMD_WAITFOR = 2 # arg: (length): waits for length seconds
- CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
-
- class LORANode(object):
- '''Implement a LORANode initiator.'''
-
- MAC_LEN = 8
-
- def __init__(self, syncdatagram, shared=None):
- self.sd = syncdatagram
- self.st = Strobe(domain, F=KeccakF(800))
- if shared is not None:
- self.st.key(shared)
-
- async def start(self):
- resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset')
-
- self.st.ratchet()
-
- pkt = await self.sendrecvvalid(b'confirm')
-
- if pkt != b'confirmed':
- raise RuntimeError
-
- async def sendrecvvalid(self, msg):
- msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN)
-
- origstate = self.st.copy()
-
- while True:
- resp = await self.sd.sendtillrecv(msg, 1)
- #_debprint('got:', resp)
-
- try:
- decmsg = self.st.recv_enc(resp[:-self.MAC_LEN])
- self.st.recv_mac(resp[-self.MAC_LEN:])
- break
- except AuthenticationFailed:
- # didn't get a valid packet, restore
- # state and retry
-
- #_debprint('failed')
- self.st.set_state_from(origstate)
-
- #_debprint('got rep:', repr(resp), repr(decmsg))
- return decmsg
-
- @staticmethod
- def _encodeargs(*args):
- r = []
- for i in args:
- r.append(i.to_bytes(4, byteorder='little'))
-
- return b''.join(r)
-
- async def _sendcmd(self, cmd, *args):
- cmdbyte = cmd.to_bytes(1, byteorder='little')
- resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args))
-
- if resp[0:1] != cmdbyte:
- raise RuntimeError(
- 'response does not match, got: %s, expected: %s' %
- (repr(resp[0:1]), repr(cmdbyte)))
-
- async def waitfor(self, length):
- return await self._sendcmd(CMD_WAITFOR, length)
-
- async def runfor(self, chan, length):
- return await self._sendcmd(CMD_RUNFOR, chan, length)
-
- async def terminate(self):
- return await self._sendcmd(CMD_TERMINATE)
-
- class SyncDatagram(object):
- '''Base interface for a more simple synchronous interface.'''
-
- def __init__(self): #pragma: no cover
- pass
-
- async def recv(self, timeout=None): #pragma: no cover
- '''Receive a datagram. If timeout is not None, wait that many
- seconds, and if nothing is received in that time, raise an
- TimeoutError exception.'''
-
- raise NotImplementedError
-
- async def send(self, data): #pragma: no cover
- '''Send a datagram.'''
-
- raise NotImplementedError
-
- async def sendtillrecv(self, data, freq):
- '''Send the datagram in data, every freq seconds until a datagram
- is received. If timeout seconds happen w/o receiving a datagram,
- then raise an TimeoutError exception.'''
-
- while True:
- #_debprint('sending:', repr(data))
- await self.send(data)
- try:
- return await self.recv(freq)
- except TimeoutError:
- pass
-
- class MockSyncDatagram(SyncDatagram):
- '''A testing version of SyncDatagram. Define a method runner which
- implements part of the sequence. In the function, await on either
- self.get, to wait for the other side to send something, or await
- self.put w/ data to send.'''
-
- def __init__(self):
- self.sendq = asyncio.Queue()
- self.recvq = asyncio.Queue()
- self.task = None
- self.task = asyncio.create_task(self.runner())
-
- self.get = self.sendq.get
- self.put = self.recvq.put
-
- async def drain(self):
- '''Wait for the runner thread to finish up.'''
-
- return await self.task
-
- async def runner(self): #pragma: no cover
- raise NotImplementedError
-
- async def recv(self, timeout=None):
- return await self.recvq.get()
-
- async def send(self, data):
- return await self.sendq.put(data)
-
- def __del__(self): #pragma: no cover
- if self.task is not None and not self.task.done():
- self.task.cancel()
-
- class TestSyncData(unittest.IsolatedAsyncioTestCase):
- async def test_syncsendtillrecv(self):
- class MySync(SyncDatagram):
- def __init__(self):
- self.sendq = []
- self.resp = [ TimeoutError(), b'a' ]
-
- async def recv(self, timeout=None):
- assert timeout == 1
- r = self.resp.pop(0)
- if isinstance(r, Exception):
- raise r
-
- return r
-
- async def send(self, data):
- self.sendq.append(data)
-
- ms = MySync()
-
- r = await ms.sendtillrecv(b'foo', 1)
-
- self.assertEqual(r, b'a')
- self.assertEqual(ms.sendq, [ b'foo', b'foo' ])
-
- def timeout(timeout):
- def timeout_wrapper(fun):
- @functools.wraps(fun)
- async def wrapper(*args, **kwargs):
- return await asyncio.wait_for(fun(*args, **kwargs),
- timeout)
-
- return wrapper
-
- return timeout_wrapper
-
- def _debprint(*args): # pragma: no cover
- import traceback, sys, os.path
- st = traceback.extract_stack(limit=2)[0]
-
- sep = ''
- if args:
- sep = ':'
-
- print('%s:%d%s' % (os.path.basename(st.filename), st.lineno, sep),
- *args)
- sys.stdout.flush()
-
- class TestLORANode(unittest.IsolatedAsyncioTestCase):
- @timeout(2)
- async def test_lora(self):
- _self = self
- shared_key = os.urandom(32)
-
- class TestSD(MockSyncDatagram):
- async def sendgettest(self, msg):
- '''Send the message, but make sure that if a
- bad message is sent afterward, that it replies
- w/ the same previous message.
- '''
-
- await self.put(msg)
- resp = await self.get()
-
- await self.put(b'bogusmsg' * 5)
-
- resp2 = await self.get()
-
- _self.assertEqual(resp, resp2)
-
- return resp
-
- async def runner(self):
- l = Strobe(domain, F=KeccakF(800))
-
- l.key(shared_key)
-
- # start handshake
- r = await self.get()
-
- pkt = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert pkt.endswith(b'reqreset')
-
- # make sure junk gets ignored
- await self.put(b'sdlfkj')
-
- # and that the packet remains the same
- _self.assertEqual(r, await self.get())
-
- # and a couple more times
- await self.put(b'0' * 24)
- _self.assertEqual(r, await self.get())
- await self.put(b'0' * 32)
- _self.assertEqual(r, await self.get())
-
- # send the response
- await self.put(l.send_enc(os.urandom(16)) +
- l.send_mac(8))
-
- # require no more back tracking at this point
- l.ratchet()
-
- # get the confirmation message
- r = await self.get()
-
- # test the resend capabilities
- await self.put(b'0' * 24)
- _self.assertEqual(r, await self.get())
-
- # decode confirmation message
- c = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- # assert that we got it
- _self.assertEqual(c, b'confirm')
-
- # send confirmed reply
- r = await self.sendgettest(l.send_enc(
- b'confirmed') + l.send_mac(8))
-
- # test and decode remaining command messages
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_WAITFOR
- assert int.from_bytes(cmd[1:],
- byteorder='little') == 30
-
- r = await self.sendgettest(l.send_enc(
- cmd[0:1]) + l.send_mac(8))
-
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_RUNFOR
- assert int.from_bytes(cmd[1:5],
- byteorder='little') == 1
- assert int.from_bytes(cmd[5:],
- byteorder='little') == 50
-
- r = await self.sendgettest(l.send_enc(
- cmd[0:1]) + l.send_mac(8))
-
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_TERMINATE
-
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
-
- tsd = TestSD()
- l = LORANode(tsd, shared=shared_key)
-
- await l.start()
-
- await l.waitfor(30)
-
- await l.runfor(1, 50)
-
- await l.terminate()
-
- await tsd.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
- #_debprint('done')
-
- @timeout(2)
- async def test_ccode(self):
- _self = self
- from ctypes import pointer, sizeof, c_uint8
-
- # seed the RNG
- prngseed = b'abc123'
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
-
- # Create the state for testing
- commstate = lora_comms.CommsState()
-
- # These are the expected messages and their arguments
- exptmsgs = [
- (CMD_WAITFOR, [ 30 ]),
- (CMD_RUNFOR, [ 1, 50 ]),
- (CMD_TERMINATE, [ ]),
- ]
- def procmsg(msg, outbuf):
- msgbuf = msg._from()
- #print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf))
- cmd = msgbuf[0]
- args = [ int.from_bytes(msgbuf[x:x + 4],
- byteorder='little') for x in range(1, len(msgbuf),
- 4) ]
-
- if exptmsgs[0] == (cmd, args):
- exptmsgs.pop(0)
- outbuf[0].pkt[0] = cmd
- outbuf[0].pktlen = 1
- else: #pragma: no cover
- raise RuntimeError('cmd not found')
-
- # wrap the callback function
- cb = lora_comms.process_msgfunc_t(procmsg)
-
- class CCodeSD(MockSyncDatagram):
- async def runner(self):
- for expectlen in [ 24, 17, 9, 9, 9 ]:
- # get message
- gb = await self.get()
- r = make_pktbuf(gb)
-
- outbytes = bytearray(64)
- outbuf = make_pktbuf(outbytes)
-
- # process the test message
- lora_comms.comms_process(commstate, r,
- outbuf)
-
- # make sure the reply matches length
- _self.assertEqual(expectlen,
- outbuf.pktlen)
-
- # save what was originally replied
- origmsg = outbuf._from()
-
- # pretend that the reply didn't make it
- r = make_pktbuf(gb)
- outbuf = make_pktbuf(outbytes)
-
- lora_comms.comms_process(commstate, r,
- outbuf)
-
- # make sure that the reply matches
- # the previous
- _self.assertEqual(origmsg,
- outbuf._from())
-
- # pass the reply back
- await self.put(outbytes[:outbuf.pktlen])
-
- # Generate shared key
- shared_key = os.urandom(32)
-
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
-
- # Create test fixture
- tsd = CCodeSD()
- l = LORANode(tsd, shared=shared_key)
-
- # Send various messages
- await l.start()
-
- await l.waitfor(30)
-
- await l.runfor(1, 50)
-
- await l.terminate()
-
- await tsd.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
-
- # Make sure all the expected messages have been
- # processed.
- self.assertFalse(exptmsgs)
- #_debprint('done')
|