|
@@ -2,7 +2,7 @@ from twisted.trial import unittest |
|
|
from twisted.test import proto_helpers |
|
|
from twisted.test import proto_helpers |
|
|
from noise.connection import NoiseConnection, Keypair |
|
|
from noise.connection import NoiseConnection, Keypair |
|
|
from twisted.internet.protocol import Factory |
|
|
from twisted.internet.protocol import Factory |
|
|
from twisted.internet import endpoints, reactor,defer |
|
|
|
|
|
|
|
|
from twisted.internet import endpoints, reactor, defer, task |
|
|
# XXX - shouldn't need to access the underlying primitives, but that's what |
|
|
# XXX - shouldn't need to access the underlying primitives, but that's what |
|
|
# noiseprotocol module requires. |
|
|
# noiseprotocol module requires. |
|
|
from cryptography.hazmat.primitives.asymmetric import x448 |
|
|
from cryptography.hazmat.primitives.asymmetric import x448 |
|
@@ -70,10 +70,9 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): |
|
|
else: |
|
|
else: |
|
|
r = self.noise.decrypt(data) |
|
|
r = self.noise.decrypt(data) |
|
|
|
|
|
|
|
|
self.endpoint.write(r) |
|
|
|
|
|
|
|
|
self.endpoint.transport.write(r) |
|
|
|
|
|
|
|
|
def proxyConnected(self, endpoint): |
|
|
def proxyConnected(self, endpoint): |
|
|
print('pc') |
|
|
|
|
|
self.endpoint = endpoint |
|
|
self.endpoint = endpoint |
|
|
self.transport.resumeProducing() |
|
|
self.transport.resumeProducing() |
|
|
|
|
|
|
|
@@ -98,6 +97,8 @@ class TNServerTest(unittest.TestCase): |
|
|
def setUp(self): |
|
|
def setUp(self): |
|
|
self.server_key_pair = genkeypair() |
|
|
self.server_key_pair = genkeypair() |
|
|
self.protos = [] |
|
|
self.protos = [] |
|
|
|
|
|
self.connectionmade = defer.Deferred() |
|
|
|
|
|
|
|
|
class AccProtFactory(Factory): |
|
|
class AccProtFactory(Factory): |
|
|
protocol = proto_helpers.AccumulatingProtocol |
|
|
protocol = proto_helpers.AccumulatingProtocol |
|
|
|
|
|
|
|
@@ -105,19 +106,21 @@ class TNServerTest(unittest.TestCase): |
|
|
self.__tc = tc |
|
|
self.__tc = tc |
|
|
Factory.__init__(self) |
|
|
Factory.__init__(self) |
|
|
|
|
|
|
|
|
|
|
|
protocolConnectionMade = self.connectionmade |
|
|
|
|
|
|
|
|
def buildProtocol(self, addr): |
|
|
def buildProtocol(self, addr): |
|
|
r = Factory.buildProtocol(addr) |
|
|
|
|
|
self.__tc.append(r) |
|
|
|
|
|
|
|
|
r = Factory.buildProtocol(self, addr) |
|
|
|
|
|
self.__tc.protos.append(r) |
|
|
return r |
|
|
return r |
|
|
|
|
|
|
|
|
for i in range(10000, 20000): |
|
|
for i in range(10000, 20000): |
|
|
ep = endpoints.TCP4ServerEndpoint(reactor, i) |
|
|
|
|
|
|
|
|
ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1') |
|
|
try: |
|
|
try: |
|
|
lpobj = yield ep.listen(AccProtFactory(self)) |
|
|
lpobj = yield ep.listen(AccProtFactory(self)) |
|
|
except Exception: |
|
|
|
|
|
|
|
|
except Exception: # pragma: no cover |
|
|
continue |
|
|
continue |
|
|
break |
|
|
break |
|
|
else: |
|
|
|
|
|
|
|
|
else: # pragma: no cover |
|
|
raise RuntimeError('all ports occupied') |
|
|
raise RuntimeError('all ports occupied') |
|
|
|
|
|
|
|
|
self.testserv = ep |
|
|
self.testserv = ep |
|
@@ -133,8 +136,8 @@ class TNServerTest(unittest.TestCase): |
|
|
def tearDown(self): |
|
|
def tearDown(self): |
|
|
self.listenportobj.stopListening() |
|
|
self.listenportobj.stopListening() |
|
|
|
|
|
|
|
|
@mock.patch('twisted.internet.endpoints.clientFromString') |
|
|
|
|
|
def test_testprotocol(self, cfs): |
|
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
|
|
def test_testprotocol(self): |
|
|
# Create client |
|
|
# Create client |
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
proto.set_as_initiator() |
|
|
proto.set_as_initiator() |
|
@@ -157,9 +160,6 @@ class TNServerTest(unittest.TestCase): |
|
|
# And process it |
|
|
# And process it |
|
|
proto.read_message(resp) |
|
|
proto.read_message(resp) |
|
|
|
|
|
|
|
|
clientconnection = defer.Deferred() |
|
|
|
|
|
cfs().connect.return_value = clientconnection |
|
|
|
|
|
|
|
|
|
|
|
# Send second message |
|
|
# Send second message |
|
|
message = proto.write_message() |
|
|
message = proto.write_message() |
|
|
self.proto.dataReceived(message) |
|
|
self.proto.dataReceived(message) |
|
@@ -171,21 +171,13 @@ class TNServerTest(unittest.TestCase): |
|
|
# connection, otherwise no place to write the data |
|
|
# connection, otherwise no place to write the data |
|
|
self.assertEqual(self.tr.producerState, 'paused') |
|
|
self.assertEqual(self.tr.producerState, 'paused') |
|
|
|
|
|
|
|
|
# Make sure that clientFromString is called properly |
|
|
|
|
|
cfs.assert_called_with(reactor, self.endpoint) |
|
|
|
|
|
|
|
|
|
|
|
# And that it was connect'ed |
|
|
|
|
|
cfs().connect.assert_called() |
|
|
|
|
|
|
|
|
# Wait for the connection to be made |
|
|
|
|
|
d = yield self.connectionmade |
|
|
|
|
|
|
|
|
# and that ClientProxyFactory was called properly |
|
|
|
|
|
args = cfs().connect.call_args.args |
|
|
|
|
|
self.assertIsInstance(args[0], ClientProxyFactory) |
|
|
|
|
|
self.assertIs(args[0].noiseproto, self.proto) |
|
|
|
|
|
|
|
|
d = yield task.deferLater(reactor, .1, bool, 1) |
|
|
|
|
|
|
|
|
# Simulate that a connection has happened |
|
|
|
|
|
remoteend = proto_helpers.StringTransport() |
|
|
|
|
|
remoteproto = args[0].buildProtocol(None) |
|
|
|
|
|
remoteproto.makeConnection(remoteend) |
|
|
|
|
|
|
|
|
# How to make this ready? |
|
|
|
|
|
self.assertEqual(self.tr.producerState, 'producing') |
|
|
|
|
|
|
|
|
# Encrypt the message |
|
|
# Encrypt the message |
|
|
ptmsg = b'this is a test message' |
|
|
ptmsg = b'this is a test message' |
|
@@ -194,4 +186,9 @@ class TNServerTest(unittest.TestCase): |
|
|
# Feed it into the protocol |
|
|
# Feed it into the protocol |
|
|
self.proto.dataReceived(encmsg) |
|
|
self.proto.dataReceived(encmsg) |
|
|
|
|
|
|
|
|
self.assertEqual(remoteend.value(), ptmsg) |
|
|
|
|
|
|
|
|
d = yield task.deferLater(reactor, .1, bool, 1) |
|
|
|
|
|
|
|
|
|
|
|
clientend = self.protos[0] |
|
|
|
|
|
self.assertEqual(clientend.data, ptmsg) |
|
|
|
|
|
|
|
|
|
|
|
clientend.transport.loseConnection() |