From dfcd36caab0921a22569e2317512419f32fcb66b Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Mon, 21 Oct 2019 14:59:15 -0700 Subject: [PATCH] check point for some work... Changing directions, but don't want to lose the work of mocking the client connection.. --- requirements.txt | 1 + twistednoise.py | 116 +++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/requirements.txt b/requirements.txt index 41e7216..54f5c36 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ coverage -e git+https://github.com/jmgurney/noiseprotocol.git@ab6f8ebe0e28f5a4105928c13baddcfdc43b7e82#egg=noiseprotocol twisted +mock diff --git a/twistednoise.py b/twistednoise.py index 3ca9c73..65f8f9f 100644 --- a/twistednoise.py +++ b/twistednoise.py @@ -2,17 +2,19 @@ from twisted.trial import unittest from twisted.test import proto_helpers from noise.connection import NoiseConnection, Keypair from twisted.internet.protocol import Factory +from twisted.internet import endpoints, reactor,defer # XXX - shouldn't need to access the underlying primitives, but that's what # noiseprotocol module requires. from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives import serialization +import twisted.internet.protocol +import mock + # Notes: # Using XK, so that the connecting party's identity is hidden and that the # server's party's key is known. -import twisted.internet.protocol - def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' @@ -30,6 +32,19 @@ def genkeypair(): return pub, priv class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): + '''This class acts as a Noise Protocol responder. The factory that + creates this Protocol is required to have the properties server_key + and endpoint. + + The server_key propery is the key for the server that the clients are + required to have (due to Noise XK protocol used) to authenticate the + server. + + The endpoint property contains the endpoint as a string that will be + used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString + and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients + for information on how to use this property.''' + def connectionMade(self): # Initialize Noise noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') @@ -45,29 +60,81 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): self.noise.read_message(data) if not self.noise.handshake_finished: self.transport.write(self.noise.write_message()) + if self.noise.handshake_finished: + self.transport.pauseProducing() + + # start the connection to the endpoint + ep = endpoints.clientFromString(reactor, self.factory.endpoint) + epdef = ep.connect(ClientProxyFactory(self)) + epdef.addCallback(self.proxyConnected) else: r = self.noise.decrypt(data) - # echo it - self.transport.write(self.noise.encrypt(r)) + + self.endpoint.write(r) + + def proxyConnected(self, endpoint): + print('pc') + self.endpoint = endpoint + self.transport.resumeProducing() + +class ClientProxyProtocol(twisted.internet.protocol.Protocol): + pass + +class ClientProxyFactory(Factory): + protocol = ClientProxyProtocol + + def __init__(self, noiseproto): + self.noiseproto = noiseproto class TwistedNoiseServerFactory(Factory): protocol = TwistedNoiseServerProtocol - def __init__(self, server_key): + def __init__(self, server_key, endpoint): self.server_key = server_key + self.endpoint = endpoint class TNServerTest(unittest.TestCase): + @defer.inlineCallbacks def setUp(self): self.server_key_pair = genkeypair() + self.protos = [] + class AccProtFactory(Factory): + protocol = proto_helpers.AccumulatingProtocol + + def __init__(self, tc): + self.__tc = tc + Factory.__init__(self) + + def buildProtocol(self, addr): + r = Factory.buildProtocol(addr) + self.__tc.append(r) + return r + + for i in range(10000, 20000): + ep = endpoints.TCP4ServerEndpoint(reactor, i) + try: + lpobj = yield ep.listen(AccProtFactory(self)) + except Exception: + continue + break + else: + raise RuntimeError('all ports occupied') - factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1]) + self.testserv = ep + self.listenportobj = lpobj + self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i + factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) self.proto = factory.buildProtocol(('127.0.0.1', 0)) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) self.client_key_pair = genkeypair() - def test_testprotocol(self): + def tearDown(self): + self.listenportobj.stopListening() + + @mock.patch('twisted.internet.endpoints.clientFromString') + def test_testprotocol(self, cfs): # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() @@ -90,22 +157,41 @@ class TNServerTest(unittest.TestCase): # And process it proto.read_message(resp) + clientconnection = defer.Deferred() + cfs().connect.return_value = clientconnection + # Send second message message = proto.write_message() self.proto.dataReceived(message) - # Finish handshake + # assert handshake finished self.assertTrue(proto.handshake_finished) + # Make sure incoming data is paused till we establish client + # connection, otherwise no place to write the data + self.assertEqual(self.tr.producerState, 'paused') + + # Make sure that clientFromString is called properly + cfs.assert_called_with(reactor, self.endpoint) + + # And that it was connect'ed + cfs().connect.assert_called() + + # and that ClientProxyFactory was called properly + args = cfs().connect.call_args.args + self.assertIsInstance(args[0], ClientProxyFactory) + self.assertIs(args[0].noiseproto, self.proto) + + # Simulate that a connection has happened + remoteend = proto_helpers.StringTransport() + remoteproto = args[0].buildProtocol(None) + remoteproto.makeConnection(remoteend) + + # Encrypt the message ptmsg = b'this is a test message' encmsg = proto.encrypt(ptmsg) + # Feed it into the protocol self.proto.dataReceived(encmsg) - # Get echo - resp = self.tr.value() - self.tr.clear() - - ptresp = proto.decrypt(resp) - - self.assertEqual(ptresp, ptmsg) + self.assertEqual(remoteend.value(), ptmsg)