diff --git a/twistednoise.py b/twistednoise.py index 65f8f9f..4343ae4 100644 --- a/twistednoise.py +++ b/twistednoise.py @@ -2,7 +2,7 @@ from twisted.trial import unittest from twisted.test import proto_helpers from noise.connection import NoiseConnection, Keypair from twisted.internet.protocol import Factory -from twisted.internet import endpoints, reactor,defer +from twisted.internet import endpoints, reactor, defer, task # XXX - shouldn't need to access the underlying primitives, but that's what # noiseprotocol module requires. from cryptography.hazmat.primitives.asymmetric import x448 @@ -70,10 +70,9 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): else: r = self.noise.decrypt(data) - self.endpoint.write(r) + self.endpoint.transport.write(r) def proxyConnected(self, endpoint): - print('pc') self.endpoint = endpoint self.transport.resumeProducing() @@ -98,6 +97,8 @@ class TNServerTest(unittest.TestCase): def setUp(self): self.server_key_pair = genkeypair() self.protos = [] + self.connectionmade = defer.Deferred() + class AccProtFactory(Factory): protocol = proto_helpers.AccumulatingProtocol @@ -105,19 +106,21 @@ class TNServerTest(unittest.TestCase): self.__tc = tc Factory.__init__(self) + protocolConnectionMade = self.connectionmade + def buildProtocol(self, addr): - r = Factory.buildProtocol(addr) - self.__tc.append(r) + r = Factory.buildProtocol(self, addr) + self.__tc.protos.append(r) return r for i in range(10000, 20000): - ep = endpoints.TCP4ServerEndpoint(reactor, i) + ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1') try: lpobj = yield ep.listen(AccProtFactory(self)) - except Exception: + except Exception: # pragma: no cover continue break - else: + else: # pragma: no cover raise RuntimeError('all ports occupied') self.testserv = ep @@ -133,8 +136,8 @@ class TNServerTest(unittest.TestCase): def tearDown(self): self.listenportobj.stopListening() - @mock.patch('twisted.internet.endpoints.clientFromString') - def test_testprotocol(self, cfs): + @defer.inlineCallbacks + def test_testprotocol(self): # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() @@ -157,9 +160,6 @@ class TNServerTest(unittest.TestCase): # And process it proto.read_message(resp) - clientconnection = defer.Deferred() - cfs().connect.return_value = clientconnection - # Send second message message = proto.write_message() self.proto.dataReceived(message) @@ -171,21 +171,13 @@ class TNServerTest(unittest.TestCase): # connection, otherwise no place to write the data self.assertEqual(self.tr.producerState, 'paused') - # Make sure that clientFromString is called properly - cfs.assert_called_with(reactor, self.endpoint) - - # And that it was connect'ed - cfs().connect.assert_called() + # Wait for the connection to be made + d = yield self.connectionmade - # and that ClientProxyFactory was called properly - args = cfs().connect.call_args.args - self.assertIsInstance(args[0], ClientProxyFactory) - self.assertIs(args[0].noiseproto, self.proto) + d = yield task.deferLater(reactor, .1, bool, 1) - # Simulate that a connection has happened - remoteend = proto_helpers.StringTransport() - remoteproto = args[0].buildProtocol(None) - remoteproto.makeConnection(remoteend) + # How to make this ready? + self.assertEqual(self.tr.producerState, 'producing') # Encrypt the message ptmsg = b'this is a test message' @@ -194,4 +186,9 @@ class TNServerTest(unittest.TestCase): # Feed it into the protocol self.proto.dataReceived(encmsg) - self.assertEqual(remoteend.value(), ptmsg) + d = yield task.deferLater(reactor, .1, bool, 1) + + clientend = self.protos[0] + self.assertEqual(clientend.data, ptmsg) + + clientend.transport.loseConnection()