|
- from twisted.trial import unittest
- from twisted.test import proto_helpers
- from noise.connection import NoiseConnection, Keypair
- from twisted.internet.protocol import Factory
- from twisted.internet import endpoints, reactor,defer
- # XXX - shouldn't need to access the underlying primitives, but that's what
- # noiseprotocol module requires.
- from cryptography.hazmat.primitives.asymmetric import x448
- from cryptography.hazmat.primitives import serialization
-
- import twisted.internet.protocol
- import mock
-
- # Notes:
- # Using XK, so that the connecting party's identity is hidden and that the
- # server's party's key is known.
-
- def genkeypair():
- '''Generates a keypair, and returns a tuple of (public, private).
- They are encoded as raw bytes, and sutible for use w/ Noise.'''
-
- key = x448.X448PrivateKey.generate()
-
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
-
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
-
- return pub, priv
-
- class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
- '''This class acts as a Noise Protocol responder. The factory that
- creates this Protocol is required to have the properties server_key
- and endpoint.
-
- The server_key propery is the key for the server that the clients are
- required to have (due to Noise XK protocol used) to authenticate the
- server.
-
- The endpoint property contains the endpoint as a string that will be
- used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
- and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
- for information on how to use this property.'''
-
- def connectionMade(self):
- # Initialize Noise
- noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- self.noise = noise
- noise.set_as_responder()
- noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
-
- # Start Handshake
- noise.start_handshake()
-
- def dataReceived(self, data):
- if not self.noise.handshake_finished:
- self.noise.read_message(data)
- if not self.noise.handshake_finished:
- self.transport.write(self.noise.write_message())
- if self.noise.handshake_finished:
- self.transport.pauseProducing()
-
- # start the connection to the endpoint
- ep = endpoints.clientFromString(reactor, self.factory.endpoint)
- epdef = ep.connect(ClientProxyFactory(self))
- epdef.addCallback(self.proxyConnected)
- else:
- r = self.noise.decrypt(data)
-
- self.endpoint.write(r)
-
- def proxyConnected(self, endpoint):
- print('pc')
- self.endpoint = endpoint
- self.transport.resumeProducing()
-
- class ClientProxyProtocol(twisted.internet.protocol.Protocol):
- pass
-
- class ClientProxyFactory(Factory):
- protocol = ClientProxyProtocol
-
- def __init__(self, noiseproto):
- self.noiseproto = noiseproto
-
- class TwistedNoiseServerFactory(Factory):
- protocol = TwistedNoiseServerProtocol
-
- def __init__(self, server_key, endpoint):
- self.server_key = server_key
- self.endpoint = endpoint
-
- class TNServerTest(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- self.server_key_pair = genkeypair()
- self.protos = []
- class AccProtFactory(Factory):
- protocol = proto_helpers.AccumulatingProtocol
-
- def __init__(self, tc):
- self.__tc = tc
- Factory.__init__(self)
-
- def buildProtocol(self, addr):
- r = Factory.buildProtocol(addr)
- self.__tc.append(r)
- return r
-
- for i in range(10000, 20000):
- ep = endpoints.TCP4ServerEndpoint(reactor, i)
- try:
- lpobj = yield ep.listen(AccProtFactory(self))
- except Exception:
- continue
- break
- else:
- raise RuntimeError('all ports occupied')
-
- self.testserv = ep
- self.listenportobj = lpobj
- self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i
- factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
- self.proto = factory.buildProtocol(('127.0.0.1', 0))
- self.tr = proto_helpers.StringTransport()
- self.proto.makeConnection(self.tr)
-
- self.client_key_pair = genkeypair()
-
- def tearDown(self):
- self.listenportobj.stopListening()
-
- @mock.patch('twisted.internet.endpoints.clientFromString')
- def test_testprotocol(self, cfs):
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
-
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.proto.dataReceived(message)
-
- # Get response
- resp = self.tr.value()
- self.tr.clear()
-
- # And process it
- proto.read_message(resp)
-
- clientconnection = defer.Deferred()
- cfs().connect.return_value = clientconnection
-
- # Send second message
- message = proto.write_message()
- self.proto.dataReceived(message)
-
- # assert handshake finished
- self.assertTrue(proto.handshake_finished)
-
- # Make sure incoming data is paused till we establish client
- # connection, otherwise no place to write the data
- self.assertEqual(self.tr.producerState, 'paused')
-
- # Make sure that clientFromString is called properly
- cfs.assert_called_with(reactor, self.endpoint)
-
- # And that it was connect'ed
- cfs().connect.assert_called()
-
- # and that ClientProxyFactory was called properly
- args = cfs().connect.call_args.args
- self.assertIsInstance(args[0], ClientProxyFactory)
- self.assertIs(args[0].noiseproto, self.proto)
-
- # Simulate that a connection has happened
- remoteend = proto_helpers.StringTransport()
- remoteproto = args[0].buildProtocol(None)
- remoteproto.makeConnection(remoteend)
-
- # Encrypt the message
- ptmsg = b'this is a test message'
- encmsg = proto.encrypt(ptmsg)
-
- # Feed it into the protocol
- self.proto.dataReceived(encmsg)
-
- self.assertEqual(remoteend.value(), ptmsg)
|