from noise.connection import NoiseConnection, Keypair
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import hashes
from twistednoise import genkeypair
from cryptography.hazmat.backends import default_backend

import asyncio
import os.path
import shutil
import socket
import tempfile
import threading
import unittest

_backend = default_backend()

def _makeunix(path):
	'''Make a properly formed unix path socket string.'''

	return 'unix:%s' % path

def _parsesockstr(sockstr):
	proto, rem = sockstr.split(':', 1)

	return proto, rem

async def connectsockstr(sockstr):
	proto, rem = _parsesockstr(sockstr)

	reader, writer = await asyncio.open_unix_connection(rem)

	return reader, writer

async def listensockstr(sockstr, cb):
	'''Wrapper for asyncio.start_x_server.

	The format of sockstr is: 'proto:param=value[,param2=value2]'.
	If the proto has a default parameter, the value can be used
	directly, like: 'proto:value'.  This is only allowed when the
	value can unambiguously be determined not to be a param.

	The characters that define 'param' must be all lower case ascii
	characters and may contain an underscore.  The first character
	must not be and underscore.

	Supported protocols:
		unix:
			Default parameter is path.
			The path parameter specifies the path to the
			unix domain socket.  The path MUST start w/ a
			slash if it is used as a default parameter.
	'''

	proto, rem = _parsesockstr(sockstr)

	server = await asyncio.start_unix_server(cb, path=rem)

	return server

# !!python makemessagelengths.py
_handshakelens = \
[72, 72, 88]

def _genciphfun(hash, ad):
	hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
	    salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)

	key = hkdf.derive(hash)
	cipher = Cipher(algorithms.AES(key), modes.ECB(),
	    backend=_backend)
	enctor = cipher.encryptor()

	def encfun(data):
		# Returns the two bytes for length
		val = len(data)
		encbytes = enctor.update(data[:16])
		mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff

		return (val ^ mask).to_bytes(length=2, byteorder='big')

	def decfun(data):
		# takes off the data and returns the total
		# length
		val = int.from_bytes(data[:2], byteorder='big')
		encbytes = enctor.update(data[2:2 + 16])
		mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff

		return val ^ mask

	return encfun, decfun

async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
	rdr, wrr = rdrwrr

	proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')

	proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)

	proto.set_as_responder()

	proto.start_handshake()

	proto.read_message(await rdr.readexactly(_handshakelens[0]))

	wrr.write(proto.write_message())

	proto.read_message(await rdr.readexactly(_handshakelens[2]))

	if not proto.handshake_finished:	# pragma: no cover
		raise RuntimeError('failed to finish handshake')

	# generate the keys for lengths
	_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
	enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')

	reader, writer = await connectsockstr(ptsockstr)

	async def decses():
		try:
			while True:
				try:
					msg = await rdr.readexactly(2 + 16)
				except asyncio.streams.IncompleteReadError:
					if rdr.at_eof():
						return 'dec'

				tlen = declenfun(msg)
				rmsg = await rdr.readexactly(tlen - 16)
				tmsg = msg[2:] + rmsg
				writer.write(proto.decrypt(tmsg))
				await writer.drain()
		finally:
			print('foo')
			# XXX - how to test
			#writer.write_eof()

	async def encses():
		while True:
			ptmsg = await reader.read(65535 - 16) # largest message
			encmsg = proto.encrypt(ptmsg)
			wrr.write(enclenfun(encmsg))
			wrr.write(encmsg)
			await wrr.drain()

	done, pending = await asyncio.wait((decses(), encses()), return_when=asyncio.FIRST_COMPLETED)
	for i in done:
		print('v:', repr(await i))

	done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
	for i in done:
		print('v:', repr(await i))

	return done

class TestListenSocket(unittest.TestCase):
	def test_listensockstr(self):
		# XXX write test
		pass

# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
def async_test(f):
	def wrapper(*args, **kwargs):
		coro = asyncio.coroutine(f)
		future = coro(*args, **kwargs)
		loop = asyncio.get_event_loop()

		# timeout after 2 seconds
		loop.run_until_complete(asyncio.wait_for(future, 2))
	return wrapper

class Tests_misc(unittest.TestCase):
	def test_genciphfun(self):
		enc, dec = _genciphfun(b'0' * 32, b'foobar')

		msg = b'this is a bunch of data'

		tb = enc(msg)

		self.assertEqual(len(msg), dec(tb + msg))

		for i in [ 20, 1384, 64000, 23839, 65535 ]:
			msg = os.urandom(i)
			self.assertEqual(len(msg), dec(enc(msg) + msg))

class Tests(unittest.TestCase):
	def setUp(self):
		# setup temporary directory
		d = os.path.realpath(tempfile.mkdtemp())
		self.basetempdir = d
		self.tempdir = os.path.join(d, 'subdir')
		os.mkdir(self.tempdir)

		# Generate key pairs
		self.server_key_pair = genkeypair()
		self.client_key_pair = genkeypair()

	def tearDown(self):
		shutil.rmtree(self.basetempdir)
		self.tempdir = None

	@async_test
	async def test_server(self):
		# Path that the server will sit on
		servsockpath = os.path.join(self.tempdir, 'servsock')
		servarg = _makeunix(servsockpath)

		# Path that the server will send pt data to
		servsockpath = os.path.join(self.tempdir, 'servptsock')

		# Setup pt target listener
		pttarg = _makeunix(servsockpath)
		ptsock = []
		def ptsockaccept(reader, writer, ptsock=ptsock):
			ptsock.append((reader, writer))

		# Bind to pt listener
		lsock = await listensockstr(pttarg, ptsockaccept)

		nfs = []
		event = asyncio.Event()

		async def runnf(rdr, wrr):
			print('a')
			a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)

			print('b')
			nfs.append(a)
			print('c')
			event.set()
			print('d')

		# Setup server listener
		ssock = await listensockstr(servarg, runnf)

		# Connect to server
		reader, writer = await connectsockstr(servarg)

		# Create client
		proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
		proto.set_as_initiator()

		# Setup required keys
		proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
		proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])

		proto.start_handshake()

		# Send first message
		message = proto.write_message()
		self.assertEqual(len(message), _handshakelens[0])
		writer.write(message)

		# Get response
		respmsg = await reader.readexactly(_handshakelens[1])
		proto.read_message(respmsg)

		# Send final reply
		message = proto.write_message()
		writer.write(message)

		# Make sure handshake has completed
		self.assertTrue(proto.handshake_finished)

		# generate the keys for lengths
		enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
		_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')

		# write a test message
		ptmsg = b'this is a test message that should be a little in length'
		encmsg = proto.encrypt(ptmsg)
		writer.write(enclenfun(encmsg))
		writer.write(encmsg)

		# XXX - how to sync?
		await asyncio.sleep(.1)

		# read the test message
		rptmsg = await ptsock[0][0].readexactly(len(ptmsg))

		self.assertEqual(rptmsg, ptmsg)

		# write a different message
		ptmsg = os.urandom(2843)
		encmsg = proto.encrypt(ptmsg)
		writer.write(enclenfun(encmsg))
		writer.write(encmsg)

		# XXX - how to sync?
		await asyncio.sleep(.1)

		# read the test message
		rptmsg = await ptsock[0][0].readexactly(len(ptmsg))

		self.assertEqual(rptmsg, ptmsg)

		# now try the other way
		ptmsg = os.urandom(912)
		ptsock[0][1].write(ptmsg)

		# find out how much we need to read
		encmsg = await reader.readexactly(2 + 16)
		tlen = declenfun(encmsg)

		# read the rest of the message
		rencmsg = await reader.readexactly(tlen - 16)
		tmsg = encmsg[2:] + rencmsg
		rptmsg = proto.decrypt(tmsg)

		self.assertEqual(rptmsg, ptmsg)

		# shut everything down
		writer.write_eof()
		#ptsock[0][1].write_eof()

		# XXX - how to sync?
		await asyncio.sleep(.1)

		await event.wait()
		print(repr(nfs))