@@ -23,7 +23,9 @@
#
import asyncio
import contextlib
import functools
import itertools
import os
import unittest
@@ -42,6 +44,7 @@ CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
# The follow commands are queue up, but will be acknoledged when queued
CMD_WAITFOR = 2 # arg: (length): waits for length seconds
CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
CMD_PING = 4 # arg: (): a no op command
class LORANode(object):
'''Implement a LORANode initiator.'''
@@ -62,7 +65,8 @@ class LORANode(object):
pkt = await self.sendrecvvalid(b'confirm')
if pkt != b'confirmed':
raise RuntimeError
raise RuntimeError('got invalid response: %s' %
repr(pkt))
async def sendrecvvalid(self, msg):
msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN)
@@ -73,6 +77,10 @@ class LORANode(object):
resp = await self.sd.sendtillrecv(msg, 1)
#_debprint('got:', resp)
# skip empty messages
if len(resp) == 0:
continue
try:
decmsg = self.st.recv_enc(resp[:-self.MAC_LEN])
self.st.recv_mac(resp[-self.MAC_LEN:])
@@ -110,6 +118,9 @@ class LORANode(object):
async def runfor(self, chan, length):
return await self._sendcmd(CMD_RUNFOR, chan, length)
async def ping(self):
return await self._sendcmd(CMD_PING)
async def terminate(self):
return await self._sendcmd(CMD_TERMINATE)
@@ -127,7 +138,6 @@ class SyncDatagram(object):
raise NotImplementedError
async def send(self, data): #pragma: no cover
'''Send a datagram.'''
raise NotImplementedError
@@ -153,7 +163,6 @@ class MockSyncDatagram(SyncDatagram):
def __init__(self):
self.sendq = asyncio.Queue()
self.recvq = asyncio.Queue()
self.task = None
self.task = asyncio.create_task(self.runner())
self.get = self.sendq.get
@@ -225,6 +234,211 @@ def _debprint(*args): # pragma: no cover
*args)
sys.stdout.flush()
class AsyncSequence(object):
'''
Object used for sequencing async functions. To use, use the
asynchronous context manager created by the sync method. For
example:
seq = AsyncSequence()
async func1():
async with seq.sync(1):
second_fun()
async func2():
async with seq.sync(0):
first_fun()
This will make sure that function first_fun is run before running
the function second_fun. If a previous block raises an Exception,
it will be passed up, and all remaining blocks (and future ones)
will raise a CancelledError to help ensure that any tasks are
properly cleaned up.
'''
def __init__(self, positerfactory=lambda: itertools.count()):
'''The argument positerfactory, is a factory that will
create an iterator that will be used for the values that
are passed to the sync method.'''
self.positer = positerfactory()
self.token = object()
self.die = False
self.waiting = {
next(self.positer): self.token
}
@contextlib.asynccontextmanager
async def sync(self, pos):
'''An async context manager that will be run when it's
turn arrives. It will only run when all the previous
items in the iterator has been successfully run.'''
if self.die:
raise asyncio.CancelledError('seq cancelled')
if pos in self.waiting:
if self.waiting[pos] is not self.token:
raise RuntimeError('pos already waiting!')
else:
fut = asyncio.Future()
self.waiting[pos] = fut
await fut
# our time to shine!
del self.waiting[pos]
try:
yield None
except Exception as e:
# if we got an exception, things went pear shaped,
# shut everything down, and any future calls.
#_debprint('dieing...', repr(e))
self.die = True
# cancel existing blocks
while self.waiting:
k, v = self.waiting.popitem()
#_debprint('canceling: %s' % repr(k))
if v is self.token:
continue
# for Python 3.9:
# msg='pos %s raised exception: %s' %
# (repr(pos), repr(e))
v.cancel()
# populate real exception up
raise
else:
# handle next
nextpos = next(self.positer)
if nextpos in self.waiting:
#_debprint('np:', repr(self), nextpos,
# repr(self.waiting[nextpos]))
self.waiting[nextpos].set_result(None)
else:
self.waiting[nextpos] = self.token
class TestSequencing(unittest.IsolatedAsyncioTestCase):
@timeout(2)
async def test_seq_alreadywaiting(self):
waitseq = AsyncSequence()
seq = AsyncSequence()
async def fun1():
async with waitseq.sync(1):
pass
async def fun2():
async with seq.sync(1):
async with waitseq.sync(1): # pragma: no cover
pass
task1 = asyncio.create_task(fun1())
task2 = asyncio.create_task(fun2())
# spin things to make sure things advance
await asyncio.sleep(0)
async with seq.sync(0):
pass
with self.assertRaises(RuntimeError):
await task2
async with waitseq.sync(0):
pass
await task1
@timeout(2)
async def test_seqexc(self):
seq = AsyncSequence()
excseq = AsyncSequence()
async def excfun1():
async with seq.sync(1):
pass
async with excseq.sync(0):
raise ValueError('foo')
# that a block that enters first, but runs after
# raises an exception
async def excfun2():
async with seq.sync(0):
pass
async with excseq.sync(1): # pragma: no cover
pass
# that a block that enters after, raises an
# exception
async def excfun3():
async with seq.sync(2):
pass
async with excseq.sync(2): # pragma: no cover
pass
task1 = asyncio.create_task(excfun1())
task2 = asyncio.create_task(excfun2())
task3 = asyncio.create_task(excfun3())
with self.assertRaises(ValueError):
await task1
with self.assertRaises(asyncio.CancelledError):
await task2
with self.assertRaises(asyncio.CancelledError):
await task3
@timeout(2)
async def test_seq(self):
# test that a seq object when created
seq = AsyncSequence(lambda: itertools.count(1))
col = []
async def fun1():
async with seq.sync(1):
col.append(1)
async with seq.sync(2):
col.append(2)
async with seq.sync(4):
col.append(4)
async def fun2():
async with seq.sync(3):
col.append(3)
async with seq.sync(6):
col.append(6)
async def fun3():
async with seq.sync(5):
col.append(5)
# and various functions are run
task1 = asyncio.create_task(fun1())
task2 = asyncio.create_task(fun2())
task3 = asyncio.create_task(fun3())
# and the functions complete
await task3
await task2
await task1
# that the order they ran in was correct
self.assertEqual(col, list(range(1, 7)))
class TestLORANode(unittest.IsolatedAsyncioTestCase):
@timeout(2)
async def test_lora(self):
@@ -365,11 +579,11 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
exptmsgs = [
(CMD_WAITFOR, [ 30 ]),
(CMD_RUNFOR, [ 1, 50 ]),
(CMD_PING, [ ]),
(CMD_TERMINATE, [ ]),
]
def procmsg(msg, outbuf):
msgbuf = msg._from()
#print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf))
cmd = msgbuf[0]
args = [ int.from_bytes(msgbuf[x:x + 4],
byteorder='little') for x in range(1, len(msgbuf),
@@ -387,7 +601,7 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
class CCodeSD(MockSyncDatagram):
async def runner(self):
for expectlen in [ 24, 17, 9, 9, 9 ]:
for expectlen in [ 24, 17, 9, 9, 9, 9 ]:
# get message
gb = await self.get()
r = make_pktbuf(gb)
@@ -438,6 +652,8 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
await l.runfor(1, 50)
await l.ping()
await l.terminate()
await tsd.drain()
@@ -450,3 +666,148 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase):
# processed.
self.assertFalse(exptmsgs)
#_debprint('done')
@timeout(2)
async def test_ccode_newsession(self):
'''This test is to make sure that if an existing session
is running, that a new session can be established, and that
when it does, the old session becomes inactive.
'''
_self = self
from ctypes import pointer, sizeof, c_uint8
seq = AsyncSequence()
# seed the RNG
prngseed = b'abc123'
lora_comms.strobe_seed_prng((c_uint8 *
len(prngseed))(*prngseed), len(prngseed))
# Create the state for testing
commstate = lora_comms.CommsState()
# These are the expected messages and their arguments
exptmsgs = [
(CMD_WAITFOR, [ 30 ]),
(CMD_WAITFOR, [ 70 ]),
(CMD_WAITFOR, [ 40 ]),
(CMD_TERMINATE, [ ]),
]
def procmsg(msg, outbuf):
msgbuf = msg._from()
cmd = msgbuf[0]
args = [ int.from_bytes(msgbuf[x:x + 4],
byteorder='little') for x in range(1, len(msgbuf),
4) ]
if exptmsgs[0] == (cmd, args):
exptmsgs.pop(0)
outbuf[0].pkt[0] = cmd
outbuf[0].pktlen = 1
else: #pragma: no cover
raise RuntimeError('cmd not found: %d' % cmd)
# wrap the callback function
cb = lora_comms.process_msgfunc_t(procmsg)
class FlipMsg(object):
async def flipmsg(self):
# get message
gb = await self.get()
r = make_pktbuf(gb)
outbytes = bytearray(64)
outbuf = make_pktbuf(outbytes)
# process the test message
lora_comms.comms_process(commstate, r,
outbuf)
# pass the reply back
pkt = outbytes[:outbuf.pktlen]
await self.put(pkt)
# this class always passes messages, this is
# used for the first session.
class CCodeSD1(MockSyncDatagram, FlipMsg):
async def runner(self):
for i in range(3):
await self.flipmsg()
async with seq.sync(0):
# create bogus message
r = make_pktbuf(b'0'*24)
outbytes = bytearray(64)
outbuf = make_pktbuf(outbytes)
# process the bogus message
lora_comms.comms_process(commstate, r,
outbuf)
# make sure there was not a response
_self.assertEqual(outbuf.pktlen, 0)
await self.flipmsg()
# this one is special in that it will pause after the first
# message to ensure that the previous session will continue
# to work, AND that if a new "new" session comes along, it
# will override the previous new session that hasn't been
# confirmed yet.
class CCodeSD2(MockSyncDatagram, FlipMsg):
async def runner(self):
# pass one message from the new session
async with seq.sync(1):
# There might be a missing case
# handled for when the confirmed
# message is generated, but lost.
await self.flipmsg()
# and the old session is still active
await l.waitfor(70)
async with seq.sync(2):
for i in range(3):
await self.flipmsg()
# Generate shared key
shared_key = os.urandom(32)
# Initialize everything
lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
# Create test fixture
tsd = CCodeSD1()
l = LORANode(tsd, shared=shared_key)
# Send various messages
await l.start()
await l.waitfor(30)
# Ensure that a new one can take over
tsd2 = CCodeSD2()
l2 = LORANode(tsd2, shared=shared_key)
# Send various messages
await l2.start()
await l2.waitfor(40)
await l2.terminate()
await tsd.drain()
await tsd2.drain()
# Make sure all messages have been processed
self.assertTrue(tsd.sendq.empty())
self.assertTrue(tsd.recvq.empty())
self.assertTrue(tsd2.sendq.empty())
self.assertTrue(tsd2.recvq.empty())
# Make sure all the expected messages have been
# processed.
self.assertFalse(exptmsgs)