|  | # Copyright 2021 John-Mark Gurney.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.
#
import asyncio
import functools
import os
import unittest
from Strobe.Strobe import Strobe, KeccakF
from Strobe.Strobe import AuthenticationFailed
import lora_comms
from lora_comms import make_pktbuf
domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
# Response to command will be the CMD and any arguments if needed.
# The command is encoded as an unsigned byte
CMD_TERMINATE = 1	# no args: terminate the sesssion, reply confirms
# The follow commands are queue up, but will be acknoledged when queued
CMD_WAITFOR = 2		# arg: (length): waits for length seconds
CMD_RUNFOR = 3		# arg: (chan, length): turns on chan for length seconds
class LORANode(object):
	'''Implement a LORANode initiator.'''
	def __init__(self, syncdatagram, shared=None):
		self.sd = syncdatagram
		self.st = Strobe(domain, F=KeccakF(800))
		if shared is not None:
			self.st.key(shared)
	async def start(self):
		msg = self.st.send_enc(os.urandom(16) + b'reqreset') + \
		    self.st.send_mac(8)
		resp = await self.sd.sendtillrecv(msg, 1)
		self.st.recv_enc(resp[:16])
		self.st.recv_mac(resp[16:])
		self.st.ratchet()
		resp = await self.sd.sendtillrecv(
		    self.st.send_enc(b'confirm') + self.st.send_mac(8), 1)
		pkt = self.st.recv_enc(resp[:9])
		self.st.recv_mac(resp[9:])
		if pkt != b'confirmed':
			raise RuntimeError
	@staticmethod
	def _encodeargs(*args):
		r = []
		for i in args:
			r.append(i.to_bytes(4, byteorder='little'))
		return b''.join(r)
	async def _sendcmd(self, cmd, *args):
		cmdbyte = cmd.to_bytes(1, byteorder='little')
		pkt = await self.sd.sendtillrecv(
		    self.st.send_enc(cmdbyte +
		    self._encodeargs(*args)) + self.st.send_mac(8), 1)
		resp = self.st.recv_enc(pkt[:-8])
		self.st.recv_mac(pkt[-8:])
		if resp[0:1] != cmdbyte:
			raise RuntimeError('response does not match, got: %s, expected: %s' % (repr(resp[0:1]), repr(cmdbyte)))
	async def waitfor(self, length):
		return await self._sendcmd(CMD_WAITFOR, length)
	async def runfor(self, chan, length):
		return await self._sendcmd(CMD_RUNFOR, chan, length)
	async def terminate(self):
		return await self._sendcmd(CMD_TERMINATE)
class SyncDatagram(object):
	'''Base interface for a more simple synchronous interface.'''
	def __init__(self): #pragma: no cover
		pass
	async def recv(self, timeout=None): #pragma: no cover
		'''Receive a datagram.  If timeout is not None, wait that many
		seconds, and if nothing is received in that time, raise an TimeoutError
		exception.'''
		raise NotImplementedError
	async def send(self, data): #pragma: no cover
		'''Send a datagram.'''
		raise NotImplementedError
	async def sendtillrecv(self, data, freq):
		'''Send the datagram in data, every freq seconds until a datagram
		is received.  If timeout seconds happen w/o receiving a datagram,
		then raise an TimeoutError exception.'''
		while True:
			await self.send(data)
			try:
				return await self.recv(freq)
			except TimeoutError:
				pass
class MockSyncDatagram(SyncDatagram):
	'''A testing version of SyncDatagram.  Define a method runner which
	implements part of the sequence.  In the function, await on either
	self.get, to wait for the other side to send something, or await
	self.put w/ data to send.'''
	def __init__(self):
		self.sendq = asyncio.Queue()
		self.recvq = asyncio.Queue()
		self.task = None
		self.task = asyncio.create_task(self.runner())
		self.get = self.sendq.get
		self.put = self.recvq.put
	async def drain(self):
		'''Wait for the runner thread to finish up.'''
		return await self.task
	async def runner(self): #pragma: no cover
		raise NotImplementedError
	async def recv(self, timeout=None):
		return await self.recvq.get()
	async def send(self, data):
		return await self.sendq.put(data)
	def __del__(self): #pragma: no cover
		if self.task is not None and not self.task.done():
			self.task.cancel()
class TestSyncData(unittest.IsolatedAsyncioTestCase):
	async def test_syncsendtillrecv(self):
		class MySync(SyncDatagram):
			def __init__(self):
				self.sendq = []
				self.resp = [ TimeoutError(), b'a' ]
			async def recv(self, timeout=None):
				assert timeout == 1
				r = self.resp.pop(0)
				if isinstance(r, Exception):
					raise r
				return r
			async def send(self, data):
				self.sendq.append(data)
		ms = MySync()
		r = await ms.sendtillrecv(b'foo', 1)
		self.assertEqual(r, b'a')
		self.assertEqual(ms.sendq, [ b'foo', b'foo' ])
def timeout(timeout):
	def timeout_wrapper(fun):
		@functools.wraps(fun)
		async def wrapper(*args, **kwargs):
			return await asyncio.wait_for(fun(*args, **kwargs),
			    timeout)
		return wrapper
	return timeout_wrapper
class TestLORANode(unittest.IsolatedAsyncioTestCase):
	@timeout(2)
	async def test_lora(self):
		shared_key = os.urandom(32)
		class TestSD(MockSyncDatagram):
			async def runner(self):
				l = Strobe(domain, F=KeccakF(800))
				l.key(shared_key)
				# start handshake
				r = await self.get()
				pkt = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])
				assert pkt.endswith(b'reqreset')
				await self.put(l.send_enc(os.urandom(16)) +
				    l.send_mac(8))
				l.ratchet()
				r = await self.get()
				c = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])
				assert c == b'confirm'
				await self.put(l.send_enc(b'confirmed') +
				    l.send_mac(8))
				r = await self.get()
				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])
				assert cmd[0] == CMD_WAITFOR
				assert int.from_bytes(cmd[1:], byteorder='little') == 30
				await self.put(l.send_enc(cmd[0:1]) +
				    l.send_mac(8))
				r = await self.get()
				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])
				assert cmd[0] == CMD_RUNFOR
				assert int.from_bytes(cmd[1:5], byteorder='little') == 1
				assert int.from_bytes(cmd[5:], byteorder='little') == 50
				await self.put(l.send_enc(cmd[0:1]) +
				    l.send_mac(8))
				r = await self.get()
				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])
				assert cmd[0] == CMD_TERMINATE
				await self.put(l.send_enc(cmd[0:1]) +
				    l.send_mac(8))
		tsd = TestSD()
		l = LORANode(tsd, shared=shared_key)
		await l.start()
		await l.waitfor(30)
		await l.runfor(1, 50)
		await l.terminate()
		await tsd.drain()
		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())
	@timeout(2)
	async def test_ccode(self):
		_self = self
		from ctypes import pointer, sizeof, c_uint8
		# seed the RNG
		prngseed = b'abc123'
		lora_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))
		# Create the state for testing
		commstate = lora_comms.CommsState()
		# These are the expected messages and their arguments
		exptmsgs = [
			(CMD_WAITFOR, [ 30 ]),
			(CMD_RUNFOR, [ 1, 50 ]),
			(CMD_TERMINATE, [ ]),
		]
		def procmsg(msg, outbuf):
			msgbuf = msg._from()
			#print('procmsg:', repr(msg), repr(msgbuf), repr(outbuf))
			cmd = msgbuf[0]
			args = [ int.from_bytes(msgbuf[x:x + 4],
			    byteorder='little') for x in range(1, len(msgbuf),
			    4) ]
			if exptmsgs[0] == (cmd, args):
				exptmsgs.pop(0)
				outbuf[0].pkt[0] = cmd
				outbuf[0].pktlen = 1
			else: #pragma: no cover
				raise RuntimeError('cmd not found')
		# wrap the callback function
		cb = lora_comms.process_msgfunc_t(procmsg)
		class CCodeSD(MockSyncDatagram):
			async def runner(self):
				for expectlen in [ 24, 17, 9, 9, 9 ]:
					# get message
					gb = await self.get()
					r = make_pktbuf(gb)
					outbytes = bytearray(64)
					outbuf = make_pktbuf(outbytes)
					# process the test message
					lora_comms.comms_process(commstate, r,
					    outbuf)
					# make sure the reply matches length
					_self.assertEqual(expectlen,
					    outbuf.pktlen)
					# save what was originally replied
					origmsg = outbuf._from()
					# pretend that the reply didn't make it
					r = make_pktbuf(gb)
					outbuf = make_pktbuf(outbytes)
					lora_comms.comms_process(commstate, r,
					    outbuf)
					# make sure that the reply matches previous
					_self.assertEqual(origmsg, outbuf._from())
					# pass the reply back
					await self.put(outbytes[:outbuf.pktlen])
		# Generate shared key
		shared_key = os.urandom(32)
		# Initialize everything
		lora_comms.comms_init(commstate, cb, make_pktbuf(shared_key))
		# Create test fixture
		tsd = CCodeSD()
		l = LORANode(tsd, shared=shared_key)
		# Send various messages
		await l.start()
		await l.waitfor(30)
		await l.runfor(1, 50)
		await l.terminate()
		await tsd.drain()
		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())
		# Make sure all the expected messages have been
		# processed.
		self.assertFalse(exptmsgs)
 |