| @@ -2,7 +2,7 @@ from twisted.trial import unittest | |||||
| from twisted.test import proto_helpers | from twisted.test import proto_helpers | ||||
| from noise.connection import NoiseConnection, Keypair | from noise.connection import NoiseConnection, Keypair | ||||
| from twisted.internet.protocol import Factory | from twisted.internet.protocol import Factory | ||||
| from twisted.internet import endpoints, reactor,defer | |||||
| from twisted.internet import endpoints, reactor, defer, task | |||||
| # XXX - shouldn't need to access the underlying primitives, but that's what | # XXX - shouldn't need to access the underlying primitives, but that's what | ||||
| # noiseprotocol module requires. | # noiseprotocol module requires. | ||||
| from cryptography.hazmat.primitives.asymmetric import x448 | from cryptography.hazmat.primitives.asymmetric import x448 | ||||
| @@ -70,10 +70,9 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): | |||||
| else: | else: | ||||
| r = self.noise.decrypt(data) | r = self.noise.decrypt(data) | ||||
| self.endpoint.write(r) | |||||
| self.endpoint.transport.write(r) | |||||
| def proxyConnected(self, endpoint): | def proxyConnected(self, endpoint): | ||||
| print('pc') | |||||
| self.endpoint = endpoint | self.endpoint = endpoint | ||||
| self.transport.resumeProducing() | self.transport.resumeProducing() | ||||
| @@ -98,6 +97,8 @@ class TNServerTest(unittest.TestCase): | |||||
| def setUp(self): | def setUp(self): | ||||
| self.server_key_pair = genkeypair() | self.server_key_pair = genkeypair() | ||||
| self.protos = [] | self.protos = [] | ||||
| self.connectionmade = defer.Deferred() | |||||
| class AccProtFactory(Factory): | class AccProtFactory(Factory): | ||||
| protocol = proto_helpers.AccumulatingProtocol | protocol = proto_helpers.AccumulatingProtocol | ||||
| @@ -105,19 +106,21 @@ class TNServerTest(unittest.TestCase): | |||||
| self.__tc = tc | self.__tc = tc | ||||
| Factory.__init__(self) | Factory.__init__(self) | ||||
| protocolConnectionMade = self.connectionmade | |||||
| def buildProtocol(self, addr): | def buildProtocol(self, addr): | ||||
| r = Factory.buildProtocol(addr) | |||||
| self.__tc.append(r) | |||||
| r = Factory.buildProtocol(self, addr) | |||||
| self.__tc.protos.append(r) | |||||
| return r | return r | ||||
| for i in range(10000, 20000): | for i in range(10000, 20000): | ||||
| ep = endpoints.TCP4ServerEndpoint(reactor, i) | |||||
| ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1') | |||||
| try: | try: | ||||
| lpobj = yield ep.listen(AccProtFactory(self)) | lpobj = yield ep.listen(AccProtFactory(self)) | ||||
| except Exception: | |||||
| except Exception: # pragma: no cover | |||||
| continue | continue | ||||
| break | break | ||||
| else: | |||||
| else: # pragma: no cover | |||||
| raise RuntimeError('all ports occupied') | raise RuntimeError('all ports occupied') | ||||
| self.testserv = ep | self.testserv = ep | ||||
| @@ -133,8 +136,8 @@ class TNServerTest(unittest.TestCase): | |||||
| def tearDown(self): | def tearDown(self): | ||||
| self.listenportobj.stopListening() | self.listenportobj.stopListening() | ||||
| @mock.patch('twisted.internet.endpoints.clientFromString') | |||||
| def test_testprotocol(self, cfs): | |||||
| @defer.inlineCallbacks | |||||
| def test_testprotocol(self): | |||||
| # Create client | # Create client | ||||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | ||||
| proto.set_as_initiator() | proto.set_as_initiator() | ||||
| @@ -157,9 +160,6 @@ class TNServerTest(unittest.TestCase): | |||||
| # And process it | # And process it | ||||
| proto.read_message(resp) | proto.read_message(resp) | ||||
| clientconnection = defer.Deferred() | |||||
| cfs().connect.return_value = clientconnection | |||||
| # Send second message | # Send second message | ||||
| message = proto.write_message() | message = proto.write_message() | ||||
| self.proto.dataReceived(message) | self.proto.dataReceived(message) | ||||
| @@ -171,21 +171,13 @@ class TNServerTest(unittest.TestCase): | |||||
| # connection, otherwise no place to write the data | # connection, otherwise no place to write the data | ||||
| self.assertEqual(self.tr.producerState, 'paused') | self.assertEqual(self.tr.producerState, 'paused') | ||||
| # Make sure that clientFromString is called properly | |||||
| cfs.assert_called_with(reactor, self.endpoint) | |||||
| # And that it was connect'ed | |||||
| cfs().connect.assert_called() | |||||
| # Wait for the connection to be made | |||||
| d = yield self.connectionmade | |||||
| # and that ClientProxyFactory was called properly | |||||
| args = cfs().connect.call_args.args | |||||
| self.assertIsInstance(args[0], ClientProxyFactory) | |||||
| self.assertIs(args[0].noiseproto, self.proto) | |||||
| d = yield task.deferLater(reactor, .1, bool, 1) | |||||
| # Simulate that a connection has happened | |||||
| remoteend = proto_helpers.StringTransport() | |||||
| remoteproto = args[0].buildProtocol(None) | |||||
| remoteproto.makeConnection(remoteend) | |||||
| # How to make this ready? | |||||
| self.assertEqual(self.tr.producerState, 'producing') | |||||
| # Encrypt the message | # Encrypt the message | ||||
| ptmsg = b'this is a test message' | ptmsg = b'this is a test message' | ||||
| @@ -194,4 +186,9 @@ class TNServerTest(unittest.TestCase): | |||||
| # Feed it into the protocol | # Feed it into the protocol | ||||
| self.proto.dataReceived(encmsg) | self.proto.dataReceived(encmsg) | ||||
| self.assertEqual(remoteend.value(), ptmsg) | |||||
| d = yield task.deferLater(reactor, .1, bool, 1) | |||||
| clientend = self.protos[0] | |||||
| self.assertEqual(clientend.data, ptmsg) | |||||
| clientend.transport.loseConnection() | |||||