This also adds a check to make sure that the allocated structure matches the C code so that things won't break. This breaks out the session state into it's own object... allowing a common function to handle the state machine... This also adds a new powerful testing tool. It's a Synchronization primitive that will ensure blocks of code run in the expected order, and only run when the previous block has fully run... This helps ensure ordering between many tasks, to cause race conditions that would otherwise be hard to cause..irr_shared
@@ -44,6 +44,13 @@ comms_pktbuf_equal(struct pktbuf a, struct pktbuf b) | |||
return memcmp(a.pkt, b.pkt, a.pktlen) == 0; | |||
} | |||
size_t | |||
_comms_state_size() | |||
{ | |||
return sizeof(struct comms_state); | |||
} | |||
size_t | |||
_strobe_state_size() | |||
{ | |||
@@ -56,7 +63,6 @@ comms_init(struct comms_state *cs, process_msgfunc_t pmf, struct pktbuf *shared) | |||
{ | |||
*cs = (struct comms_state){ | |||
.cs_comm_state = COMMS_WAIT_REQUEST, | |||
.cs_procmsg = pmf, | |||
}; | |||
@@ -66,45 +72,44 @@ comms_init(struct comms_state *cs, process_msgfunc_t pmf, struct pktbuf *shared) | |||
strobe_key(&cs->cs_start, SYM_KEY, shared->pkt, shared->pktlen); | |||
/* copy starting state over to initial state */ | |||
cs->cs_state = cs->cs_start; | |||
cs->cs_active = (struct comms_session){ | |||
.cs_crypto = cs->cs_start, | |||
.cs_state = COMMS_WAIT_REQUEST, | |||
}; | |||
cs->cs_pending = cs->cs_active; | |||
} | |||
#define CONFIRMED_STR_BASE "confirmed" | |||
#define CONFIRMED_STR ((const uint8_t *)CONFIRMED_STR_BASE) | |||
#define CONFIRMED_STR_LEN (sizeof(CONFIRMED_STR_BASE) - 1) | |||
/* | |||
* encrypted data to be processed is passed in via pbin. | |||
* | |||
* The pktbuf pointed to by pbout contains the buffer that a [encrypted] | |||
* response will be written to. The length needs to be updated, where 0 | |||
* means no reply. | |||
*/ | |||
void | |||
comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) | |||
static void | |||
_comms_process_session(struct comms_state *cs, struct comms_session *sess, struct pktbuf pbin, struct pktbuf *pbout) | |||
{ | |||
strobe_s tmp; | |||
uint8_t buf[64] = {}; | |||
struct pktbuf pbmsg, pbrep; | |||
ssize_t cnt, ret, msglen; | |||
/* if the current msg matches the previous */ | |||
if (comms_pktbuf_equal(pbin, cs->cs_prevmsg)) { | |||
/* send the previous response */ | |||
pbout->pktlen = cs->cs_prevmsgresp.pktlen; | |||
memcpy(pbout->pkt, cs->cs_prevmsgresp.pkt, pbout->pktlen); | |||
return; | |||
} | |||
/* save the state incase the message is bad */ | |||
tmp = sess->cs_crypto; | |||
strobe_attach_buffer(&cs->cs_state, pbin.pkt, pbin.pktlen); | |||
strobe_attach_buffer(&sess->cs_crypto, pbin.pkt, pbin.pktlen); | |||
cnt = strobe_get(&cs->cs_state, APP_CIPHERTEXT, buf, pbin.pktlen - | |||
cnt = strobe_get(&sess->cs_crypto, APP_CIPHERTEXT, buf, pbin.pktlen - | |||
MAC_LEN); | |||
msglen = cnt; | |||
cnt = strobe_get(&cs->cs_state, MAC, pbin.pkt + | |||
cnt = strobe_get(&sess->cs_crypto, MAC, pbin.pkt + | |||
(pbin.pktlen - MAC_LEN), MAC_LEN); | |||
/* XXX - cnt != MAC_LEN test case */ | |||
/* MAC check failed */ | |||
if (cnt == -1) { | |||
/* restore the previous state */ | |||
sess->cs_crypto = tmp; | |||
pbout->pktlen = 0; | |||
return; | |||
} | |||
/* | |||
* if we have arrived here, MAC has been verified, and buf now | |||
@@ -112,29 +117,29 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) | |||
*/ | |||
/* attach the buffer for output */ | |||
strobe_attach_buffer(&cs->cs_state, pbout->pkt, pbout->pktlen); | |||
strobe_attach_buffer(&sess->cs_crypto, pbout->pkt, pbout->pktlen); | |||
ret = 0; | |||
switch (cs->cs_comm_state) { | |||
switch (sess->cs_state) { | |||
case COMMS_WAIT_REQUEST: | |||
/* XXX - reqreset check */ | |||
bare_strobe_randomize(buf, CHALLENGE_LEN); | |||
ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, buf, | |||
ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, buf, | |||
CHALLENGE_LEN); | |||
ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN); | |||
ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); | |||
strobe_operate(&cs->cs_state, RATCHET, NULL, 32); | |||
strobe_operate(&sess->cs_crypto, RATCHET, NULL, 32); | |||
cs->cs_comm_state = COMMS_WAIT_CONFIRM; | |||
sess->cs_state = COMMS_WAIT_CONFIRM; | |||
break; | |||
case COMMS_WAIT_CONFIRM: | |||
/* XXX - confirm check */ | |||
ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, CONFIRMED_STR, | |||
ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, CONFIRMED_STR, | |||
CONFIRMED_STR_LEN); | |||
ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN); | |||
cs->cs_comm_state = COMMS_PROCESS_MSGS; | |||
ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); | |||
sess->cs_state = COMMS_PROCESS_MSGS; | |||
break; | |||
case COMMS_PROCESS_MSGS: { | |||
@@ -150,9 +155,9 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) | |||
cs->cs_procmsg(pbmsg, &pbrep); | |||
ret = strobe_put(&cs->cs_state, APP_CIPHERTEXT, repbuf, | |||
ret = strobe_put(&sess->cs_crypto, APP_CIPHERTEXT, repbuf, | |||
pbrep.pktlen); | |||
ret += strobe_put(&cs->cs_state, MAC, NULL, MAC_LEN); | |||
ret += strobe_put(&sess->cs_crypto, MAC, NULL, MAC_LEN); | |||
break; | |||
} | |||
@@ -161,8 +166,36 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) | |||
/* set the output buffer length */ | |||
pbout->pktlen = ret; | |||
if (ret != 0) { | |||
} | |||
/* | |||
* encrypted data to be processed is passed in via pbin. | |||
* | |||
* The pktbuf pointed to by pbout contains the buffer that a [encrypted] | |||
* response will be written to. The length needs to be updated, where 0 | |||
* means no reply. | |||
*/ | |||
void | |||
comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) | |||
{ | |||
struct pktbuf pbouttmp; | |||
/* if the current msg matches the previous */ | |||
if (comms_pktbuf_equal(pbin, cs->cs_prevmsg)) { | |||
/* send the previous response */ | |||
pbout->pktlen = cs->cs_prevmsgresp.pktlen; | |||
memcpy(pbout->pkt, cs->cs_prevmsgresp.pkt, pbout->pktlen); | |||
return; | |||
} | |||
/* try to use the active session */ | |||
pbouttmp = *pbout; | |||
_comms_process_session(cs, &cs->cs_active, pbin, &pbouttmp); | |||
if (pbouttmp.pktlen != 0) { | |||
retmsg: | |||
/* we accepted a new message store it */ | |||
*pbout = pbouttmp; | |||
/* store the req */ | |||
cs->cs_prevmsg.pkt = cs->cs_prevmsgbuf; | |||
@@ -173,5 +206,39 @@ comms_process(struct comms_state *cs, struct pktbuf pbin, struct pktbuf *pbout) | |||
cs->cs_prevmsgresp.pkt = cs->cs_prevmsgrespbuf; | |||
cs->cs_prevmsgresp.pktlen = pbout->pktlen; | |||
memcpy(cs->cs_prevmsgresp.pkt, pbout->pkt, pbout->pktlen); | |||
} else { | |||
/* active session didn't work, try cs_pending */ | |||
pbouttmp = *pbout; | |||
_comms_process_session(cs, &cs->cs_pending, pbin, &pbouttmp); | |||
if (cs->cs_pending.cs_state == COMMS_PROCESS_MSGS) { | |||
/* new active state */ | |||
cs->cs_active = cs->cs_pending; | |||
cs->cs_pending = (struct comms_session){ | |||
.cs_crypto = cs->cs_start, | |||
.cs_state = COMMS_WAIT_REQUEST, | |||
}; | |||
goto retmsg; | |||
} | |||
/* pending session didn't work, maybe new */ | |||
struct comms_session tmpsess; | |||
tmpsess = (struct comms_session){ | |||
.cs_crypto = cs->cs_start, | |||
.cs_state = COMMS_WAIT_REQUEST, | |||
}; | |||
pbouttmp = *pbout; | |||
_comms_process_session(cs, &tmpsess, pbin, &pbouttmp); | |||
if (tmpsess.cs_state == COMMS_WAIT_CONFIRM) { | |||
/* new request for session */ | |||
cs->cs_pending = tmpsess; | |||
*pbout = pbouttmp; | |||
} else { | |||
/* no packet to reply with */ | |||
pbout->pktlen = 0; | |||
} | |||
} | |||
} |
@@ -45,9 +45,33 @@ enum comm_state { | |||
COMMS_PROCESS_MSGS, | |||
}; | |||
struct comms_session { | |||
strobe_s cs_crypto; | |||
enum comm_state cs_state; | |||
}; | |||
/* | |||
* Each message will be passed to each state. | |||
* | |||
* cs_active can be in any state. | |||
* cs_pending can only be in a _WAIT_* state. | |||
* | |||
* When cs_pending advances to _PROCESS_MSGS, it will | |||
* replace cs_active, and cs_pending w/ be copied from cache | |||
* and set to _WAIT_REQUEST. | |||
* | |||
* If any message was not processed by the first to, a new session | |||
* will be attempted w/ the _start crypto state, and if it progresses | |||
* to _WAIT_CONFIG, it will replace cs_pending. | |||
* | |||
* We don't have to save the reply from a new session, because if the | |||
* reply gets lost, the initiator will send the request again and we'll | |||
* restart the session. | |||
*/ | |||
struct comms_state { | |||
strobe_s cs_state; | |||
enum comm_state cs_comm_state; | |||
struct comms_session cs_active; /* current active session */ | |||
struct comms_session cs_pending; /* current pending session */ | |||
strobe_s cs_start; /* special starting state cache */ | |||
process_msgfunc_t cs_procmsg; | |||
@@ -60,6 +84,7 @@ struct comms_state { | |||
}; | |||
size_t _strobe_state_size(); | |||
size_t _comms_state_size(); | |||
void comms_init(struct comms_state *, process_msgfunc_t, struct pktbuf *); | |||
void comms_process(struct comms_state *, struct pktbuf, struct pktbuf *); |
@@ -23,7 +23,9 @@ | |||
# | |||
import asyncio | |||
import contextlib | |||
import functools | |||
import itertools | |||
import os | |||
import unittest | |||
@@ -42,6 +44,7 @@ CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms | |||
# The follow commands are queue up, but will be acknoledged when queued | |||
CMD_WAITFOR = 2 # arg: (length): waits for length seconds | |||
CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds | |||
CMD_PING = 4 # arg: (): a no op command | |||
class LORANode(object): | |||
'''Implement a LORANode initiator.''' | |||
@@ -62,7 +65,8 @@ class LORANode(object): | |||
pkt = await self.sendrecvvalid(b'confirm') | |||
if pkt != b'confirmed': | |||
raise RuntimeError | |||
raise RuntimeError('got invalid response: %s' % | |||
repr(pkt)) | |||
async def sendrecvvalid(self, msg): | |||
msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN) | |||
@@ -73,6 +77,10 @@ class LORANode(object): | |||
resp = await self.sd.sendtillrecv(msg, 1) | |||
#_debprint('got:', resp) | |||
# skip empty messages | |||
if len(resp) == 0: | |||
continue | |||
try: | |||
decmsg = self.st.recv_enc(resp[:-self.MAC_LEN]) | |||
self.st.recv_mac(resp[-self.MAC_LEN:]) | |||
@@ -110,6 +118,9 @@ class LORANode(object): | |||
async def runfor(self, chan, length): | |||
return await self._sendcmd(CMD_RUNFOR, chan, length) | |||
async def ping(self): | |||
return await self._sendcmd(CMD_PING) | |||
async def terminate(self): | |||
return await self._sendcmd(CMD_TERMINATE) | |||
@@ -127,7 +138,6 @@ class SyncDatagram(object): | |||
raise NotImplementedError | |||
async def send(self, data): #pragma: no cover | |||
'''Send a datagram.''' | |||
raise NotImplementedError | |||
@@ -153,7 +163,6 @@ class MockSyncDatagram(SyncDatagram): | |||
def __init__(self): | |||
self.sendq = asyncio.Queue() | |||
self.recvq = asyncio.Queue() | |||
self.task = None | |||
self.task = asyncio.create_task(self.runner()) | |||
self.get = self.sendq.get | |||
@@ -225,6 +234,211 @@ def _debprint(*args): # pragma: no cover | |||
*args) | |||
sys.stdout.flush() | |||
class AsyncSequence(object): | |||
''' | |||
Object used for sequencing async functions. To use, use the | |||
asynchronous context manager created by the sync method. For | |||
example: | |||
seq = AsyncSequence() | |||
async func1(): | |||
async with seq.sync(1): | |||
second_fun() | |||
async func2(): | |||
async with seq.sync(0): | |||
first_fun() | |||
This will make sure that function first_fun is run before running | |||
the function second_fun. If a previous block raises an Exception, | |||
it will be passed up, and all remaining blocks (and future ones) | |||
will raise a CancelledError to help ensure that any tasks are | |||
properly cleaned up. | |||
''' | |||
def __init__(self, positerfactory=lambda: itertools.count()): | |||
'''The argument positerfactory, is a factory that will | |||
create an iterator that will be used for the values that | |||
are passed to the sync method.''' | |||
self.positer = positerfactory() | |||
self.token = object() | |||
self.die = False | |||
self.waiting = { | |||
next(self.positer): self.token | |||
} | |||
@contextlib.asynccontextmanager | |||
async def sync(self, pos): | |||
'''An async context manager that will be run when it's | |||
turn arrives. It will only run when all the previous | |||
items in the iterator has been successfully run.''' | |||
if self.die: | |||
raise asyncio.CancelledError('seq cancelled') | |||
if pos in self.waiting: | |||
if self.waiting[pos] is not self.token: | |||
raise RuntimeError('pos already waiting!') | |||
else: | |||
fut = asyncio.Future() | |||
self.waiting[pos] = fut | |||
await fut | |||
# our time to shine! | |||
del self.waiting[pos] | |||
try: | |||
yield None | |||
except Exception as e: | |||
# if we got an exception, things went pear shaped, | |||
# shut everything down, and any future calls. | |||
#_debprint('dieing...', repr(e)) | |||
self.die = True | |||
# cancel existing blocks | |||
while self.waiting: | |||
k, v = self.waiting.popitem() | |||
#_debprint('canceling: %s' % repr(k)) | |||
if v is self.token: | |||
continue | |||
# for Python 3.9: | |||
# msg='pos %s raised exception: %s' % | |||
# (repr(pos), repr(e)) | |||
v.cancel() | |||
# populate real exception up | |||
raise | |||
else: | |||
# handle next | |||
nextpos = next(self.positer) | |||
if nextpos in self.waiting: | |||
#_debprint('np:', repr(self), nextpos, | |||
# repr(self.waiting[nextpos])) | |||
self.waiting[nextpos].set_result(None) | |||
else: | |||
self.waiting[nextpos] = self.token | |||
class TestSequencing(unittest.IsolatedAsyncioTestCase): | |||
@timeout(2) | |||
async def test_seq_alreadywaiting(self): | |||
waitseq = AsyncSequence() | |||
seq = AsyncSequence() | |||
async def fun1(): | |||
async with waitseq.sync(1): | |||
pass | |||
async def fun2(): | |||
async with seq.sync(1): | |||
async with waitseq.sync(1): # pragma: no cover | |||
pass | |||
task1 = asyncio.create_task(fun1()) | |||
task2 = asyncio.create_task(fun2()) | |||
# spin things to make sure things advance | |||
await asyncio.sleep(0) | |||
async with seq.sync(0): | |||
pass | |||
with self.assertRaises(RuntimeError): | |||
await task2 | |||
async with waitseq.sync(0): | |||
pass | |||
await task1 | |||
@timeout(2) | |||
async def test_seqexc(self): | |||
seq = AsyncSequence() | |||
excseq = AsyncSequence() | |||
async def excfun1(): | |||
async with seq.sync(1): | |||
pass | |||
async with excseq.sync(0): | |||
raise ValueError('foo') | |||
# that a block that enters first, but runs after | |||
# raises an exception | |||
async def excfun2(): | |||
async with seq.sync(0): | |||
pass | |||
async with excseq.sync(1): # pragma: no cover | |||
pass | |||
# that a block that enters after, raises an | |||
# exception | |||
async def excfun3(): | |||
async with seq.sync(2): | |||
pass | |||
async with excseq.sync(2): # pragma: no cover | |||
pass | |||
task1 = asyncio.create_task(excfun1()) | |||
task2 = asyncio.create_task(excfun2()) | |||
task3 = asyncio.create_task(excfun3()) | |||
with self.assertRaises(ValueError): | |||
await task1 | |||
with self.assertRaises(asyncio.CancelledError): | |||
await task2 | |||
with self.assertRaises(asyncio.CancelledError): | |||
await task3 | |||
@timeout(2) | |||
async def test_seq(self): | |||
# test that a seq object when created | |||
seq = AsyncSequence(lambda: itertools.count(1)) | |||
col = [] | |||
async def fun1(): | |||
async with seq.sync(1): | |||
col.append(1) | |||
async with seq.sync(2): | |||
col.append(2) | |||
async with seq.sync(4): | |||
col.append(4) | |||
async def fun2(): | |||
async with seq.sync(3): | |||
col.append(3) | |||
async with seq.sync(6): | |||
col.append(6) | |||
async def fun3(): | |||
async with seq.sync(5): | |||
col.append(5) | |||
# and various functions are run | |||
task1 = asyncio.create_task(fun1()) | |||
task2 = asyncio.create_task(fun2()) | |||
task3 = asyncio.create_task(fun3()) | |||
# and the functions complete | |||
await task3 | |||
await task2 | |||
await task1 | |||
# that the order they ran in was correct | |||
self.assertEqual(col, list(range(1, 7))) | |||
class TestLORANode(unittest.IsolatedAsyncioTestCase): | |||
@timeout(2) | |||
async def test_lora(self): | |||
@@ -365,11 +579,11 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): | |||
exptmsgs = [ | |||
(CMD_WAITFOR, [ 30 ]), | |||
(CMD_RUNFOR, [ 1, 50 ]), | |||
(CMD_PING, [ ]), | |||
(CMD_TERMINATE, [ ]), | |||
] | |||
def procmsg(msg, outbuf): | |||
msgbuf = msg._from() | |||
#print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf)) | |||
cmd = msgbuf[0] | |||
args = [ int.from_bytes(msgbuf[x:x + 4], | |||
byteorder='little') for x in range(1, len(msgbuf), | |||
@@ -387,7 +601,7 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): | |||
class CCodeSD(MockSyncDatagram): | |||
async def runner(self): | |||
for expectlen in [ 24, 17, 9, 9, 9 ]: | |||
for expectlen in [ 24, 17, 9, 9, 9, 9 ]: | |||
# get message | |||
gb = await self.get() | |||
r = make_pktbuf(gb) | |||
@@ -438,6 +652,8 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): | |||
await l.runfor(1, 50) | |||
await l.ping() | |||
await l.terminate() | |||
await tsd.drain() | |||
@@ -450,3 +666,148 @@ class TestLORANode(unittest.IsolatedAsyncioTestCase): | |||
# processed. | |||
self.assertFalse(exptmsgs) | |||
#_debprint('done') | |||
@timeout(2) | |||
async def test_ccode_newsession(self): | |||
'''This test is to make sure that if an existing session | |||
is running, that a new session can be established, and that | |||
when it does, the old session becomes inactive. | |||
''' | |||
_self = self | |||
from ctypes import pointer, sizeof, c_uint8 | |||
seq = AsyncSequence() | |||
# seed the RNG | |||
prngseed = b'abc123' | |||
lora_comms.strobe_seed_prng((c_uint8 * | |||
len(prngseed))(*prngseed), len(prngseed)) | |||
# Create the state for testing | |||
commstate = lora_comms.CommsState() | |||
# These are the expected messages and their arguments | |||
exptmsgs = [ | |||
(CMD_WAITFOR, [ 30 ]), | |||
(CMD_WAITFOR, [ 70 ]), | |||
(CMD_WAITFOR, [ 40 ]), | |||
(CMD_TERMINATE, [ ]), | |||
] | |||
def procmsg(msg, outbuf): | |||
msgbuf = msg._from() | |||
cmd = msgbuf[0] | |||
args = [ int.from_bytes(msgbuf[x:x + 4], | |||
byteorder='little') for x in range(1, len(msgbuf), | |||
4) ] | |||
if exptmsgs[0] == (cmd, args): | |||
exptmsgs.pop(0) | |||
outbuf[0].pkt[0] = cmd | |||
outbuf[0].pktlen = 1 | |||
else: #pragma: no cover | |||
raise RuntimeError('cmd not found: %d' % cmd) | |||
# wrap the callback function | |||
cb = lora_comms.process_msgfunc_t(procmsg) | |||
class FlipMsg(object): | |||
async def flipmsg(self): | |||
# get message | |||
gb = await self.get() | |||
r = make_pktbuf(gb) | |||
outbytes = bytearray(64) | |||
outbuf = make_pktbuf(outbytes) | |||
# process the test message | |||
lora_comms.comms_process(commstate, r, | |||
outbuf) | |||
# pass the reply back | |||
pkt = outbytes[:outbuf.pktlen] | |||
await self.put(pkt) | |||
# this class always passes messages, this is | |||
# used for the first session. | |||
class CCodeSD1(MockSyncDatagram, FlipMsg): | |||
async def runner(self): | |||
for i in range(3): | |||
await self.flipmsg() | |||
async with seq.sync(0): | |||
# create bogus message | |||
r = make_pktbuf(b'0'*24) | |||
outbytes = bytearray(64) | |||
outbuf = make_pktbuf(outbytes) | |||
# process the bogus message | |||
lora_comms.comms_process(commstate, r, | |||
outbuf) | |||
# make sure there was not a response | |||
_self.assertEqual(outbuf.pktlen, 0) | |||
await self.flipmsg() | |||
# this one is special in that it will pause after the first | |||
# message to ensure that the previous session will continue | |||
# to work, AND that if a new "new" session comes along, it | |||
# will override the previous new session that hasn't been | |||
# confirmed yet. | |||
class CCodeSD2(MockSyncDatagram, FlipMsg): | |||
async def runner(self): | |||
# pass one message from the new session | |||
async with seq.sync(1): | |||
# There might be a missing case | |||
# handled for when the confirmed | |||
# message is generated, but lost. | |||
await self.flipmsg() | |||
# and the old session is still active | |||
await l.waitfor(70) | |||
async with seq.sync(2): | |||
for i in range(3): | |||
await self.flipmsg() | |||
# Generate shared key | |||
shared_key = os.urandom(32) | |||
# Initialize everything | |||
lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key)) | |||
# Create test fixture | |||
tsd = CCodeSD1() | |||
l = LORANode(tsd, shared=shared_key) | |||
# Send various messages | |||
await l.start() | |||
await l.waitfor(30) | |||
# Ensure that a new one can take over | |||
tsd2 = CCodeSD2() | |||
l2 = LORANode(tsd2, shared=shared_key) | |||
# Send various messages | |||
await l2.start() | |||
await l2.waitfor(40) | |||
await l2.terminate() | |||
await tsd.drain() | |||
await tsd2.drain() | |||
# Make sure all messages have been processed | |||
self.assertTrue(tsd.sendq.empty()) | |||
self.assertTrue(tsd.recvq.empty()) | |||
self.assertTrue(tsd2.sendq.empty()) | |||
self.assertTrue(tsd2.recvq.empty()) | |||
# Make sure all the expected messages have been | |||
# processed. | |||
self.assertFalse(exptmsgs) |
@@ -22,10 +22,15 @@ | |||
# SUCH DAMAGE. | |||
# | |||
from ctypes import Structure, POINTER, CFUNCTYPE, pointer | |||
from ctypes import Structure, POINTER, CFUNCTYPE, pointer, sizeof | |||
from ctypes import c_uint8, c_uint16, c_ssize_t, c_size_t, c_uint64, c_int | |||
from ctypes import CDLL | |||
class StructureRepr(object): | |||
def __repr__(self): #pragma: no cover | |||
return '%s(%s)' % (self.__class__.__name__, ', '.join('%s=%s' % | |||
(k, getattr(self, k)) for k, v in self._fields_)) | |||
class PktBuf(Structure): | |||
_fields_ = [ | |||
('pkt', POINTER(c_uint8)), | |||
@@ -63,11 +68,17 @@ _lib._strobe_state_size.restype = c_size_t | |||
_lib._strobe_state_size.argtypes = () | |||
_strobe_state_u64_cnt = (_lib._strobe_state_size() + 7) // 8 | |||
class CommsState(Structure): | |||
class CommsSession(Structure,StructureRepr): | |||
_fields_ = [ | |||
('cs_crypto', c_uint64 * _strobe_state_u64_cnt), | |||
('cs_state', c_int), | |||
] | |||
class CommsState(Structure,StructureRepr): | |||
_fields_ = [ | |||
# The alignment of these may be off | |||
('cs_state', c_uint64 * _strobe_state_u64_cnt), | |||
('cs_comm_state', c_int), | |||
('cs_active', CommsSession), | |||
('cs_pending', CommsSession), | |||
('cs_start', c_uint64 * _strobe_state_u64_cnt), | |||
('cs_procmsg', process_msgfunc_t), | |||
@@ -78,8 +89,15 @@ class CommsState(Structure): | |||
('cs_prevmsgrespbuf', c_uint8 * 64), | |||
] | |||
_lib._comms_state_size.restype = c_size_t | |||
_lib._comms_state_size.argtypes = () | |||
if _lib._comms_state_size() != sizeof(CommsState): # pragma: no cover | |||
raise RuntimeError('CommsState structure size mismatch!') | |||
for func, ret, args in [ | |||
('comms_init', None, (POINTER(CommsState), process_msgfunc_t, POINTER(PktBuf))), | |||
('comms_init', None, (POINTER(CommsState), process_msgfunc_t, | |||
POINTER(PktBuf))), | |||
('comms_process', None, (POINTER(CommsState), PktBuf, POINTER(PktBuf))), | |||
('strobe_seed_prng', None, (POINTER(c_uint8), c_ssize_t)), | |||
]: | |||