from noise.connection import NoiseConnection, Keypair
from twisted.trial import unittest
from twisted.test import proto_helpers
from twisted.internet.protocol import Factory
from twisted.internet import endpoints, reactor, defer, task
# XXX - shouldn't need to access the underlying primitives, but that's what
# noiseprotocol module requires.
from cryptography.hazmat.primitives.asymmetric import x448
from cryptography.hazmat.primitives import serialization

import mock
import os.path
import shutil
import tempfile
import twisted.internet.protocol

__author__ = 'John-Mark Gurney'
__copyright__ = 'Copyright 2019 John-Mark Gurney.  All rights reserved.'
__license__ = '2-clause BSD license'

# Copyright 2019 John-Mark Gurney.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.

# Notes:
# Using XK, so that the connecting party's identity is hidden and that the
# server's party's key is known.
#
# The client and server class names are used to refer to the initiator and
# responder sides.  Even though both client and server each have a server
# component (listen in on a socket to start comms), and a client component
# (create connection).
#
# Noise packets are 16 bytes + length of data
#
# Proposed method to hide message lengths:
# Immediately after handshake completes, each side generates and sends
# an n byte key that will be used for encrypting (algo tbd) their own
# byte counts.  The length field will be encrypted via
# E(pktnum, key) XOR 2 byte length.
#
# Note that authenticating the message length is NOT needed.  This is
# because the noise message blocks themselves are authenticated.  The
# worse that could happen is that a larger read (64k) is done, and then
# the connection aborts because of decryption failure.
#

def _makeunix(path):
	'''Make a properly formed unix path socket string.'''

	return 'unix:%s' % path

def genkeypair():
	'''Generates a keypair, and returns a tuple of (public, private).
	They are encoded as raw bytes, and sutible for use w/ Noise.'''

	key = x448.X448PrivateKey.generate()

	enc = serialization.Encoding.Raw
	pubformat = serialization.PublicFormat.Raw
	privformat = serialization.PrivateFormat.Raw
	encalgo = serialization.NoEncryption()

	pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
	priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)

	return pub, priv

class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
	'''This class acts as a Noise Protocol responder.  The factory that
	creates this Protocol is required to have the properties server_key
	and endpoint.

	The server_key propery is the key for the server that the clients are
	required to have (due to Noise XK protocol used) to authenticate the
	server.

	The endpoint property contains the endpoint as a string that will be
	used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
	and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
	for information on how to use this property.'''

	def connectionMade(self):
		# Initialize Noise
		noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
		self.noise = noise
		if self.mode == 'resp':
			noise.set_as_responder()
		elif self.mode == 'init':
			noise.set_as_initiator()

		noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)
		if hasattr(self.factory, 'pub_key'):
			noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key)

		# Start Handshake
		noise.start_handshake()

	def encData(self, data):
		'''Receive plain text data, encrypt it, and send it down the
		wire.'''

		self.transport.write(self.noise.encrypt(data))

	def dataReceived(self, data):
		'''Receive encrypted data, and write it to the endpoint that
		was connected via the plaintextConnected method.'''

		if not self.noise.handshake_finished:
			self.noise.read_message(data)
			if not self.noise.handshake_finished:
				self.transport.write(self.noise.write_message())

			if self.noise.handshake_finished:
				self.handshakeFinished()
		else:
			r = self.noise.decrypt(data)

			self.endpoint.transport.write(r)

	def handshakeFinished(self):	# pragma: no cover
		'''This function is called when the handshake has been
		completed.  This is used to start data flowing, and to
		do any necessary connection work.'''

		raise NotImplementedError

	def plaintextConnected(self, endpoint):
		'''Connect the plain text endpoint to the factory.  All the
		decrypted data will be written to this protocol,
		(specifically, it's transport).'''

		self.endpoint = endpoint
		self.transport.resumeProducing()

class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
	mode = 'resp'

	def handshakeFinished(self):
		self.transport.pauseProducing()

		# start the connection to the endpoint
		ep = endpoints.clientFromString(reactor, self.factory.endpoint)
		epdef = ep.connect(ServerPTProxyFactory(self))
		epdef.addCallback(self.plaintextConnected)

class TwistedNoiseServerFactory(Factory):
	protocol = TwistedNoiseServerProtocol

	def __init__(self, priv_key, endpoint):
		self.priv_key = priv_key
		self.endpoint = endpoint

# Supporting classes for TwistedNoiseServer
class PTProxyProtocol(twisted.internet.protocol.Protocol):
	'''Simple protocol then when data is received, encrypts the data
	w/ the connected noise protocol.'''

	def dataReceived(self, data):
		self.factory.noiseproto.encData(data)

class ServerPTProxyFactory(Factory):
	protocol = PTProxyProtocol

	def __init__(self, noiseproto):
		self.noiseproto = noiseproto

class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
	mode = 'init'

class ClientPTFactory(Factory):
	protocol = TwistedNoiseClientProtocol

	def __init__(self, priv_key, servpub, sockstr):
		self.priv_key = priv_key
		self.pub_key = servpub
		self.sockstr = sockstr

class TNServerTest(unittest.TestCase):
	@defer.inlineCallbacks
	def setUp(self):
		# setup temporary directory
		d = os.path.realpath(tempfile.mkdtemp())
		self.basetempdir = d
		self.tempdir = os.path.join(d, 'subdir')
		os.mkdir(self.tempdir)

		# Generate key pairs
		self.server_key_pair = genkeypair()
		self.client_key_pair = genkeypair()

		# Server's PT client will be here
		self.protos = []
		self.connectionmade = defer.Deferred()

		class AccProtFactory(Factory):
			protocol = proto_helpers.AccumulatingProtocol

			def __init__(self, tc):
				self.__tc = tc
				Factory.__init__(self)

			protocolConnectionMade = self.connectionmade

			def buildProtocol(self, addr):
				r = Factory.buildProtocol(self, addr)
				self.__tc.protos.append(r)
				return r

		self.AccProtFactory = AccProtFactory

		# Setup PT client endpoint
		sockpath = os.path.join(self.tempdir, 'servptsock')
		ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
		lpobj = yield ep.listen(AccProtFactory(self))

		self.testserv = ep
		self.listenportobj = lpobj
		self.endpoint = _makeunix(sockpath)

		# Setup server, and configure where to connect to.
		self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)

	@defer.inlineCallbacks
	def tearDown(self):
		d = yield self.listenportobj.stopListening()

		shutil.rmtree(self.basetempdir)
		self.tempdir = None

	@defer.inlineCallbacks
	def test_testserver(self):
		#
		# How this test is plumbed:
		#
		#  proto (NoiseConnection) -> self.tr (StringTransport) ->
		#    self.proto (TwistedNoiseServerProtocol) ->
		#    self.proto.endpoint (PTProxyProtocol) -> unix sock ->
		#    self.protos[0] (AccumulatingProtocol)
		#

		# Generate a server protocol, and bind it to a string
		# transport for testing
		self.proto = self.servfactory.buildProtocol(None)
		self.tr = proto_helpers.StringTransport()
		self.proto.makeConnection(self.tr)

		# Create client
		proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
		proto.set_as_initiator()

		# Setup required keys
		proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
		proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])

		proto.start_handshake()

		# Send first message
		message = proto.write_message()
		self.proto.dataReceived(message)

		# Get response
		resp = self.tr.value()
		self.tr.clear()

		# And process it
		proto.read_message(resp)

		# Send second message
		message = proto.write_message()
		self.proto.dataReceived(message)

		# assert handshake finished
		self.assertTrue(proto.handshake_finished)

		# Make sure incoming data is paused till we establish client
		# connection, otherwise no place to write the data
		self.assertEqual(self.tr.producerState, 'paused')

		# Wait for the connection to be made
		d = yield self.connectionmade

		d = yield task.deferLater(reactor, .1, bool, 1)

		# How to make this ready?
		self.assertEqual(self.tr.producerState, 'producing')

		# Encrypt the message
		ptmsg = b'this is a test message'
		encmsg = proto.encrypt(ptmsg)

		# Feed it into the protocol
		self.proto.dataReceived(encmsg)

		# XXX - fix
		# wait to pass it through
		d = yield task.deferLater(reactor, .1, bool, 1)

		# fetch remote end out
		clientend = self.protos[0]
		self.assertEqual(clientend.data, ptmsg)

		# send a message the other direction
		rptmsg = b'this is a different test message going the other way'
		clientend.transport.write(rptmsg)

		# XXX - fix
		# wait to pass it through
		d = yield task.deferLater(reactor, .1, bool, 1)

		# receive it and decrypt it
		resp = self.tr.value()
		self.assertEqual(proto.decrypt(resp), rptmsg)

		if False:
			import time
			s = time.time()
			cnt = 40000
			blksz = 1024
			rnd = os.urandom(blksz)
			for i in range(0, cnt):
				proto.encrypt(rnd)
			e = time.time()

			print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024))

		# clean up connection
		clientend.transport.loseConnection()

	@defer.inlineCallbacks
	def test_clientserver(self):
		# Path that the client "listener" will sit on.
		cptsockpath = os.path.join(self.tempdir, 'clientptsock')

		# Path that the server will sit on
		servsockpath = os.path.join(self.tempdir, 'servsock')

		# Start up the server
		servep = endpoints.serverFromString(reactor, _makeunix(servsockpath))
		servlpobj = yield servep.listen(self.servfactory)

		# Start up the client half
		clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath))
		clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath)))

		# Conenct to the client
		clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath))
		clptconobj = yield clptep.connect(self.AccProtFactory(self))

		# The client plain text connection
		clptproto = self.protos[-1]

		clptproto.transport.write('this is a test')

		# Clean up
		d = yield servlpobj.stopListening()
		d = yield clientlpobj.stopListening()