Implement a secure ICS protocol targeting LoRa Node151 microcontroller for controlling irrigation.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1015 lines
24 KiB

  1. # Copyright 2021 John-Mark Gurney.
  2. #
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions
  5. # are met:
  6. # 1. Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # 2. Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. #
  12. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  13. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  14. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  15. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  16. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  17. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  18. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  19. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  20. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  21. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  22. # SUCH DAMAGE.
  23. #
  24. import asyncio
  25. import contextlib
  26. import functools
  27. import itertools
  28. import os
  29. import unittest
  30. from Strobe.Strobe import Strobe, KeccakF
  31. from Strobe.Strobe import AuthenticationFailed
  32. import lora_comms
  33. from lora_comms import make_pktbuf
  34. import multicast
  35. from util import *
  36. domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
  37. # Response to command will be the CMD and any arguments if needed.
  38. # The command is encoded as an unsigned byte
  39. CMD_TERMINATE = 1 # no args: terminate the sesssion, reply confirms
  40. # The follow commands are queue up, but will be acknoledged when queued
  41. CMD_WAITFOR = 2 # arg: (length): waits for length seconds
  42. CMD_RUNFOR = 3 # arg: (chan, length): turns on chan for length seconds
  43. CMD_PING = 4 # arg: (): a no op command
  44. class LORANode(object):
  45. '''Implement a LORANode initiator.'''
  46. MAC_LEN = 8
  47. def __init__(self, syncdatagram, shared=None):
  48. self.sd = syncdatagram
  49. self.st = Strobe(domain, F=KeccakF(800))
  50. if shared is not None:
  51. self.st.key(shared)
  52. async def start(self):
  53. resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset')
  54. self.st.ratchet()
  55. pkt = await self.sendrecvvalid(b'confirm')
  56. if pkt != b'confirmed':
  57. raise RuntimeError('got invalid response: %s' %
  58. repr(pkt))
  59. async def sendrecvvalid(self, msg):
  60. msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN)
  61. origstate = self.st.copy()
  62. while True:
  63. resp = await self.sd.sendtillrecv(msg, 1)
  64. #_debprint('got:', resp)
  65. # skip empty messages
  66. if len(resp) == 0:
  67. continue
  68. try:
  69. decmsg = self.st.recv_enc(resp[:-self.MAC_LEN])
  70. self.st.recv_mac(resp[-self.MAC_LEN:])
  71. break
  72. except AuthenticationFailed:
  73. # didn't get a valid packet, restore
  74. # state and retry
  75. #_debprint('failed')
  76. self.st.set_state_from(origstate)
  77. #_debprint('got rep:', repr(resp), repr(decmsg))
  78. return decmsg
  79. @staticmethod
  80. def _encodeargs(*args):
  81. r = []
  82. for i in args:
  83. r.append(i.to_bytes(4, byteorder='little'))
  84. return b''.join(r)
  85. async def _sendcmd(self, cmd, *args):
  86. cmdbyte = cmd.to_bytes(1, byteorder='little')
  87. resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args))
  88. if resp[0:1] != cmdbyte:
  89. raise RuntimeError(
  90. 'response does not match, got: %s, expected: %s' %
  91. (repr(resp[0:1]), repr(cmdbyte)))
  92. async def waitfor(self, length):
  93. return await self._sendcmd(CMD_WAITFOR, length)
  94. async def runfor(self, chan, length):
  95. return await self._sendcmd(CMD_RUNFOR, chan, length)
  96. async def ping(self):
  97. return await self._sendcmd(CMD_PING)
  98. async def terminate(self):
  99. return await self._sendcmd(CMD_TERMINATE)
  100. class SyncDatagram(object):
  101. '''Base interface for a more simple synchronous interface.'''
  102. async def recv(self, timeout=None): #pragma: no cover
  103. '''Receive a datagram. If timeout is not None, wait that many
  104. seconds, and if nothing is received in that time, raise an
  105. TimeoutError exception.'''
  106. raise NotImplementedError
  107. async def send(self, data): #pragma: no cover
  108. raise NotImplementedError
  109. async def sendtillrecv(self, data, freq):
  110. '''Send the datagram in data, every freq seconds until a datagram
  111. is received. If timeout seconds happen w/o receiving a datagram,
  112. then raise an TimeoutError exception.'''
  113. while True:
  114. #_debprint('sending:', repr(data))
  115. await self.send(data)
  116. try:
  117. return await self.recv(freq)
  118. except TimeoutError:
  119. pass
  120. class MulticastSyncDatagram(SyncDatagram):
  121. '''
  122. An implementation of SyncDatagram that uses the provided
  123. multicast address maddr as the source/sink of the packets.
  124. Note that once created, the start coroutine needs to be
  125. await'd before being passed to a LORANode so that everything
  126. is running.
  127. '''
  128. # Note: sent packets will be received. A similar method to
  129. # what was done in multicast.{to,from}_loragw could be done
  130. # here as well, that is passing in a set of packets to not
  131. # pass back up.
  132. def __init__(self, maddr):
  133. self.maddr = maddr
  134. self._ignpkts = set()
  135. async def start(self):
  136. self.mr = await multicast.create_multicast_receiver(self.maddr)
  137. self.mt = await multicast.create_multicast_transmitter(
  138. self.maddr)
  139. async def _recv(self):
  140. while True:
  141. pkt = await self.mr.recv()
  142. pkt = pkt[0]
  143. if pkt not in self._ignpkts:
  144. return pkt
  145. self._ignpkts.remove(pkt)
  146. async def recv(self, timeout=None): #pragma: no cover
  147. r = await asyncio.wait_for(self._recv(), timeout=timeout)
  148. return r
  149. async def send(self, data): #pragma: no cover
  150. self._ignpkts.add(bytes(data))
  151. await self.mt.send(data)
  152. def close(self):
  153. '''Shutdown communications.'''
  154. self.mr.close()
  155. self.mr = None
  156. self.mt.close()
  157. self.mt = None
  158. class MockSyncDatagram(SyncDatagram):
  159. '''A testing version of SyncDatagram. Define a method runner which
  160. implements part of the sequence. In the function, await on either
  161. self.get, to wait for the other side to send something, or await
  162. self.put w/ data to send.'''
  163. def __init__(self):
  164. self.sendq = asyncio.Queue()
  165. self.recvq = asyncio.Queue()
  166. self.task = asyncio.create_task(self.runner())
  167. self.get = self.sendq.get
  168. self.put = self.recvq.put
  169. async def drain(self):
  170. '''Wait for the runner thread to finish up.'''
  171. return await self.task
  172. async def runner(self): #pragma: no cover
  173. raise NotImplementedError
  174. async def recv(self, timeout=None):
  175. return await self.recvq.get()
  176. async def send(self, data):
  177. return await self.sendq.put(data)
  178. def __del__(self): #pragma: no cover
  179. if self.task is not None and not self.task.done():
  180. self.task.cancel()
  181. class TestSyncData(unittest.IsolatedAsyncioTestCase):
  182. async def test_syncsendtillrecv(self):
  183. class MySync(SyncDatagram):
  184. def __init__(self):
  185. self.sendq = []
  186. self.resp = [ TimeoutError(), b'a' ]
  187. async def recv(self, timeout=None):
  188. assert timeout == 1
  189. r = self.resp.pop(0)
  190. if isinstance(r, Exception):
  191. raise r
  192. return r
  193. async def send(self, data):
  194. self.sendq.append(data)
  195. ms = MySync()
  196. r = await ms.sendtillrecv(b'foo', 1)
  197. self.assertEqual(r, b'a')
  198. self.assertEqual(ms.sendq, [ b'foo', b'foo' ])
  199. class AsyncSequence(object):
  200. '''
  201. Object used for sequencing async functions. To use, use the
  202. asynchronous context manager created by the sync method. For
  203. example:
  204. seq = AsyncSequence()
  205. async func1():
  206. async with seq.sync(1):
  207. second_fun()
  208. async func2():
  209. async with seq.sync(0):
  210. first_fun()
  211. This will make sure that function first_fun is run before running
  212. the function second_fun. If a previous block raises an Exception,
  213. it will be passed up, and all remaining blocks (and future ones)
  214. will raise a CancelledError to help ensure that any tasks are
  215. properly cleaned up.
  216. '''
  217. def __init__(self, positerfactory=lambda: itertools.count()):
  218. '''The argument positerfactory, is a factory that will
  219. create an iterator that will be used for the values that
  220. are passed to the sync method.'''
  221. self.positer = positerfactory()
  222. self.token = object()
  223. self.die = False
  224. self.waiting = {
  225. next(self.positer): self.token
  226. }
  227. async def simpsync(self, pos):
  228. async with self.sync(pos):
  229. pass
  230. @contextlib.asynccontextmanager
  231. async def sync(self, pos):
  232. '''An async context manager that will be run when it's
  233. turn arrives. It will only run when all the previous
  234. items in the iterator has been successfully run.'''
  235. if self.die:
  236. raise asyncio.CancelledError('seq cancelled')
  237. if pos in self.waiting:
  238. if self.waiting[pos] is not self.token:
  239. raise RuntimeError('pos already waiting!')
  240. else:
  241. fut = asyncio.Future()
  242. self.waiting[pos] = fut
  243. await fut
  244. # our time to shine!
  245. del self.waiting[pos]
  246. try:
  247. yield None
  248. except Exception as e:
  249. # if we got an exception, things went pear shaped,
  250. # shut everything down, and any future calls.
  251. #_debprint('dieing...', repr(e))
  252. self.die = True
  253. # cancel existing blocks
  254. while self.waiting:
  255. k, v = self.waiting.popitem()
  256. #_debprint('canceling: %s' % repr(k))
  257. if v is self.token:
  258. continue
  259. # for Python 3.9:
  260. # msg='pos %s raised exception: %s' %
  261. # (repr(pos), repr(e))
  262. v.cancel()
  263. # populate real exception up
  264. raise
  265. else:
  266. # handle next
  267. nextpos = next(self.positer)
  268. if nextpos in self.waiting:
  269. #_debprint('np:', repr(self), nextpos,
  270. # repr(self.waiting[nextpos]))
  271. self.waiting[nextpos].set_result(None)
  272. else:
  273. self.waiting[nextpos] = self.token
  274. class TestSequencing(unittest.IsolatedAsyncioTestCase):
  275. @timeout(2)
  276. async def test_seq_alreadywaiting(self):
  277. waitseq = AsyncSequence()
  278. seq = AsyncSequence()
  279. async def fun1():
  280. async with waitseq.sync(1):
  281. pass
  282. async def fun2():
  283. async with seq.sync(1):
  284. async with waitseq.sync(1): # pragma: no cover
  285. pass
  286. task1 = asyncio.create_task(fun1())
  287. task2 = asyncio.create_task(fun2())
  288. # spin things to make sure things advance
  289. await asyncio.sleep(0)
  290. async with seq.sync(0):
  291. pass
  292. with self.assertRaises(RuntimeError):
  293. await task2
  294. async with waitseq.sync(0):
  295. pass
  296. await task1
  297. @timeout(2)
  298. async def test_seqexc(self):
  299. seq = AsyncSequence()
  300. excseq = AsyncSequence()
  301. async def excfun1():
  302. async with seq.sync(1):
  303. pass
  304. async with excseq.sync(0):
  305. raise ValueError('foo')
  306. # that a block that enters first, but runs after
  307. # raises an exception
  308. async def excfun2():
  309. async with seq.sync(0):
  310. pass
  311. async with excseq.sync(1): # pragma: no cover
  312. pass
  313. # that a block that enters after, raises an
  314. # exception
  315. async def excfun3():
  316. async with seq.sync(2):
  317. pass
  318. async with excseq.sync(2): # pragma: no cover
  319. pass
  320. task1 = asyncio.create_task(excfun1())
  321. task2 = asyncio.create_task(excfun2())
  322. task3 = asyncio.create_task(excfun3())
  323. with self.assertRaises(ValueError):
  324. await task1
  325. with self.assertRaises(asyncio.CancelledError):
  326. await task2
  327. with self.assertRaises(asyncio.CancelledError):
  328. await task3
  329. @timeout(2)
  330. async def test_seq(self):
  331. # test that a seq object when created
  332. seq = AsyncSequence(lambda: itertools.count(1))
  333. col = []
  334. async def fun1():
  335. async with seq.sync(1):
  336. col.append(1)
  337. async with seq.sync(2):
  338. col.append(2)
  339. async with seq.sync(4):
  340. col.append(4)
  341. async def fun2():
  342. async with seq.sync(3):
  343. col.append(3)
  344. async with seq.sync(6):
  345. col.append(6)
  346. async def fun3():
  347. async with seq.sync(5):
  348. col.append(5)
  349. # and various functions are run
  350. task1 = asyncio.create_task(fun1())
  351. task2 = asyncio.create_task(fun2())
  352. task3 = asyncio.create_task(fun3())
  353. # and the functions complete
  354. await task3
  355. await task2
  356. await task1
  357. # that the order they ran in was correct
  358. self.assertEqual(col, list(range(1, 7)))
  359. class TestLORANode(unittest.IsolatedAsyncioTestCase):
  360. @timeout(2)
  361. async def test_lora(self):
  362. _self = self
  363. shared_key = os.urandom(32)
  364. class TestSD(MockSyncDatagram):
  365. async def sendgettest(self, msg):
  366. '''Send the message, but make sure that if a
  367. bad message is sent afterward, that it replies
  368. w/ the same previous message.
  369. '''
  370. await self.put(msg)
  371. resp = await self.get()
  372. await self.put(b'bogusmsg' * 5)
  373. resp2 = await self.get()
  374. _self.assertEqual(resp, resp2)
  375. return resp
  376. async def runner(self):
  377. l = Strobe(domain, F=KeccakF(800))
  378. l.key(shared_key)
  379. # start handshake
  380. r = await self.get()
  381. pkt = l.recv_enc(r[:-8])
  382. l.recv_mac(r[-8:])
  383. assert pkt.endswith(b'reqreset')
  384. # make sure junk gets ignored
  385. await self.put(b'sdlfkj')
  386. # and that the packet remains the same
  387. _self.assertEqual(r, await self.get())
  388. # and a couple more times
  389. await self.put(b'0' * 24)
  390. _self.assertEqual(r, await self.get())
  391. await self.put(b'0' * 32)
  392. _self.assertEqual(r, await self.get())
  393. # send the response
  394. await self.put(l.send_enc(os.urandom(16)) +
  395. l.send_mac(8))
  396. # require no more back tracking at this point
  397. l.ratchet()
  398. # get the confirmation message
  399. r = await self.get()
  400. # test the resend capabilities
  401. await self.put(b'0' * 24)
  402. _self.assertEqual(r, await self.get())
  403. # decode confirmation message
  404. c = l.recv_enc(r[:-8])
  405. l.recv_mac(r[-8:])
  406. # assert that we got it
  407. _self.assertEqual(c, b'confirm')
  408. # send confirmed reply
  409. r = await self.sendgettest(l.send_enc(
  410. b'confirmed') + l.send_mac(8))
  411. # test and decode remaining command messages
  412. cmd = l.recv_enc(r[:-8])
  413. l.recv_mac(r[-8:])
  414. assert cmd[0] == CMD_WAITFOR
  415. assert int.from_bytes(cmd[1:],
  416. byteorder='little') == 30
  417. r = await self.sendgettest(l.send_enc(
  418. cmd[0:1]) + l.send_mac(8))
  419. cmd = l.recv_enc(r[:-8])
  420. l.recv_mac(r[-8:])
  421. assert cmd[0] == CMD_RUNFOR
  422. assert int.from_bytes(cmd[1:5],
  423. byteorder='little') == 1
  424. assert int.from_bytes(cmd[5:],
  425. byteorder='little') == 50
  426. r = await self.sendgettest(l.send_enc(
  427. cmd[0:1]) + l.send_mac(8))
  428. cmd = l.recv_enc(r[:-8])
  429. l.recv_mac(r[-8:])
  430. assert cmd[0] == CMD_TERMINATE
  431. await self.put(l.send_enc(cmd[0:1]) +
  432. l.send_mac(8))
  433. tsd = TestSD()
  434. l = LORANode(tsd, shared=shared_key)
  435. await l.start()
  436. await l.waitfor(30)
  437. await l.runfor(1, 50)
  438. await l.terminate()
  439. await tsd.drain()
  440. # Make sure all messages have been processed
  441. self.assertTrue(tsd.sendq.empty())
  442. self.assertTrue(tsd.recvq.empty())
  443. #_debprint('done')
  444. @timeout(2)
  445. async def test_ccode_badmsgs(self):
  446. # Test to make sure that various bad messages in the
  447. # handshake process are rejected even if the attacker
  448. # has the correct key. This just keeps the protocol
  449. # tight allowing for variations in the future.
  450. # seed the RNG
  451. prngseed = b'abc123'
  452. from ctypes import c_uint8
  453. lora_comms.strobe_seed_prng((c_uint8 *
  454. len(prngseed))(*prngseed), len(prngseed))
  455. # Create the state for testing
  456. commstate = lora_comms.CommsState()
  457. cb = lora_comms.process_msgfunc_t(lambda msg, outbuf: None)
  458. # Generate shared key
  459. shared_key = os.urandom(32)
  460. # Initialize everything
  461. lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
  462. # Create test fixture, only use it to init crypto state
  463. tsd = SyncDatagram()
  464. l = LORANode(tsd, shared=shared_key)
  465. # copy the crypto state
  466. cstate = l.st.copy()
  467. # compose an incorrect init message
  468. msg = os.urandom(16) + b'othre'
  469. msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
  470. out = lora_comms.comms_process_wrap(commstate, msg)
  471. self.assertFalse(out)
  472. # copy the crypto state
  473. cstate = l.st.copy()
  474. # compose an incorrect init message
  475. msg = os.urandom(16) + b' eqreset'
  476. msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
  477. out = lora_comms.comms_process_wrap(commstate, msg)
  478. self.assertFalse(out)
  479. # compose the correct init message
  480. msg = os.urandom(16) + b'reqreset'
  481. msg = l.st.send_enc(msg) + l.st.send_mac(l.MAC_LEN)
  482. out = lora_comms.comms_process_wrap(commstate, msg)
  483. l.st.recv_enc(out[:-l.MAC_LEN])
  484. l.st.recv_mac(out[-l.MAC_LEN:])
  485. l.st.ratchet()
  486. # copy the crypto state
  487. cstate = l.st.copy()
  488. # compose an incorrect confirmed message
  489. msg = b'onfirm'
  490. msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
  491. out = lora_comms.comms_process_wrap(commstate, msg)
  492. self.assertFalse(out)
  493. # copy the crypto state
  494. cstate = l.st.copy()
  495. # compose an incorrect confirmed message
  496. msg = b' onfirm'
  497. msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)
  498. out = lora_comms.comms_process_wrap(commstate, msg)
  499. self.assertFalse(out)
  500. @timeout(2)
  501. async def test_ccode(self):
  502. _self = self
  503. from ctypes import c_uint8
  504. # seed the RNG
  505. prngseed = b'abc123'
  506. lora_comms.strobe_seed_prng((c_uint8 *
  507. len(prngseed))(*prngseed), len(prngseed))
  508. # Create the state for testing
  509. commstate = lora_comms.CommsState()
  510. # These are the expected messages and their arguments
  511. exptmsgs = [
  512. (CMD_WAITFOR, [ 30 ]),
  513. (CMD_RUNFOR, [ 1, 50 ]),
  514. (CMD_PING, [ ]),
  515. (CMD_TERMINATE, [ ]),
  516. ]
  517. def procmsg(msg, outbuf):
  518. msgbuf = msg._from()
  519. cmd = msgbuf[0]
  520. args = [ int.from_bytes(msgbuf[x:x + 4],
  521. byteorder='little') for x in range(1, len(msgbuf),
  522. 4) ]
  523. if exptmsgs[0] == (cmd, args):
  524. exptmsgs.pop(0)
  525. outbuf[0].pkt[0] = cmd
  526. outbuf[0].pktlen = 1
  527. else: #pragma: no cover
  528. raise RuntimeError('cmd not found')
  529. # wrap the callback function
  530. cb = lora_comms.process_msgfunc_t(procmsg)
  531. class CCodeSD(MockSyncDatagram):
  532. async def runner(self):
  533. for expectlen in [ 24, 17, 9, 9, 9, 9 ]:
  534. # get message
  535. inmsg = await self.get()
  536. # process the test message
  537. out = lora_comms.comms_process_wrap(
  538. commstate, inmsg)
  539. # make sure the reply matches length
  540. _self.assertEqual(expectlen, len(out))
  541. # save what was originally replied
  542. origmsg = out
  543. # pretend that the reply didn't make it
  544. out = lora_comms.comms_process_wrap(
  545. commstate, inmsg)
  546. # make sure that the reply matches
  547. # the previous
  548. _self.assertEqual(origmsg, out)
  549. # pass the reply back
  550. await self.put(out)
  551. # Generate shared key
  552. shared_key = os.urandom(32)
  553. # Initialize everything
  554. lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
  555. # Create test fixture
  556. tsd = CCodeSD()
  557. l = LORANode(tsd, shared=shared_key)
  558. # Send various messages
  559. await l.start()
  560. await l.waitfor(30)
  561. await l.runfor(1, 50)
  562. await l.ping()
  563. await l.terminate()
  564. await tsd.drain()
  565. # Make sure all messages have been processed
  566. self.assertTrue(tsd.sendq.empty())
  567. self.assertTrue(tsd.recvq.empty())
  568. # Make sure all the expected messages have been
  569. # processed.
  570. self.assertFalse(exptmsgs)
  571. #_debprint('done')
  572. @timeout(2)
  573. async def test_ccode_newsession(self):
  574. '''This test is to make sure that if an existing session
  575. is running, that a new session can be established, and that
  576. when it does, the old session becomes inactive.
  577. '''
  578. _self = self
  579. from ctypes import c_uint8
  580. seq = AsyncSequence()
  581. # seed the RNG
  582. prngseed = b'abc123'
  583. lora_comms.strobe_seed_prng((c_uint8 *
  584. len(prngseed))(*prngseed), len(prngseed))
  585. # Create the state for testing
  586. commstate = lora_comms.CommsState()
  587. # These are the expected messages and their arguments
  588. exptmsgs = [
  589. (CMD_WAITFOR, [ 30 ]),
  590. (CMD_WAITFOR, [ 70 ]),
  591. (CMD_WAITFOR, [ 40 ]),
  592. (CMD_TERMINATE, [ ]),
  593. ]
  594. def procmsg(msg, outbuf):
  595. msgbuf = msg._from()
  596. cmd = msgbuf[0]
  597. args = [ int.from_bytes(msgbuf[x:x + 4],
  598. byteorder='little') for x in range(1, len(msgbuf),
  599. 4) ]
  600. if exptmsgs[0] == (cmd, args):
  601. exptmsgs.pop(0)
  602. outbuf[0].pkt[0] = cmd
  603. outbuf[0].pktlen = 1
  604. else: #pragma: no cover
  605. raise RuntimeError('cmd not found: %d' % cmd)
  606. # wrap the callback function
  607. cb = lora_comms.process_msgfunc_t(procmsg)
  608. class FlipMsg(object):
  609. async def flipmsg(self):
  610. # get message
  611. inmsg = await self.get()
  612. # process the test message
  613. out = lora_comms.comms_process_wrap(
  614. commstate, inmsg)
  615. # pass the reply back
  616. await self.put(out)
  617. # this class always passes messages, this is
  618. # used for the first session.
  619. class CCodeSD1(MockSyncDatagram, FlipMsg):
  620. async def runner(self):
  621. for i in range(3):
  622. await self.flipmsg()
  623. async with seq.sync(0):
  624. # create bogus message
  625. inmsg = b'0'*24
  626. # process the bogus message
  627. out = lora_comms.comms_process_wrap(
  628. commstate, inmsg)
  629. # make sure there was not a response
  630. _self.assertFalse(out)
  631. await self.flipmsg()
  632. # this one is special in that it will pause after the first
  633. # message to ensure that the previous session will continue
  634. # to work, AND that if a new "new" session comes along, it
  635. # will override the previous new session that hasn't been
  636. # confirmed yet.
  637. class CCodeSD2(MockSyncDatagram, FlipMsg):
  638. async def runner(self):
  639. # pass one message from the new session
  640. async with seq.sync(1):
  641. # There might be a missing case
  642. # handled for when the confirmed
  643. # message is generated, but lost.
  644. await self.flipmsg()
  645. # and the old session is still active
  646. await l.waitfor(70)
  647. async with seq.sync(2):
  648. for i in range(3):
  649. await self.flipmsg()
  650. # Generate shared key
  651. shared_key = os.urandom(32)
  652. # Initialize everything
  653. lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
  654. # Create test fixture
  655. tsd = CCodeSD1()
  656. l = LORANode(tsd, shared=shared_key)
  657. # Send various messages
  658. await l.start()
  659. await l.waitfor(30)
  660. # Ensure that a new one can take over
  661. tsd2 = CCodeSD2()
  662. l2 = LORANode(tsd2, shared=shared_key)
  663. # Send various messages
  664. await l2.start()
  665. await l2.waitfor(40)
  666. await l2.terminate()
  667. await tsd.drain()
  668. await tsd2.drain()
  669. # Make sure all messages have been processed
  670. self.assertTrue(tsd.sendq.empty())
  671. self.assertTrue(tsd.recvq.empty())
  672. self.assertTrue(tsd2.sendq.empty())
  673. self.assertTrue(tsd2.recvq.empty())
  674. # Make sure all the expected messages have been
  675. # processed.
  676. self.assertFalse(exptmsgs)
  677. class TestLoRaNodeMulticast(unittest.IsolatedAsyncioTestCase):
  678. # see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
  679. maddr = ('224.0.0.198', 48542)
  680. @timeout(2)
  681. async def test_multisyncdgram(self):
  682. # Test the implementation of the multicast version of
  683. # SyncDatagram
  684. _self = self
  685. from ctypes import c_uint8
  686. # seed the RNG
  687. prngseed = b'abc123'
  688. lora_comms.strobe_seed_prng((c_uint8 *
  689. len(prngseed))(*prngseed), len(prngseed))
  690. # Create the state for testing
  691. commstate = lora_comms.CommsState()
  692. # These are the expected messages and their arguments
  693. exptmsgs = [
  694. (CMD_WAITFOR, [ 30 ]),
  695. (CMD_PING, [ ]),
  696. (CMD_TERMINATE, [ ]),
  697. ]
  698. def procmsg(msg, outbuf):
  699. msgbuf = msg._from()
  700. cmd = msgbuf[0]
  701. args = [ int.from_bytes(msgbuf[x:x + 4],
  702. byteorder='little') for x in range(1, len(msgbuf),
  703. 4) ]
  704. if exptmsgs[0] == (cmd, args):
  705. exptmsgs.pop(0)
  706. outbuf[0].pkt[0] = cmd
  707. outbuf[0].pktlen = 1
  708. else: #pragma: no cover
  709. raise RuntimeError('cmd not found')
  710. # wrap the callback function
  711. cb = lora_comms.process_msgfunc_t(procmsg)
  712. # Generate shared key
  713. shared_key = os.urandom(32)
  714. # Initialize everything
  715. lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
  716. # create the object we are testing
  717. msd = MulticastSyncDatagram(self.maddr)
  718. seq = AsyncSequence()
  719. async def clienttask():
  720. mr = await multicast.create_multicast_receiver(
  721. self.maddr)
  722. mt = await multicast.create_multicast_transmitter(
  723. self.maddr)
  724. try:
  725. # make sure the above threads are running
  726. await seq.simpsync(0)
  727. while True:
  728. pkt = await mr.recv()
  729. msg = pkt[0]
  730. out = lora_comms.comms_process_wrap(
  731. commstate, msg)
  732. if out:
  733. await mt.send(out)
  734. finally:
  735. mr.close()
  736. mt.close()
  737. task = asyncio.create_task(clienttask())
  738. # start it
  739. await msd.start()
  740. # pass it to a node
  741. l = LORANode(msd, shared=shared_key)
  742. await seq.simpsync(1)
  743. # Send various messages
  744. await l.start()
  745. await l.waitfor(30)
  746. await l.ping()
  747. await l.terminate()
  748. # shut things down
  749. ln = None
  750. msd.close()
  751. task.cancel()
  752. with self.assertRaises(asyncio.CancelledError):
  753. await task