from twisted.trial import unittest from twisted.test import proto_helpers from noise.connection import NoiseConnection, Keypair from twisted.internet.protocol import Factory from twisted.internet import endpoints, reactor,defer # XXX - shouldn't need to access the underlying primitives, but that's what # noiseprotocol module requires. from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives import serialization import twisted.internet.protocol import mock # Notes: # Using XK, so that the connecting party's identity is hidden and that the # server's party's key is known. def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): '''This class acts as a Noise Protocol responder. The factory that creates this Protocol is required to have the properties server_key and endpoint. The server_key propery is the key for the server that the clients are required to have (due to Noise XK protocol used) to authenticate the server. The endpoint property contains the endpoint as a string that will be used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients for information on how to use this property.''' def connectionMade(self): # Initialize Noise noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') self.noise = noise noise.set_as_responder() noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) # Start Handshake noise.start_handshake() def dataReceived(self, data): if not self.noise.handshake_finished: self.noise.read_message(data) if not self.noise.handshake_finished: self.transport.write(self.noise.write_message()) if self.noise.handshake_finished: self.transport.pauseProducing() # start the connection to the endpoint ep = endpoints.clientFromString(reactor, self.factory.endpoint) epdef = ep.connect(ClientProxyFactory(self)) epdef.addCallback(self.proxyConnected) else: r = self.noise.decrypt(data) self.endpoint.write(r) def proxyConnected(self, endpoint): print('pc') self.endpoint = endpoint self.transport.resumeProducing() class ClientProxyProtocol(twisted.internet.protocol.Protocol): pass class ClientProxyFactory(Factory): protocol = ClientProxyProtocol def __init__(self, noiseproto): self.noiseproto = noiseproto class TwistedNoiseServerFactory(Factory): protocol = TwistedNoiseServerProtocol def __init__(self, server_key, endpoint): self.server_key = server_key self.endpoint = endpoint class TNServerTest(unittest.TestCase): @defer.inlineCallbacks def setUp(self): self.server_key_pair = genkeypair() self.protos = [] class AccProtFactory(Factory): protocol = proto_helpers.AccumulatingProtocol def __init__(self, tc): self.__tc = tc Factory.__init__(self) def buildProtocol(self, addr): r = Factory.buildProtocol(addr) self.__tc.append(r) return r for i in range(10000, 20000): ep = endpoints.TCP4ServerEndpoint(reactor, i) try: lpobj = yield ep.listen(AccProtFactory(self)) except Exception: continue break else: raise RuntimeError('all ports occupied') self.testserv = ep self.listenportobj = lpobj self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) self.proto = factory.buildProtocol(('127.0.0.1', 0)) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) self.client_key_pair = genkeypair() def tearDown(self): self.listenportobj.stopListening() @mock.patch('twisted.internet.endpoints.clientFromString') def test_testprotocol(self, cfs): # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.start_handshake() # Send first message message = proto.write_message() self.proto.dataReceived(message) # Get response resp = self.tr.value() self.tr.clear() # And process it proto.read_message(resp) clientconnection = defer.Deferred() cfs().connect.return_value = clientconnection # Send second message message = proto.write_message() self.proto.dataReceived(message) # assert handshake finished self.assertTrue(proto.handshake_finished) # Make sure incoming data is paused till we establish client # connection, otherwise no place to write the data self.assertEqual(self.tr.producerState, 'paused') # Make sure that clientFromString is called properly cfs.assert_called_with(reactor, self.endpoint) # And that it was connect'ed cfs().connect.assert_called() # and that ClientProxyFactory was called properly args = cfs().connect.call_args.args self.assertIsInstance(args[0], ClientProxyFactory) self.assertIs(args[0].noiseproto, self.proto) # Simulate that a connection has happened remoteend = proto_helpers.StringTransport() remoteproto = args[0].buildProtocol(None) remoteproto.makeConnection(remoteend) # Encrypt the message ptmsg = b'this is a test message' encmsg = proto.encrypt(ptmsg) # Feed it into the protocol self.proto.dataReceived(encmsg) self.assertEqual(remoteend.value(), ptmsg)