# Copyright 2021 John-Mark Gurney.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.
#

import asyncio
import contextlib
import functools
import itertools
import os
import sys
import unittest

from Strobe.Strobe import Strobe, KeccakF
from Strobe.Strobe import AuthenticationFailed

import syote_comms
from syote_comms import make_pktbuf, X25519
import multicast
from util import *

# Response to command will be the CMD and any arguments if needed.
# The command is encoded as an unsigned byte
CMD_TERMINATE = 1	# no args: terminate the sesssion, reply confirms

# The follow commands are queue up, but will be acknoledged when queued
CMD_WAITFOR = 2		# arg: (length): waits for length seconds
CMD_RUNFOR = 3		# arg: (chan, length): turns on chan for length seconds
CMD_PING = 4		# arg: (): a no op command
CMD_SETUNSET = 5	# arg: (chan, val): sets chan to val
CMD_ADV = 6		# arg: ([cnt]): advances to the next cnt (default 1) command
CMD_CLEAR = 7		# arg: (): clears all future commands, but keeps current running

class LORANode(object):
	'''Implement a LORANode initiator.

	There are currently two implemented modes, one is shared, and then
	a shared key must be provided to the shared keyword argument.

	The other is ecdhe mode, which requires an X25519 key to be passed
	in to init_key, and the respondent's public key to be passed in to
	resp_pub.
	'''

	SHARED_DOMAIN = b'com.funkthat.lora.irrigation.shared.v0.0.1'
	ECDHE_DOMAIN = b'com.funkthat.lora.irrigation.ecdhe.v0.0.1'

	MAC_LEN = 8

	def __init__(self, syncdatagram, shared=None, init_key=None, resp_pub=None):
		self.sd = syncdatagram
		if shared is not None:
			self.st = Strobe(self.SHARED_DOMAIN, F=KeccakF(800))
			self.st.key(shared)
			self.start = self.shared_start
		elif init_key is not None and resp_pub is not None:
			self.st = Strobe(self.ECDHE_DOMAIN, F=KeccakF(800))
			self.key = init_key
			self.resp_pub = resp_pub
			self.st.key(init_key.getpub() + resp_pub)
			self.start = self.ecdhe_start
		else:
			raise RuntimeError('invalid combination of keys provided')

	async def shared_start(self):
		resp = await self.sendrecvvalid(os.urandom(16) + b'reqreset')

		self.st.ratchet()

		pkt = await self.sendrecvvalid(b'confirm')

		if pkt != b'confirmed':
			raise RuntimeError('got invalid response: %s' %
			    repr(pkt))

	async def ecdhe_start(self):
		ephkey = X25519.gen()

		resp = await self.sendrecvvalid(ephkey.getpub() + b'reqreset',
		    fun=lambda: self.st.key(ephkey.dh(self.resp_pub) + self.key.dh(self.resp_pub)))

		self.st.key(ephkey.dh(resp) + self.key.dh(resp))

		pkt = await self.sendrecvvalid(b'confirm')

		if pkt != b'confirmed':
			raise RuntimeError('got invalid response: %s' %
			    repr(pkt))

	async def sendrecvvalid(self, msg, fun=None):
		msg = self.st.send_enc(msg) + self.st.send_mac(self.MAC_LEN)

		if fun is not None:
			fun()

		origstate = self.st.copy()

		while True:
			resp = await self.sd.sendtillrecv(msg, .50)
			#_debprint('got:', resp)

			# skip empty messages
			if len(resp) == 0:
				continue

			try:
				decmsg = self.st.recv_enc(resp[:-self.MAC_LEN])
				self.st.recv_mac(resp[-self.MAC_LEN:])
				break
			except AuthenticationFailed:
				# didn't get a valid packet, restore
				# state and retry

				#_debprint('failed')
				self.st.set_state_from(origstate)

		#_debprint('got rep:', repr(resp), repr(decmsg))
		return decmsg

	@staticmethod
	def _encodeargs(*args):
		r = []
		for i in args:
			r.append(i.to_bytes(4, byteorder='little'))

		return b''.join(r)

	async def _sendcmd(self, cmd, *args):
		cmdbyte = cmd.to_bytes(1, byteorder='little')
		resp = await self.sendrecvvalid(cmdbyte + self._encodeargs(*args))

		if resp[0:1] != cmdbyte:
			raise RuntimeError(
			    'response does not match, got: %s, expected: %s' %
			    (repr(resp[0:1]), repr(cmdbyte)))

	async def waitfor(self, length):
		return await self._sendcmd(CMD_WAITFOR, length)

	async def runfor(self, chan, length):
		return await self._sendcmd(CMD_RUNFOR, chan, length)

	async def setunset(self, chan, val):
		return await self._sendcmd(CMD_SETUNSET, chan, val)

	async def ping(self):
		return await self._sendcmd(CMD_PING)

	async def adv(self, cnt=None):
		args = ()
		if cnt is not None:
			args = (cnt, )
		return await self._sendcmd(CMD_ADV, *args)

	async def clear(self):
		return await self._sendcmd(CMD_CLEAR)

	async def terminate(self):
		return await self._sendcmd(CMD_TERMINATE)

class SyncDatagram(object):
	'''Base interface for a more simple synchronous interface.'''

	async def recv(self, timeout=None): #pragma: no cover
		'''Receive a datagram.  If timeout is not None, wait that many
		seconds, and if nothing is received in that time, raise an
		asyncio.TimeoutError exception.'''

		raise NotImplementedError

	async def send(self, data): #pragma: no cover

		raise NotImplementedError

	async def sendtillrecv(self, data, freq):
		'''Send the datagram in data, every freq seconds until a datagram
		is received.  If timeout seconds happen w/o receiving a datagram,
		then raise an TimeoutError exception.'''

		while True:
			#_debprint('sending:', repr(data))
			await self.send(data)
			try:
				return await self.recv(freq)
			except asyncio.TimeoutError:
				pass

class MulticastSyncDatagram(SyncDatagram):
	'''
	An implementation of SyncDatagram that uses the provided
	multicast address maddr as the source/sink of the packets.

	Note that once created, the start coroutine needs to be
	await'd before being passed to a LORANode so that everything
	is running.
	'''

	# Note: sent packets will be received.  A similar method to
	# what was done in multicast.{to,from}_loragw could be done
	# here as well, that is passing in a set of packets to not
	# pass back up.

	def __init__(self, maddr):
		self.maddr = maddr
		self._ignpkts = set()

	async def start(self):
		self.mr = await multicast.create_multicast_receiver(self.maddr)
		self.mt = await multicast.create_multicast_transmitter(
		    self.maddr)

	async def _recv(self):
		while True:
			pkt = await self.mr.recv()
			pkt = pkt[0]
			if pkt not in self._ignpkts:
				return pkt

			self._ignpkts.remove(pkt)

	async def recv(self, timeout=None): #pragma: no cover
		r = await asyncio.wait_for(self._recv(), timeout=timeout)

		return r

	async def send(self, data): #pragma: no cover
		self._ignpkts.add(bytes(data))
		await self.mt.send(data)

	def close(self):
		'''Shutdown communications.'''

		self.mr.close()
		self.mr = None
		self.mt.close()
		self.mt = None

def listsplit(lst, item):
	try:
		idx = lst.index(item)
	except ValueError:
		return lst, []

	return lst[:idx], lst[idx + 1:]

async def main():
	import argparse

	from loraserv import DEFAULT_MADDR as maddr

	parser = argparse.ArgumentParser()

	parser.add_argument('-f', dest='schedfile', metavar='filename', type=str,
	    help='Use commands from the file.  One command per line.')
	parser.add_argument('-r', dest='client', metavar='module:function', type=str,
	    help='Create a respondant instead of sending commands.  Commands will be passed to the function.')
	parser.add_argument('-s', dest='shared_key', metavar='shared_key', type=str, required=True,
	    help='The shared key (encoded as UTF-8) to use.')
	parser.add_argument('args', metavar='CMD_ARG', type=str, nargs='*',
	    help='Various commands to send to the device.')

	args = parser.parse_args()

	shared_key = args.shared_key.encode('utf-8')

	if args.client:
		# Run a client
		mr = await multicast.create_multicast_receiver(maddr)
		mt = await multicast.create_multicast_transmitter(maddr)

		from ctypes import c_uint8

		# seed the RNG
		prngseed = os.urandom(64)
		syote_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))

		# Create the state for testing
		commstate = syote_comms.CommsState()

		import util_load
		client_func = util_load.load_application(args.client)

		def client_call(msg, outbuf):
			ret = client_func(msg._from())

			if len(ret) > outbuf[0].pktlen:
				ret = b'error, too long buffer: %d' % len(ret)

			outbuf[0].pktlen = min(len(ret), outbuf[0].pktlen)
			for i in range(outbuf[0].pktlen):
				outbuf[0].pkt[i] = ret[i]

		cb = syote_comms.process_msgfunc_t(client_call)

		# Initialize everything
		syote_comms.comms_init(commstate, cb, make_pktbuf(shared_key))

		try:
			while True:
				pkt = await mr.recv()
				msg = pkt[0]

				out = syote_comms.comms_process_wrap(
				    commstate, msg)

				if out:
					await mt.send(out)
		finally:
			mr.close()
			mt.close()
		sys.exit(0)

	msd = MulticastSyncDatagram(maddr)
	await msd.start()

	l = LORANode(msd, shared=shared_key)

	await l.start()

	valid_cmds = {
	    'waitfor', 'setunset', 'runfor', 'ping', 'adv', 'clear',
	    'terminate',
	}

	if args.args and args.schedfile:
		parser.error('only one of -f or arguments can be specified.')

	if args.args:
		cmds = list(args.args)
		cmdargs = []
		while cmds:
			a, cmds = listsplit(cmds, '--')
			cmdargs.append(a)
	else:
		with open(args.schedfile) as fp:
			cmdargs = [ x.split() for x in fp.readlines() ]

	while cmdargs:
		cmd, *args = cmdargs.pop(0)

		if cmd not in valid_cmds:
			print('invalid command:', repr(cmd))
			sys.exit(1)

		fun = getattr(l, cmd)

		await fun(*(int(x) for x in args))

if __name__ == '__main__':
	asyncio.run(main())

class MockSyncDatagram(SyncDatagram):
	'''A testing version of SyncDatagram.  Define a method runner which
	implements part of the sequence.  In the function, await on either
	self.get, to wait for the other side to send something, or await
	self.put w/ data to send.'''

	def __init__(self):
		self.sendq = asyncio.Queue()
		self.recvq = asyncio.Queue()
		self.task = asyncio.create_task(self.runner())

		self.get = self.sendq.get
		self.put = self.recvq.put

	async def drain(self):
		'''Wait for the runner thread to finish up.'''

		return await self.task

	async def runner(self): #pragma: no cover
		raise NotImplementedError

	async def recv(self, timeout=None):
		return await self.recvq.get()

	async def send(self, data):
		return await self.sendq.put(data)

	def __del__(self): #pragma: no cover
		if self.task is not None and not self.task.done():
			self.task.cancel()

class TestSyncData(unittest.IsolatedAsyncioTestCase):
	async def test_syncsendtillrecv(self):
		class MySync(SyncDatagram):
			def __init__(self):
				self.sendq = []
				self.resp = [ asyncio.TimeoutError(), b'a' ]

			async def recv(self, timeout=None):
				assert timeout == 1
				r = self.resp.pop(0)
				if isinstance(r, Exception):
					raise r

				return r

			async def send(self, data):
				self.sendq.append(data)

		ms = MySync()

		r = await ms.sendtillrecv(b'foo', 1)

		self.assertEqual(r, b'a')
		self.assertEqual(ms.sendq, [ b'foo', b'foo' ])

class AsyncSequence(object):
	'''
	Object used for sequencing async functions.  To use, use the
	asynchronous context manager created by the sync method.  For
	example:
	seq = AsyncSequence()
	async func1():
		async with seq.sync(1):
			second_fun()

	async func2():
		async with seq.sync(0):
			first_fun()

	This will make sure that function first_fun is run before running
	the function second_fun.  If a previous block raises an Exception,
	it will be passed up, and all remaining blocks (and future ones)
	will raise a CancelledError to help ensure that any tasks are
	properly cleaned up.
	'''

	def __init__(self, positerfactory=lambda: itertools.count()):
		'''The argument positerfactory, is a factory that will
		create an iterator that will be used for the values that
		are passed to the sync method.'''

		self.positer = positerfactory()
		self.token = object()
		self.die = False
		self.waiting = {
			next(self.positer): self.token
		}

	async def simpsync(self, pos):
		async with self.sync(pos):
			pass

	@contextlib.asynccontextmanager
	async def sync(self, pos):
		'''An async context manager that will be run when it's
		turn arrives.  It will only run when all the previous
		items in the iterator has been successfully run.'''

		if self.die:
			raise asyncio.CancelledError('seq cancelled')

		if pos in self.waiting:
			if self.waiting[pos] is not self.token:
				raise RuntimeError('pos already waiting!')
		else:
			fut = asyncio.Future()
			self.waiting[pos] = fut
			await fut

		# our time to shine!
		del self.waiting[pos]

		try:
			yield None
		except Exception as e:
			# if we got an exception, things went pear shaped,
			# shut everything down, and any future calls.

			#_debprint('dieing...', repr(e))
			self.die = True

			# cancel existing blocks
			while self.waiting:
				k, v = self.waiting.popitem()
				#_debprint('canceling: %s' % repr(k))
				if v is self.token:
					continue

				# for Python 3.9:
				# msg='pos %s raised exception: %s' %
				#     (repr(pos), repr(e))
				v.cancel()

			# populate real exception up
			raise
		else:
			# handle next
			nextpos = next(self.positer)

			if nextpos in self.waiting:
				#_debprint('np:', repr(self), nextpos,
				#    repr(self.waiting[nextpos]))
				self.waiting[nextpos].set_result(None)
			else:
				self.waiting[nextpos] = self.token

class TestSequencing(unittest.IsolatedAsyncioTestCase):
	@timeout(2)
	async def test_seq_alreadywaiting(self):
		waitseq = AsyncSequence()

		seq = AsyncSequence()

		async def fun1():
			async with waitseq.sync(1):
				pass

		async def fun2():
			async with seq.sync(1):
				async with waitseq.sync(1): # pragma: no cover
					pass

		task1 = asyncio.create_task(fun1())
		task2 = asyncio.create_task(fun2())

		# spin things to make sure things advance
		await asyncio.sleep(0)

		async with seq.sync(0):
			pass

		with self.assertRaises(RuntimeError):
			await task2

		async with waitseq.sync(0):
			pass

		await task1

	@timeout(2)
	async def test_seqexc(self):
		seq = AsyncSequence()

		excseq = AsyncSequence()

		async def excfun1():
			async with seq.sync(1):
				pass

			async with excseq.sync(0):
				raise ValueError('foo')

		# that a block that enters first, but runs after
		# raises an exception
		async def excfun2():
			async with seq.sync(0):
				pass

			async with excseq.sync(1): # pragma: no cover
				pass

		# that a block that enters after, raises an
		# exception
		async def excfun3():
			async with seq.sync(2):
				pass

			async with excseq.sync(2): # pragma: no cover
				pass

		task1 = asyncio.create_task(excfun1())
		task2 = asyncio.create_task(excfun2())
		task3 = asyncio.create_task(excfun3())

		with self.assertRaises(ValueError):
			await task1

		with self.assertRaises(asyncio.CancelledError):
			await task2

		with self.assertRaises(asyncio.CancelledError):
			await task3

	@timeout(2)
	async def test_seq(self):
		# test that a seq object when created
		seq = AsyncSequence(lambda: itertools.count(1))

		col = []

		async def fun1():
			async with seq.sync(1):
				col.append(1)

			async with seq.sync(2):
				col.append(2)

			async with seq.sync(4):
				col.append(4)

		async def fun2():
			async with seq.sync(3):
				col.append(3)

			async with seq.sync(6):
				col.append(6)

		async def fun3():
			async with seq.sync(5):
				col.append(5)

		# and various functions are run
		task1 = asyncio.create_task(fun1())
		task2 = asyncio.create_task(fun2())
		task3 = asyncio.create_task(fun3())

		# and the functions complete
		await task3
		await task2
		await task1

		# that the order they ran in was correct
		self.assertEqual(col, list(range(1, 7)))

class TestLORANode(unittest.IsolatedAsyncioTestCase):
	shared_domain = b'com.funkthat.lora.irrigation.shared.v0.0.1'
	ecdhe_domain = b'com.funkthat.lora.irrigation.ecdhe.v0.0.1'

	def test_initparams(self):
		# make sure no keys fails
		with self.assertRaises(RuntimeError):
			l = LORANode(None)

	@timeout(2)
	async def test_lora_ecdhe(self):
		_self = self
		initkey = X25519.gen()
		respkey = X25519.gen()

		class TestSD(MockSyncDatagram):
			async def sendgettest(self, msg):
				'''Send the message, but make sure that if a
				bad message is sent afterward, that it replies
				w/ the same previous message.
				'''

				await self.put(msg)
				resp = await self.get()

				await self.put(b'bogusmsg' * 5)

				resp2 = await self.get()

				_self.assertEqual(resp, resp2)

				return resp

			async def runner(self):
				# as respondant

				l = Strobe(_self.ecdhe_domain, F=KeccakF(800))

				l.key(initkey.getpub() + respkey.getpub())

				# start handshake
				r = await self.get()

				# get eph key w/ reqreset
				pkt = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert pkt.endswith(b'reqreset')

				ephpub = pkt[:-len(b'reqreset')]

				# make sure junk gets ignored
				await self.put(b'sdlfkj')

				# and that the packet remains the same
				_self.assertEqual(r, await self.get())

				# and a couple more times
				await self.put(b'0' * 24)
				_self.assertEqual(r, await self.get())
				await self.put(b'0' * 32)
				_self.assertEqual(r, await self.get())

				# update the keys
				l.key(respkey.dh(ephpub) + respkey.dh(initkey.getpub()))

				# generate our eph key
				ephkey = X25519.gen()

				# send the response
				await self.put(l.send_enc(ephkey.getpub()) +
				    l.send_mac(8))

				l.key(ephkey.dh(ephpub) + ephkey.dh(initkey.getpub()))

				# get the confirmation message
				r = await self.get()

				# test the resend capabilities
				await self.put(b'0' * 24)
				_self.assertEqual(r, await self.get())

				# decode confirmation message
				c = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				# assert that we got it
				_self.assertEqual(c, b'confirm')

				# send confirmed reply
				r = await self.sendgettest(l.send_enc(
				    b'confirmed') + l.send_mac(8))

				# test and decode remaining command messages
				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert cmd[0] == CMD_WAITFOR
				assert int.from_bytes(cmd[1:],
				     byteorder='little') == 30

				r = await self.sendgettest(l.send_enc(
				    cmd[0:1]) + l.send_mac(8))

				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert cmd[0] == CMD_RUNFOR
				assert int.from_bytes(cmd[1:5],
				     byteorder='little') == 1
				assert int.from_bytes(cmd[5:],
				     byteorder='little') == 50

				r = await self.sendgettest(l.send_enc(
				    cmd[0:1]) + l.send_mac(8))

				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert cmd[0] == CMD_TERMINATE

				await self.put(l.send_enc(cmd[0:1]) +
				    l.send_mac(8))

		tsd = TestSD()

		# make sure it fails w/o both specified
		with self.assertRaises(RuntimeError):
			l = LORANode(tsd, init_key=initkey)

		with self.assertRaises(RuntimeError):
			l = LORANode(tsd, resp_pub=respkey.getpub())

		l = LORANode(tsd, init_key=initkey, resp_pub=respkey.getpub())

		await l.start()

		await l.waitfor(30)

		await l.runfor(1, 50)

		await l.terminate()

		await tsd.drain()

		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())
		#_debprint('done')

	@timeout(2)
	async def test_lora_shared(self):
		_self = self
		shared_key = os.urandom(32)

		class TestSD(MockSyncDatagram):
			async def sendgettest(self, msg):
				'''Send the message, but make sure that if a
				bad message is sent afterward, that it replies
				w/ the same previous message.
				'''

				await self.put(msg)
				resp = await self.get()

				await self.put(b'bogusmsg' * 5)

				resp2 = await self.get()

				_self.assertEqual(resp, resp2)

				return resp

			async def runner(self):
				l = Strobe(TestLORANode.shared_domain, F=KeccakF(800))

				l.key(shared_key)

				# start handshake
				r = await self.get()

				pkt = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert pkt.endswith(b'reqreset')

				# make sure junk gets ignored
				await self.put(b'sdlfkj')

				# and that the packet remains the same
				_self.assertEqual(r, await self.get())

				# and a couple more times
				await self.put(b'0' * 24)
				_self.assertEqual(r, await self.get())
				await self.put(b'0' * 32)
				_self.assertEqual(r, await self.get())

				# send the response
				await self.put(l.send_enc(os.urandom(16)) +
				    l.send_mac(8))

				# require no more back tracking at this point
				l.ratchet()

				# get the confirmation message
				r = await self.get()

				# test the resend capabilities
				await self.put(b'0' * 24)
				_self.assertEqual(r, await self.get())

				# decode confirmation message
				c = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				# assert that we got it
				_self.assertEqual(c, b'confirm')

				# send confirmed reply
				r = await self.sendgettest(l.send_enc(
				    b'confirmed') + l.send_mac(8))

				# test and decode remaining command messages
				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert cmd[0] == CMD_WAITFOR
				assert int.from_bytes(cmd[1:],
				     byteorder='little') == 30

				r = await self.sendgettest(l.send_enc(
				    cmd[0:1]) + l.send_mac(8))

				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert cmd[0] == CMD_RUNFOR
				assert int.from_bytes(cmd[1:5],
				     byteorder='little') == 1
				assert int.from_bytes(cmd[5:],
				     byteorder='little') == 50

				r = await self.sendgettest(l.send_enc(
				    cmd[0:1]) + l.send_mac(8))

				cmd = l.recv_enc(r[:-8])
				l.recv_mac(r[-8:])

				assert cmd[0] == CMD_TERMINATE

				await self.put(l.send_enc(cmd[0:1]) +
				    l.send_mac(8))

		tsd = TestSD()
		l = LORANode(tsd, shared=shared_key)

		await l.start()

		await l.waitfor(30)

		await l.runfor(1, 50)

		await l.terminate()

		await tsd.drain()

		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())
		#_debprint('done')

	@timeout(2)
	async def test_ccode_badmsgs(self):
		# Test to make sure that various bad messages in the
		# handshake process are rejected even if the attacker
		# has the correct key.  This just keeps the protocol
		# tight allowing for variations in the future.

		# seed the RNG
		prngseed = b'abc123'
		from ctypes import c_uint8
		syote_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))

		# Create the state for testing
		commstate = syote_comms.CommsState()

		cb = syote_comms.process_msgfunc_t(lambda msg, outbuf: None)

		# Generate shared key
		shared_key = os.urandom(32)

		# Initialize everything
		syote_comms.comms_init(commstate, cb, make_pktbuf(shared_key), None, None)

		# Create test fixture, only use it to init crypto state
		tsd = SyncDatagram()
		l = LORANode(tsd, shared=shared_key)

		# copy the crypto state
		cstate = l.st.copy()

		# compose an incorrect init message
		msg = os.urandom(16) + b'othre'
		msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)

		out = syote_comms.comms_process_wrap(commstate, msg)

		self.assertFalse(out)

		# that varous short messages don't cause problems
		for i in range(10):
			out = syote_comms.comms_process_wrap(commstate, b'0' * i)

			self.assertFalse(out)

		# copy the crypto state
		cstate = l.st.copy()

		# compose an incorrect init message
		msg = os.urandom(16) + b' eqreset'
		msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)

		out = syote_comms.comms_process_wrap(commstate, msg)

		self.assertFalse(out)

		# compose the correct init message
		msg = os.urandom(16) + b'reqreset'
		msg = l.st.send_enc(msg) + l.st.send_mac(l.MAC_LEN)

		out = syote_comms.comms_process_wrap(commstate, msg)

		l.st.recv_enc(out[:-l.MAC_LEN])
		l.st.recv_mac(out[-l.MAC_LEN:])

		l.st.ratchet()

		# copy the crypto state
		cstate = l.st.copy()

		# compose an incorrect confirmed message
		msg = b'onfirm'
		msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)

		out = syote_comms.comms_process_wrap(commstate, msg)

		self.assertFalse(out)

		# copy the crypto state
		cstate = l.st.copy()

		# compose an incorrect confirmed message
		msg = b' onfirm'
		msg = cstate.send_enc(msg) + cstate.send_mac(l.MAC_LEN)

		out = syote_comms.comms_process_wrap(commstate, msg)

		self.assertFalse(out)

	@timeout(2)
	async def test_ccode_ecdhe(self):
		_self = self
		from ctypes import c_uint8

		# seed the RNG
		prngseed = b'abc123'
		syote_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))

		# Create the state for testing
		commstate = syote_comms.CommsState()

		# These are the expected messages and their arguments
		exptmsgs = [
			(CMD_WAITFOR, [ 30 ]),
			(CMD_RUNFOR, [ 1, 50 ]),
			(CMD_PING, [ ]),
			(CMD_TERMINATE, [ ]),
		]
		def procmsg(msg, outbuf):
			msgbuf = msg._from()
			cmd = msgbuf[0]
			args = [ int.from_bytes(msgbuf[x:x + 4],
			    byteorder='little') for x in range(1, len(msgbuf),
			    4) ]

			if exptmsgs[0] == (cmd, args):
				exptmsgs.pop(0)
				outbuf[0].pkt[0] = cmd
				outbuf[0].pktlen = 1
			else: #pragma: no cover
				raise RuntimeError('cmd not found')

		# wrap the callback function
		cb = syote_comms.process_msgfunc_t(procmsg)

		class CCodeSD(MockSyncDatagram):
			async def runner(self):
				for expectlen in [ 40, 17, 9, 9, 9, 9 ]:
					# get message
					inmsg = await self.get()

					# process the test message
					out = syote_comms.comms_process_wrap(
					    commstate, inmsg)

					# make sure the reply matches length
					_self.assertEqual(expectlen, len(out))

					# save what was originally replied
					origmsg = out

					# pretend that the reply didn't make it
					out = syote_comms.comms_process_wrap(
					    commstate, inmsg)

					# make sure that the reply matches
					# the previous
					_self.assertEqual(origmsg, out)

					# pass the reply back
					await self.put(out)

		# Generate keys
		initkey = X25519.gen()
		respkey = X25519.gen()

		# Initialize everything
		syote_comms.comms_init(commstate, cb, None, make_pktbuf(respkey.getpriv()), make_pktbuf(initkey.getpub()))

		# Create test fixture
		tsd = CCodeSD()
		l = LORANode(tsd, init_key=initkey, resp_pub=respkey.getpub())

		# Send various messages
		await l.start()

		await l.waitfor(30)

		await l.runfor(1, 50)

		await l.ping()

		await l.terminate()

		await tsd.drain()

		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())

		# Make sure all the expected messages have been
		# processed.
		self.assertFalse(exptmsgs)
		#_debprint('done')

	@timeout(2)
	async def test_ccode(self):
		_self = self
		from ctypes import c_uint8

		# seed the RNG
		prngseed = b'abc123'
		syote_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))

		# Create the state for testing
		commstate = syote_comms.CommsState()

		# These are the expected messages and their arguments
		exptmsgs = [
			(CMD_WAITFOR, [ 30 ]),
			(CMD_RUNFOR, [ 1, 50 ]),
			(CMD_PING, [ ]),
			(CMD_TERMINATE, [ ]),
		]
		def procmsg(msg, outbuf):
			msgbuf = msg._from()
			cmd = msgbuf[0]
			args = [ int.from_bytes(msgbuf[x:x + 4],
			    byteorder='little') for x in range(1, len(msgbuf),
			    4) ]

			if exptmsgs[0] == (cmd, args):
				exptmsgs.pop(0)
				outbuf[0].pkt[0] = cmd
				outbuf[0].pktlen = 1
			else: #pragma: no cover
				raise RuntimeError('cmd not found')

		# wrap the callback function
		cb = syote_comms.process_msgfunc_t(procmsg)

		class CCodeSD(MockSyncDatagram):
			async def runner(self):
				for expectlen in [ 24, 17, 9, 9, 9, 9 ]:
					# get message
					inmsg = await self.get()

					# process the test message
					out = syote_comms.comms_process_wrap(
					    commstate, inmsg)

					# make sure the reply matches length
					_self.assertEqual(expectlen, len(out))

					# save what was originally replied
					origmsg = out

					# pretend that the reply didn't make it
					out = syote_comms.comms_process_wrap(
					    commstate, inmsg)

					# make sure that the reply matches
					# the previous
					_self.assertEqual(origmsg, out)

					# pass the reply back
					await self.put(out)

		# Generate shared key
		shared_key = os.urandom(32)

		# Initialize everything
		syote_comms.comms_init(commstate, cb, make_pktbuf(shared_key), None, None)

		# Create test fixture
		tsd = CCodeSD()
		l = LORANode(tsd, shared=shared_key)

		# Send various messages
		await l.start()

		await l.waitfor(30)

		await l.runfor(1, 50)

		await l.ping()

		await l.terminate()

		await tsd.drain()

		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())

		# Make sure all the expected messages have been
		# processed.
		self.assertFalse(exptmsgs)
		#_debprint('done')

	@timeout(2)
	async def test_ccode_newsession(self):
		'''This test is to make sure that if an existing session
		is running, that a new session can be established, and that
		when it does, the old session becomes inactive.
		'''

		_self = self
		from ctypes import c_uint8

		seq = AsyncSequence()

		# seed the RNG
		prngseed = b'abc123'
		syote_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))

		# Create the state for testing
		commstate = syote_comms.CommsState()

		# These are the expected messages and their arguments
		exptmsgs = [
			(CMD_WAITFOR, [ 30 ]),
			(CMD_WAITFOR, [ 70 ]),
			(CMD_WAITFOR, [ 40 ]),
			(CMD_TERMINATE, [ ]),
		]
		def procmsg(msg, outbuf):
			msgbuf = msg._from()
			cmd = msgbuf[0]
			args = [ int.from_bytes(msgbuf[x:x + 4],
			    byteorder='little') for x in range(1, len(msgbuf),
			    4) ]

			if exptmsgs[0] == (cmd, args):
				exptmsgs.pop(0)
				outbuf[0].pkt[0] = cmd
				outbuf[0].pktlen = 1
			else: #pragma: no cover
				raise RuntimeError('cmd not found: %d' % cmd)

		# wrap the callback function
		cb = syote_comms.process_msgfunc_t(procmsg)

		class FlipMsg(object):
			async def flipmsg(self):
				# get message
				inmsg = await self.get()

				# process the test message
				out = syote_comms.comms_process_wrap(
				    commstate, inmsg)

				# pass the reply back
				await self.put(out)

		# this class always passes messages, this is
		# used for the first session.
		class CCodeSD1(MockSyncDatagram, FlipMsg):
			async def runner(self):
				for i in range(3):
					await self.flipmsg()

				async with seq.sync(0):
					# create bogus message
					inmsg = b'0'*24

					# process the bogus message
					out = syote_comms.comms_process_wrap(
					    commstate, inmsg)

					# make sure there was not a response
					_self.assertFalse(out)

				await self.flipmsg()

		# this one is special in that it will pause after the first
		# message to ensure that the previous session will continue
		# to work, AND that if a new "new" session comes along, it
		# will override the previous new session that hasn't been
		# confirmed yet.
		class CCodeSD2(MockSyncDatagram, FlipMsg):
			async def runner(self):
				# pass one message from the new session
				async with seq.sync(1):
					# There might be a missing case
					# handled for when the confirmed
					# message is generated, but lost.
					await self.flipmsg()

					# and the old session is still active
					await l.waitfor(70)

				async with seq.sync(2):
					for i in range(3):
						await self.flipmsg()

		# Generate shared key
		shared_key = os.urandom(32)

		# Initialize everything
		syote_comms.comms_init(commstate, cb, make_pktbuf(shared_key), None, None)

		# Create test fixture
		tsd = CCodeSD1()
		l = LORANode(tsd, shared=shared_key)

		# Send various messages
		await l.start()

		await l.waitfor(30)

		# Ensure that a new one can take over
		tsd2 = CCodeSD2()

		l2 = LORANode(tsd2, shared=shared_key)

		# Send various messages
		await l2.start()

		await l2.waitfor(40)

		await l2.terminate()

		await tsd.drain()
		await tsd2.drain()

		# Make sure all messages have been processed
		self.assertTrue(tsd.sendq.empty())
		self.assertTrue(tsd.recvq.empty())
		self.assertTrue(tsd2.sendq.empty())
		self.assertTrue(tsd2.recvq.empty())

		# Make sure all the expected messages have been
		# processed.
		self.assertFalse(exptmsgs)

class TestLoRaNodeMulticast(unittest.IsolatedAsyncioTestCase):
	# see: https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1
	maddr = ('224.0.0.198', 48542)

	@timeout(2)
	async def test_multisyncdgram(self):
		# Test the implementation of the multicast version of
		# SyncDatagram

		_self = self
		from ctypes import c_uint8

		# seed the RNG
		prngseed = b'abc123'
		syote_comms.strobe_seed_prng((c_uint8 *
		    len(prngseed))(*prngseed), len(prngseed))

		# Create the state for testing
		commstate = syote_comms.CommsState()

		# These are the expected messages and their arguments
		exptmsgs = [
			(CMD_WAITFOR, [ 30 ]),
			(CMD_PING, [ ]),
			(CMD_TERMINATE, [ ]),
		]
		def procmsg(msg, outbuf):
			msgbuf = msg._from()
			cmd = msgbuf[0]
			args = [ int.from_bytes(msgbuf[x:x + 4],
			    byteorder='little') for x in range(1, len(msgbuf),
			    4) ]

			if exptmsgs[0] == (cmd, args):
				exptmsgs.pop(0)
				outbuf[0].pkt[0] = cmd
				outbuf[0].pktlen = 1
			else: #pragma: no cover
				raise RuntimeError('cmd not found')

		# wrap the callback function
		cb = syote_comms.process_msgfunc_t(procmsg)

		# Generate shared key
		shared_key = os.urandom(32)

		# Initialize everything
		syote_comms.comms_init(commstate, cb, make_pktbuf(shared_key), None, None)

		# create the object we are testing
		msd = MulticastSyncDatagram(self.maddr)

		seq = AsyncSequence()

		async def clienttask():
			mr = await multicast.create_multicast_receiver(
			    self.maddr)
			mt = await multicast.create_multicast_transmitter(
			    self.maddr)

			try:
				# make sure the above threads are running
				await seq.simpsync(0)

				while True:
					pkt = await mr.recv()
					msg = pkt[0]

					out = syote_comms.comms_process_wrap(
					    commstate, msg)

					if out:
						await mt.send(out)
			finally:
				mr.close()
				mt.close()

		task = asyncio.create_task(clienttask())

		# start it
		await msd.start()

		# pass it to a node
		l = LORANode(msd, shared=shared_key)

		await seq.simpsync(1)

		# Send various messages
		await l.start()

		await l.waitfor(30)

		await l.ping()

		await l.terminate()

		# shut things down
		ln = None
		msd.close()

		task.cancel()

		with self.assertRaises(asyncio.CancelledError):
			await task