|
- # Copyright 2021 John-Mark Gurney.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- # 1. Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # 2. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
- # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
- # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
- # SUCH DAMAGE.
- #
-
- import asyncio
- import contextlib
- import functools
- import itertools
- import os
- import unittest
-
- from Strobe.Strobe import Strobe, KeccakF
- from Strobe.Strobe import AuthenticationFailed
-
- import lora_comms
- from lora_comms import make_pktbuf
- import multicast
- from util import *
-
- domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
-
- # Response to command will be the CMD and any arguments if needed.
- # The command is encoded as an unsigned byte
- CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
-
- # The follow commands are queue up, but will be acknoledged when queued
- CMD_WAITFOR = 2 # arg: (length): waits for length seconds
- CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
- CMD_PING = 4 # arg: (): a no op command
-
- class LORANode(object):
- '''Implement a LORANode initiator.'''
-
- MAC_LEN = 8
-
- def __init__(self, syncdatagram, shared=None):
- self.sd = syncdatagram
- self.st = Strobe(domain, F=KeccakF(800))
- if shared is not None:
- self.st.key(shared)
-
- async def start(self):
- resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset')
-
- self.st.ratchet()
-
- pkt = await self.sendrecvvalid(b'confirm')
-
- if pkt != b'confirmed':
- raise RuntimeError('got invalid response: %s' %
- repr(pkt))
-
- async def sendrecvvalid(self, msg):
- msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN)
-
- origstate = self.st.copy()
-
- while True:
- resp = await self.sd.sendtillrecv(msg, 1)
- #_debprint('got:', resp)
-
- # skip empty messages
- if len(resp) == 0:
- continue
-
- try:
- decmsg = self.st.recv_enc(resp[:-self.MAC_LEN])
- self.st.recv_mac(resp[-self.MAC_LEN:])
- break
- except AuthenticationFailed:
- # didn't get a valid packet, restore
- # state and retry
-
- #_debprint('failed')
- self.st.set_state_from(origstate)
-
- #_debprint('got rep:', repr(resp), repr(decmsg))
- return decmsg
-
- @staticmethod
- def _encodeargs(*args):
- r = []
- for i in args:
- r.append(i.to_bytes(4, byteorder='little'))
-
- return b''.join(r)
-
- async def _sendcmd(self, cmd, *args):
- cmdbyte = cmd.to_bytes(1, byteorder='little')
- resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args))
-
- if resp[0:1] != cmdbyte:
- raise RuntimeError(
- 'response does not match, got: %s, expected: %s' %
- (repr(resp[0:1]), repr(cmdbyte)))
-
- async def waitfor(self, length):
- return await self._sendcmd(CMD_WAITFOR, length)
-
- async def runfor(self, chan, length):
- return await self._sendcmd(CMD_RUNFOR, chan, length)
-
- async def ping(self):
- return await self._sendcmd(CMD_PING)
-
- async def terminate(self):
- return await self._sendcmd(CMD_TERMINATE)
-
- class SyncDatagram(object):
- '''Base interface for a more simple synchronous interface.'''
-
- async def recv(self, timeout=None): #pragma: no cover
- '''Receive a datagram. If timeout is not None, wait that many
- seconds, and if nothing is received in that time, raise an
- TimeoutError exception.'''
-
- raise NotImplementedError
-
- async def send(self, data): #pragma: no cover
-
- raise NotImplementedError
-
- async def sendtillrecv(self, data, freq):
- '''Send the datagram in data, every freq seconds until a datagram
- is received. If timeout seconds happen w/o receiving a datagram,
- then raise an TimeoutError exception.'''
-
- while True:
- #_debprint('sending:', repr(data))
- await self.send(data)
- try:
- return await self.recv(freq)
- except TimeoutError:
- pass
-
- class MulticastSyncDatagram(SyncDatagram):
- '''
- An implementation of SyncDatagram that uses the provided
- multicast address maddr as the source/sink of the packets.
-
- Note that once created, the start coroutine needs to be
- await'd before being passed to a LORANode so that everything
- is running.
- '''
-
- # Note: sent packets will be received. A similar method to
- # what was done in multicast.{to,from}_loragw could be done
- # here as well, that is passing in a set of packets to not
- # pass back up.
-
- def __init__(self, maddr):
- self.maddr = maddr
- self._ignpkts = set()
-
- async def start(self):
- self.mr = await multicast.create_multicast_receiver(self.maddr)
- self.mt = await multicast.create_multicast_transmitter(
- self.maddr)
-
- async def _recv(self):
- while True:
- pkt = await self.mr.recv()
- pkt = pkt[0]
- if pkt not in self._ignpkts:
- return pkt
-
- self._ignpkts.remove(pkt)
-
- async def recv(self, timeout=None): #pragma: no cover
- r = await asyncio.wait_for(self._recv(), timeout=timeout)
-
- return r
-
- async def send(self, data): #pragma: no cover
- self._ignpkts.add(bytes(data))
- await self.mt.send(data)
-
- def close(self):
- '''Shutdown communications.'''
-
- self.mr.close()
- self.mr = None
- self.mt.close()
- self.mt = None
-
- class MockSyncDatagram(SyncDatagram):
- '''A testing version of SyncDatagram. Define a method runner which
- implements part of the sequence. In the function, await on either
- self.get, to wait for the other side to send something, or await
- self.put w/ data to send.'''
-
- def __init__(self):
- self.sendq = asyncio.Queue()
- self.recvq = asyncio.Queue()
- self.task = asyncio.create_task(self.runner())
-
- self.get = self.sendq.get
- self.put = self.recvq.put
-
- async def drain(self):
- '''Wait for the runner thread to finish up.'''
-
- return await self.task
-
- async def runner(self): #pragma: no cover
- raise NotImplementedError
-
- async def recv(self, timeout=None):
- return await self.recvq.get()
-
- async def send(self, data):
- return await self.sendq.put(data)
-
- def __del__(self): #pragma: no cover
- if self.task is not None and not self.task.done():
- self.task.cancel()
-
- class TestSyncData(unittest.IsolatedAsyncioTestCase):
- async def test_syncsendtillrecv(self):
- class MySync(SyncDatagram):
- def __init__(self):
- self.sendq = []
- self.resp = [ TimeoutError(), b'a' ]
-
- async def recv(self, timeout=None):
- assert timeout == 1
- r = self.resp.pop(0)
- if isinstance(r, Exception):
- raise r
-
- return r
-
- async def send(self, data):
- self.sendq.append(data)
-
- ms = MySync()
-
- r = await ms.sendtillrecv(b'foo', 1)
-
- self.assertEqual(r, b'a')
- self.assertEqual(ms.sendq, [ b'foo', b'foo' ])
-
- class AsyncSequence(object):
- '''
- Object used for sequencing async functions. To use, use the
- asynchronous context manager created by the sync method. For
- example:
- seq = AsyncSequence()
- async func1():
- async with seq.sync(1):
- second_fun()
-
- async func2():
- async with seq.sync(0):
- first_fun()
-
- This will make sure that function first_fun is run before running
- the function second_fun. If a previous block raises an Exception,
- it will be passed up, and all remaining blocks (and future ones)
- will raise a CancelledError to help ensure that any tasks are
- properly cleaned up.
- '''
-
- def __init__(self, positerfactory=lambda: itertools.count()):
- '''The argument positerfactory, is a factory that will
- create an iterator that will be used for the values that
- are passed to the sync method.'''
-
- self.positer = positerfactory()
- self.token = object()
- self.die = False
- self.waiting = {
- next(self.positer): self.token
- }
-
- async def simpsync(self, pos):
- async with self.sync(pos):
- pass
-
- @contextlib.asynccontextmanager
- async def sync(self, pos):
- '''An async context manager that will be run when it's
- turn arrives. It will only run when all the previous
- items in the iterator has been successfully run.'''
-
- if self.die:
- raise asyncio.CancelledError('seq cancelled')
-
- if pos in self.waiting:
- if self.waiting[pos] is not self.token:
- raise RuntimeError('pos already waiting!')
- else:
- fut = asyncio.Future()
- self.waiting[pos] = fut
- await fut
-
- # our time to shine!
- del self.waiting[pos]
-
- try:
- yield None
- except Exception as e:
- # if we got an exception, things went pear shaped,
- # shut everything down, and any future calls.
-
- #_debprint('dieing...', repr(e))
- self.die = True
-
- # cancel existing blocks
- while self.waiting:
- k, v = self.waiting.popitem()
- #_debprint('canceling: %s' % repr(k))
- if v is self.token:
- continue
-
- # for Python 3.9:
- # msg='pos %s raised exception: %s' %
- # (repr(pos), repr(e))
- v.cancel()
-
- # populate real exception up
- raise
- else:
- # handle next
- nextpos = next(self.positer)
-
- if nextpos in self.waiting:
- #_debprint('np:', repr(self), nextpos,
- # repr(self.waiting[nextpos]))
- self.waiting[nextpos].set_result(None)
- else:
- self.waiting[nextpos] = self.token
-
- class TestSequencing(unittest.IsolatedAsyncioTestCase):
- @timeout(2)
- async def test_seq_alreadywaiting(self):
- waitseq = AsyncSequence()
-
- seq = AsyncSequence()
-
- async def fun1():
- async with waitseq.sync(1):
- pass
-
- async def fun2():
- async with seq.sync(1):
- async with waitseq.sync(1): # pragma: no cover
- pass
-
- task1 = asyncio.create_task(fun1())
- task2 = asyncio.create_task(fun2())
-
- # spin things to make sure things advance
- await asyncio.sleep(0)
-
- async with seq.sync(0):
- pass
-
- with self.assertRaises(RuntimeError):
- await task2
-
- async with waitseq.sync(0):
- pass
-
- await task1
-
- @timeout(2)
- async def test_seqexc(self):
- seq = AsyncSequence()
-
- excseq = AsyncSequence()
-
- async def excfun1():
- async with seq.sync(1):
- pass
-
- async with excseq.sync(0):
- raise ValueError('foo')
-
- # that a block that enters first, but runs after
- # raises an exception
- async def excfun2():
- async with seq.sync(0):
- pass
-
- async with excseq.sync(1): # pragma: no cover
- pass
-
- # that a block that enters after, raises an
- # exception
- async def excfun3():
- async with seq.sync(2):
- pass
-
- async with excseq.sync(2): # pragma: no cover
- pass
-
- task1 = asyncio.create_task(excfun1())
- task2 = asyncio.create_task(excfun2())
- task3 = asyncio.create_task(excfun3())
-
- with self.assertRaises(ValueError):
- await task1
-
- with self.assertRaises(asyncio.CancelledError):
- await task2
-
- with self.assertRaises(asyncio.CancelledError):
- await task3
-
- @timeout(2)
- async def test_seq(self):
- # test that a seq object when created
- seq = AsyncSequence(lambda: itertools.count(1))
-
- col = []
-
- async def fun1():
- async with seq.sync(1):
- col.append(1)
-
- async with seq.sync(2):
- col.append(2)
-
- async with seq.sync(4):
- col.append(4)
-
- async def fun2():
- async with seq.sync(3):
- col.append(3)
-
- async with seq.sync(6):
- col.append(6)
-
- async def fun3():
- async with seq.sync(5):
- col.append(5)
-
- # and various functions are run
- task1 = asyncio.create_task(fun1())
- task2 = asyncio.create_task(fun2())
- task3 = asyncio.create_task(fun3())
-
- # and the functions complete
- await task3
- await task2
- await task1
-
- # that the order they ran in was correct
- self.assertEqual(col, list(range(1, 7)))
-
- class TestLORANode(unittest.IsolatedAsyncioTestCase):
- @timeout(2)
- async def test_lora(self):
- _self = self
- shared_key = os.urandom(32)
-
- class TestSD(MockSyncDatagram):
- async def sendgettest(self, msg):
- '''Send the message, but make sure that if a
- bad message is sent afterward, that it replies
- w/ the same previous message.
- '''
-
- await self.put(msg)
- resp = await self.get()
-
- await self.put(b'bogusmsg' * 5)
-
- resp2 = await self.get()
-
- _self.assertEqual(resp, resp2)
-
- return resp
-
- async def runner(self):
- l = Strobe(domain, F=KeccakF(800))
-
- l.key(shared_key)
-
- # start handshake
- r = await self.get()
-
- pkt = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert pkt.endswith(b'reqreset')
-
- # make sure junk gets ignored
- await self.put(b'sdlfkj')
-
- # and that the packet remains the same
- _self.assertEqual(r, await self.get())
-
- # and a couple more times
- await self.put(b'0' * 24)
- _self.assertEqual(r, await self.get())
- await self.put(b'0' * 32)
- _self.assertEqual(r, await self.get())
-
- # send the response
- await self.put(l.send_enc(os.urandom(16)) +
- l.send_mac(8))
-
- # require no more back tracking at this point
- l.ratchet()
-
- # get the confirmation message
- r = await self.get()
-
- # test the resend capabilities
- await self.put(b'0' * 24)
- _self.assertEqual(r, await self.get())
-
- # decode confirmation message
- c = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- # assert that we got it
- _self.assertEqual(c, b'confirm')
-
- # send confirmed reply
- r = await self.sendgettest(l.send_enc(
- b'confirmed') + l.send_mac(8))
-
- # test and decode remaining command messages
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_WAITFOR
- assert int.from_bytes(cmd[1:],
- byteorder='little') == 30
-
- r = await self.sendgettest(l.send_enc(
- cmd[0:1]) + l.send_mac(8))
-
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_RUNFOR
- assert int.from_bytes(cmd[1:5],
- byteorder='little') == 1
- assert int.from_bytes(cmd[5:],
- byteorder='little') == 50
-
- r = await self.sendgettest(l.send_enc(
- cmd[0:1]) + l.send_mac(8))
-
- cmd = l.recv_enc(r[:-8])
- l.recv_mac(r[-8:])
-
- assert cmd[0] == CMD_TERMINATE
-
- await self.put(l.send_enc(cmd[0:1]) +
- l.send_mac(8))
-
- tsd = TestSD()
- l = LORANode(tsd, shared=shared_key)
-
- await l.start()
-
- await l.waitfor(30)
-
- await l.runfor(1, 50)
-
- await l.terminate()
-
- await tsd.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
- #_debprint('done')
-
- @timeout(2)
- async def test_ccode_badmsgs(self):
- # Test to make sure that various bad messages in the
- # handshake process are rejected even if the attacker
- # has the correct key. This just keeps the protocol
- # tight allowing for variations in the future.
-
- # seed the RNG
- prngseed = b'abc123'
- from ctypes import c_uint8
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
-
- # Create the state for testing
- commstate = lora_comms.CommsState()
-
- cb = lora_comms.process_msgfunc_t(lambda msg, outbuf: None)
-
- # Generate shared key
- shared_key = os.urandom(32)
-
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
-
- # Create test fixture, only use it to init crypto state
- tsd = SyncDatagram()
- l = LORANode(tsd, shared=shared_key)
-
- # copy the crypto state
- cstate = l.st.copy()
-
- # compose an incorrect init message
- msg = os.urandom(16) + b'othre'
- msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
-
- out = lora_comms.comms_process_wrap(commstate, msg)
-
- self.assertFalse(out)
-
- # copy the crypto state
- cstate = l.st.copy()
-
- # compose an incorrect init message
- msg = os.urandom(16) + b' eqreset'
- msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
-
- out = lora_comms.comms_process_wrap(commstate, msg)
-
- self.assertFalse(out)
-
- # compose the correct init message
- msg = os.urandom(16) + b'reqreset'
- msg = l.st.send_enc(msg) + l.st.send_mac(l.MAC_LEN)
-
- out = lora_comms.comms_process_wrap(commstate, msg)
-
- l.st.recv_enc(out[:-l.MAC_LEN])
- l.st.recv_mac(out[-l.MAC_LEN:])
-
- l.st.ratchet()
-
- # copy the crypto state
- cstate = l.st.copy()
-
- # compose an incorrect confirmed message
- msg = b'onfirm'
- msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
-
- out = lora_comms.comms_process_wrap(commstate, msg)
-
- self.assertFalse(out)
-
- # copy the crypto state
- cstate = l.st.copy()
-
- # compose an incorrect confirmed message
- msg = b' onfirm'
- msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
-
- out = lora_comms.comms_process_wrap(commstate, msg)
-
- self.assertFalse(out)
-
- @timeout(2)
- async def test_ccode(self):
- _self = self
- from ctypes import c_uint8
-
- # seed the RNG
- prngseed = b'abc123'
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
-
- # Create the state for testing
- commstate = lora_comms.CommsState()
-
- # These are the expected messages and their arguments
- exptmsgs = [
- (CMD_WAITFOR, [ 30 ]),
- (CMD_RUNFOR, [ 1, 50 ]),
- (CMD_PING, [ ]),
- (CMD_TERMINATE, [ ]),
- ]
- def procmsg(msg, outbuf):
- msgbuf = msg._from()
- cmd = msgbuf[0]
- args = [ int.from_bytes(msgbuf[x:x + 4],
- byteorder='little') for x in range(1, len(msgbuf),
- 4) ]
-
- if exptmsgs[0] == (cmd, args):
- exptmsgs.pop(0)
- outbuf[0].pkt[0] = cmd
- outbuf[0].pktlen = 1
- else: #pragma: no cover
- raise RuntimeError('cmd not found')
-
- # wrap the callback function
- cb = lora_comms.process_msgfunc_t(procmsg)
-
- class CCodeSD(MockSyncDatagram):
- async def runner(self):
- for expectlen in [ 24, 17, 9, 9, 9, 9 ]:
- # get message
- inmsg = await self.get()
-
- # process the test message
- out = lora_comms.comms_process_wrap(
- commstate, inmsg)
-
- # make sure the reply matches length
- _self.assertEqual(expectlen, len(out))
-
- # save what was originally replied
- origmsg = out
-
- # pretend that the reply didn't make it
- out = lora_comms.comms_process_wrap(
- commstate, inmsg)
-
- # make sure that the reply matches
- # the previous
- _self.assertEqual(origmsg, out)
-
- # pass the reply back
- await self.put(out)
-
- # Generate shared key
- shared_key = os.urandom(32)
-
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
-
- # Create test fixture
- tsd = CCodeSD()
- l = LORANode(tsd, shared=shared_key)
-
- # Send various messages
- await l.start()
-
- await l.waitfor(30)
-
- await l.runfor(1, 50)
-
- await l.ping()
-
- await l.terminate()
-
- await tsd.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
-
- # Make sure all the expected messages have been
- # processed.
- self.assertFalse(exptmsgs)
- #_debprint('done')
-
- @timeout(2)
- async def test_ccode_newsession(self):
- '''This test is to make sure that if an existing session
- is running, that a new session can be established, and that
- when it does, the old session becomes inactive.
- '''
-
- _self = self
- from ctypes import c_uint8
-
- seq = AsyncSequence()
-
- # seed the RNG
- prngseed = b'abc123'
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
-
- # Create the state for testing
- commstate = lora_comms.CommsState()
-
- # These are the expected messages and their arguments
- exptmsgs = [
- (CMD_WAITFOR, [ 30 ]),
- (CMD_WAITFOR, [ 70 ]),
- (CMD_WAITFOR, [ 40 ]),
- (CMD_TERMINATE, [ ]),
- ]
- def procmsg(msg, outbuf):
- msgbuf = msg._from()
- cmd = msgbuf[0]
- args = [ int.from_bytes(msgbuf[x:x + 4],
- byteorder='little') for x in range(1, len(msgbuf),
- 4) ]
-
- if exptmsgs[0] == (cmd, args):
- exptmsgs.pop(0)
- outbuf[0].pkt[0] = cmd
- outbuf[0].pktlen = 1
- else: #pragma: no cover
- raise RuntimeError('cmd not found: %d' % cmd)
-
- # wrap the callback function
- cb = lora_comms.process_msgfunc_t(procmsg)
-
- class FlipMsg(object):
- async def flipmsg(self):
- # get message
- inmsg = await self.get()
-
- # process the test message
- out = lora_comms.comms_process_wrap(
- commstate, inmsg)
-
- # pass the reply back
- await self.put(out)
-
- # this class always passes messages, this is
- # used for the first session.
- class CCodeSD1(MockSyncDatagram, FlipMsg):
- async def runner(self):
- for i in range(3):
- await self.flipmsg()
-
- async with seq.sync(0):
- # create bogus message
- inmsg = b'0'*24
-
- # process the bogus message
- out = lora_comms.comms_process_wrap(
- commstate, inmsg)
-
- # make sure there was not a response
- _self.assertFalse(out)
-
- await self.flipmsg()
-
- # this one is special in that it will pause after the first
- # message to ensure that the previous session will continue
- # to work, AND that if a new "new" session comes along, it
- # will override the previous new session that hasn't been
- # confirmed yet.
- class CCodeSD2(MockSyncDatagram, FlipMsg):
- async def runner(self):
- # pass one message from the new session
- async with seq.sync(1):
- # There might be a missing case
- # handled for when the confirmed
- # message is generated, but lost.
- await self.flipmsg()
-
- # and the old session is still active
- await l.waitfor(70)
-
- async with seq.sync(2):
- for i in range(3):
- await self.flipmsg()
-
- # Generate shared key
- shared_key = os.urandom(32)
-
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
-
- # Create test fixture
- tsd = CCodeSD1()
- l = LORANode(tsd, shared=shared_key)
-
- # Send various messages
- await l.start()
-
- await l.waitfor(30)
-
- # Ensure that a new one can take over
- tsd2 = CCodeSD2()
-
- l2 = LORANode(tsd2, shared=shared_key)
-
- # Send various messages
- await l2.start()
-
- await l2.waitfor(40)
-
- await l2.terminate()
-
- await tsd.drain()
- await tsd2.drain()
-
- # Make sure all messages have been processed
- self.assertTrue(tsd.sendq.empty())
- self.assertTrue(tsd.recvq.empty())
- self.assertTrue(tsd2.sendq.empty())
- self.assertTrue(tsd2.recvq.empty())
-
- # Make sure all the expected messages have been
- # processed.
- self.assertFalse(exptmsgs)
-
- class TestLoRaNodeMulticast(unittest.IsolatedAsyncioTestCase):
- # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
- maddr = ('224.0.0.198', 48542)
-
- @timeout(2)
- async def test_multisyncdgram(self):
- # Test the implementation of the multicast version of
- # SyncDatagram
-
- _self = self
- from ctypes import c_uint8
-
- # seed the RNG
- prngseed = b'abc123'
- lora_comms.strobe_seed_prng((c_uint8 *
- len(prngseed))(*prngseed), len(prngseed))
-
- # Create the state for testing
- commstate = lora_comms.CommsState()
-
- # These are the expected messages and their arguments
- exptmsgs = [
- (CMD_WAITFOR, [ 30 ]),
- (CMD_PING, [ ]),
- (CMD_TERMINATE, [ ]),
- ]
- def procmsg(msg, outbuf):
- msgbuf = msg._from()
- cmd = msgbuf[0]
- args = [ int.from_bytes(msgbuf[x:x + 4],
- byteorder='little') for x in range(1, len(msgbuf),
- 4) ]
-
- if exptmsgs[0] == (cmd, args):
- exptmsgs.pop(0)
- outbuf[0].pkt[0] = cmd
- outbuf[0].pktlen = 1
- else: #pragma: no cover
- raise RuntimeError('cmd not found')
-
- # wrap the callback function
- cb = lora_comms.process_msgfunc_t(procmsg)
-
- # Generate shared key
- shared_key = os.urandom(32)
-
- # Initialize everything
- lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
-
- # create the object we are testing
- msd = MulticastSyncDatagram(self.maddr)
-
- seq = AsyncSequence()
-
- async def clienttask():
- mr = await multicast.create_multicast_receiver(
- self.maddr)
- mt = await multicast.create_multicast_transmitter(
- self.maddr)
-
- try:
- # make sure the above threads are running
- await seq.simpsync(0)
-
- while True:
- pkt = await mr.recv()
- msg = pkt[0]
-
- out = lora_comms.comms_process_wrap(
- commstate, msg)
-
- if out:
- await mt.send(out)
- finally:
- mr.close()
- mt.close()
-
- task = asyncio.create_task(clienttask())
-
- # start it
- await msd.start()
-
- # pass it to a node
- l = LORANode(msd, shared=shared_key)
-
- await seq.simpsync(1)
-
- # Send various messages
- await l.start()
-
- await l.waitfor(30)
-
- await l.ping()
-
- await l.terminate()
-
- # shut things down
- ln = None
- msd.close()
-
- task.cancel()
-
- with self.assertRaises(asyncio.CancelledError):
- await task
|