|
|
@@ -2,17 +2,19 @@ from twisted.trial import unittest |
|
|
|
from twisted.test import proto_helpers |
|
|
|
from noise.connection import NoiseConnection, Keypair |
|
|
|
from twisted.internet.protocol import Factory |
|
|
|
from twisted.internet import endpoints, reactor,defer |
|
|
|
# XXX - shouldn't need to access the underlying primitives, but that's what |
|
|
|
# noiseprotocol module requires. |
|
|
|
from cryptography.hazmat.primitives.asymmetric import x448 |
|
|
|
from cryptography.hazmat.primitives import serialization |
|
|
|
|
|
|
|
import twisted.internet.protocol |
|
|
|
import mock |
|
|
|
|
|
|
|
# Notes: |
|
|
|
# Using XK, so that the connecting party's identity is hidden and that the |
|
|
|
# server's party's key is known. |
|
|
|
|
|
|
|
import twisted.internet.protocol |
|
|
|
|
|
|
|
def genkeypair(): |
|
|
|
'''Generates a keypair, and returns a tuple of (public, private). |
|
|
|
They are encoded as raw bytes, and sutible for use w/ Noise.''' |
|
|
@@ -30,6 +32,19 @@ def genkeypair(): |
|
|
|
return pub, priv |
|
|
|
|
|
|
|
class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): |
|
|
|
'''This class acts as a Noise Protocol responder. The factory that |
|
|
|
creates this Protocol is required to have the properties server_key |
|
|
|
and endpoint. |
|
|
|
|
|
|
|
The server_key propery is the key for the server that the clients are |
|
|
|
required to have (due to Noise XK protocol used) to authenticate the |
|
|
|
server. |
|
|
|
|
|
|
|
The endpoint property contains the endpoint as a string that will be |
|
|
|
used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString |
|
|
|
and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients |
|
|
|
for information on how to use this property.''' |
|
|
|
|
|
|
|
def connectionMade(self): |
|
|
|
# Initialize Noise |
|
|
|
noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
@@ -45,29 +60,81 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): |
|
|
|
self.noise.read_message(data) |
|
|
|
if not self.noise.handshake_finished: |
|
|
|
self.transport.write(self.noise.write_message()) |
|
|
|
if self.noise.handshake_finished: |
|
|
|
self.transport.pauseProducing() |
|
|
|
|
|
|
|
# start the connection to the endpoint |
|
|
|
ep = endpoints.clientFromString(reactor, self.factory.endpoint) |
|
|
|
epdef = ep.connect(ClientProxyFactory(self)) |
|
|
|
epdef.addCallback(self.proxyConnected) |
|
|
|
else: |
|
|
|
r = self.noise.decrypt(data) |
|
|
|
# echo it |
|
|
|
self.transport.write(self.noise.encrypt(r)) |
|
|
|
|
|
|
|
self.endpoint.write(r) |
|
|
|
|
|
|
|
def proxyConnected(self, endpoint): |
|
|
|
print('pc') |
|
|
|
self.endpoint = endpoint |
|
|
|
self.transport.resumeProducing() |
|
|
|
|
|
|
|
class ClientProxyProtocol(twisted.internet.protocol.Protocol): |
|
|
|
pass |
|
|
|
|
|
|
|
class ClientProxyFactory(Factory): |
|
|
|
protocol = ClientProxyProtocol |
|
|
|
|
|
|
|
def __init__(self, noiseproto): |
|
|
|
self.noiseproto = noiseproto |
|
|
|
|
|
|
|
class TwistedNoiseServerFactory(Factory): |
|
|
|
protocol = TwistedNoiseServerProtocol |
|
|
|
|
|
|
|
def __init__(self, server_key): |
|
|
|
def __init__(self, server_key, endpoint): |
|
|
|
self.server_key = server_key |
|
|
|
self.endpoint = endpoint |
|
|
|
|
|
|
|
class TNServerTest(unittest.TestCase): |
|
|
|
@defer.inlineCallbacks |
|
|
|
def setUp(self): |
|
|
|
self.server_key_pair = genkeypair() |
|
|
|
self.protos = [] |
|
|
|
class AccProtFactory(Factory): |
|
|
|
protocol = proto_helpers.AccumulatingProtocol |
|
|
|
|
|
|
|
def __init__(self, tc): |
|
|
|
self.__tc = tc |
|
|
|
Factory.__init__(self) |
|
|
|
|
|
|
|
def buildProtocol(self, addr): |
|
|
|
r = Factory.buildProtocol(addr) |
|
|
|
self.__tc.append(r) |
|
|
|
return r |
|
|
|
|
|
|
|
for i in range(10000, 20000): |
|
|
|
ep = endpoints.TCP4ServerEndpoint(reactor, i) |
|
|
|
try: |
|
|
|
lpobj = yield ep.listen(AccProtFactory(self)) |
|
|
|
except Exception: |
|
|
|
continue |
|
|
|
break |
|
|
|
else: |
|
|
|
raise RuntimeError('all ports occupied') |
|
|
|
|
|
|
|
factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1]) |
|
|
|
self.testserv = ep |
|
|
|
self.listenportobj = lpobj |
|
|
|
self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i |
|
|
|
factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) |
|
|
|
self.proto = factory.buildProtocol(('127.0.0.1', 0)) |
|
|
|
self.tr = proto_helpers.StringTransport() |
|
|
|
self.proto.makeConnection(self.tr) |
|
|
|
|
|
|
|
self.client_key_pair = genkeypair() |
|
|
|
|
|
|
|
def test_testprotocol(self): |
|
|
|
def tearDown(self): |
|
|
|
self.listenportobj.stopListening() |
|
|
|
|
|
|
|
@mock.patch('twisted.internet.endpoints.clientFromString') |
|
|
|
def test_testprotocol(self, cfs): |
|
|
|
# Create client |
|
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
proto.set_as_initiator() |
|
|
@@ -90,22 +157,41 @@ class TNServerTest(unittest.TestCase): |
|
|
|
# And process it |
|
|
|
proto.read_message(resp) |
|
|
|
|
|
|
|
clientconnection = defer.Deferred() |
|
|
|
cfs().connect.return_value = clientconnection |
|
|
|
|
|
|
|
# Send second message |
|
|
|
message = proto.write_message() |
|
|
|
self.proto.dataReceived(message) |
|
|
|
|
|
|
|
# Finish handshake |
|
|
|
# assert handshake finished |
|
|
|
self.assertTrue(proto.handshake_finished) |
|
|
|
|
|
|
|
# Make sure incoming data is paused till we establish client |
|
|
|
# connection, otherwise no place to write the data |
|
|
|
self.assertEqual(self.tr.producerState, 'paused') |
|
|
|
|
|
|
|
# Make sure that clientFromString is called properly |
|
|
|
cfs.assert_called_with(reactor, self.endpoint) |
|
|
|
|
|
|
|
# And that it was connect'ed |
|
|
|
cfs().connect.assert_called() |
|
|
|
|
|
|
|
# and that ClientProxyFactory was called properly |
|
|
|
args = cfs().connect.call_args.args |
|
|
|
self.assertIsInstance(args[0], ClientProxyFactory) |
|
|
|
self.assertIs(args[0].noiseproto, self.proto) |
|
|
|
|
|
|
|
# Simulate that a connection has happened |
|
|
|
remoteend = proto_helpers.StringTransport() |
|
|
|
remoteproto = args[0].buildProtocol(None) |
|
|
|
remoteproto.makeConnection(remoteend) |
|
|
|
|
|
|
|
# Encrypt the message |
|
|
|
ptmsg = b'this is a test message' |
|
|
|
encmsg = proto.encrypt(ptmsg) |
|
|
|
|
|
|
|
# Feed it into the protocol |
|
|
|
self.proto.dataReceived(encmsg) |
|
|
|
|
|
|
|
# Get echo |
|
|
|
resp = self.tr.value() |
|
|
|
self.tr.clear() |
|
|
|
|
|
|
|
ptresp = proto.decrypt(resp) |
|
|
|
|
|
|
|
self.assertEqual(ptresp, ptmsg) |
|
|
|
self.assertEqual(remoteend.value(), ptmsg) |