|
- from twisted.trial import unittest
- from twisted.test import proto_helpers
- from noise.connection import NoiseConnection, Keypair
- from twisted.internet.protocol import Factory
- from twisted.internet import endpoints, reactor, defer, task
- # XXX - shouldn't need to access the underlying primitives, but that's what
- # noiseprotocol module requires.
- from cryptography.hazmat.primitives.asymmetric import x448
- from cryptography.hazmat.primitives import serialization
-
- import mock
- import os.path
- import shutil
- import tempfile
- import twisted.internet.protocol
-
- # Notes:
- # Using XK, so that the connecting party's identity is hidden and that the
- # server's party's key is known.
-
- def genkeypair():
- '''Generates a keypair, and returns a tuple of (public, private).
- They are encoded as raw bytes, and sutible for use w/ Noise.'''
-
- key = x448.X448PrivateKey.generate()
-
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
-
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
-
- return pub, priv
-
- class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
- '''This class acts as a Noise Protocol responder. The factory that
- creates this Protocol is required to have the properties server_key
- and endpoint.
-
- The server_key propery is the key for the server that the clients are
- required to have (due to Noise XK protocol used) to authenticate the
- server.
-
- The endpoint property contains the endpoint as a string that will be
- used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
- and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
- for information on how to use this property.'''
-
- def connectionMade(self):
- # Initialize Noise
- noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- self.noise = noise
- noise.set_as_responder()
- noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
-
- # Start Handshake
- noise.start_handshake()
-
- def dataReceived(self, data):
- if not self.noise.handshake_finished:
- self.noise.read_message(data)
- if not self.noise.handshake_finished:
- self.transport.write(self.noise.write_message())
- if self.noise.handshake_finished:
- self.transport.pauseProducing()
-
- # start the connection to the endpoint
- ep = endpoints.clientFromString(reactor, self.factory.endpoint)
- epdef = ep.connect(ClientProxyFactory(self))
- epdef.addCallback(self.proxyConnected)
- else:
- r = self.noise.decrypt(data)
-
- self.endpoint.transport.write(r)
-
- def proxyConnected(self, endpoint):
- self.endpoint = endpoint
- self.transport.resumeProducing()
-
- class ClientProxyProtocol(twisted.internet.protocol.Protocol):
- pass
-
- class ClientProxyFactory(Factory):
- protocol = ClientProxyProtocol
-
- def __init__(self, noiseproto):
- self.noiseproto = noiseproto
-
- class TwistedNoiseServerFactory(Factory):
- protocol = TwistedNoiseServerProtocol
-
- def __init__(self, server_key, endpoint):
- self.server_key = server_key
- self.endpoint = endpoint
-
- class TNServerTest(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
-
- self.server_key_pair = genkeypair()
- self.protos = []
- self.connectionmade = defer.Deferred()
-
- class AccProtFactory(Factory):
- protocol = proto_helpers.AccumulatingProtocol
-
- def __init__(self, tc):
- self.__tc = tc
- Factory.__init__(self)
-
- protocolConnectionMade = self.connectionmade
-
- def buildProtocol(self, addr):
- r = Factory.buildProtocol(self, addr)
- self.__tc.protos.append(r)
- return r
-
- sockpath = os.path.join(self.tempdir, 'clientsock')
- ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
- lpobj = yield ep.listen(AccProtFactory(self))
-
- self.testserv = ep
- self.listenportobj = lpobj
- self.endpoint = 'unix:path=%s' % sockpath
-
- factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
- self.proto = factory.buildProtocol(None)
- self.tr = proto_helpers.StringTransport()
- self.proto.makeConnection(self.tr)
-
- self.client_key_pair = genkeypair()
-
- def tearDown(self):
- self.listenportobj.stopListening()
-
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
-
- @defer.inlineCallbacks
- def test_testprotocol(self):
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
-
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.proto.dataReceived(message)
-
- # Get response
- resp = self.tr.value()
- self.tr.clear()
-
- # And process it
- proto.read_message(resp)
-
- # Send second message
- message = proto.write_message()
- self.proto.dataReceived(message)
-
- # assert handshake finished
- self.assertTrue(proto.handshake_finished)
-
- # Make sure incoming data is paused till we establish client
- # connection, otherwise no place to write the data
- self.assertEqual(self.tr.producerState, 'paused')
-
- # Wait for the connection to be made
- d = yield self.connectionmade
-
- d = yield task.deferLater(reactor, .1, bool, 1)
-
- # How to make this ready?
- self.assertEqual(self.tr.producerState, 'producing')
-
- # Encrypt the message
- ptmsg = b'this is a test message'
- encmsg = proto.encrypt(ptmsg)
-
- # Feed it into the protocol
- self.proto.dataReceived(encmsg)
-
- d = yield task.deferLater(reactor, .1, bool, 1)
-
- clientend = self.protos[0]
- self.assertEqual(clientend.data, ptmsg)
-
- clientend.transport.loseConnection()
|