from noise.connection import NoiseConnection, Keypair from twisted.trial import unittest from twisted.test import proto_helpers from twisted.internet.protocol import Factory from twisted.internet import endpoints, reactor, defer, task # XXX - shouldn't need to access the underlying primitives, but that's what # noiseprotocol module requires. from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives import serialization import mock import os.path import shutil import tempfile import twisted.internet.protocol __author__ = 'John-Mark Gurney' __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.' __license__ = '2-clause BSD license' # Copyright 2019 John-Mark Gurney. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # Notes: # Using XK, so that the connecting party's identity is hidden and that the # server's party's key is known. # # The client and server class names are used to refer to the initiator and # responder sides. Even though both client and server each have a server # component (listen in on a socket to start comms), and a client component # (create connection). # # Noise packets are 16 bytes + length of data # # Proposed method to hide message lengths: # Immediately after handshake completes, each side generates and sends # an n byte key that will be used for encrypting (algo tbd) their own # byte counts. The length field will be encrypted via # E(pktnum, key) XOR 2 byte length. # # Note that authenticating the message length is NOT needed. This is # because the noise message blocks themselves are authenticated. The # worse that could happen is that a larger read (64k) is done, and then # the connection aborts because of decryption failure. # def _makeunix(path): '''Make a properly formed unix path socket string.''' return 'unix:%s' % path def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): '''This class acts as a Noise Protocol responder. The factory that creates this Protocol is required to have the properties server_key and endpoint. The server_key propery is the key for the server that the clients are required to have (due to Noise XK protocol used) to authenticate the server. The endpoint property contains the endpoint as a string that will be used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients for information on how to use this property.''' def connectionMade(self): # Initialize Noise noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') self.noise = noise if self.mode == 'resp': noise.set_as_responder() elif self.mode == 'init': noise.set_as_initiator() noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) if hasattr(self.factory, 'pub_key'): noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key) # Start Handshake noise.start_handshake() def encData(self, data): '''Receive plain text data, encrypt it, and send it down the wire.''' self.transport.write(self.noise.encrypt(data)) def dataReceived(self, data): '''Receive encrypted data, and write it to the endpoint that was connected via the plaintextConnected method.''' if not self.noise.handshake_finished: self.noise.read_message(data) if not self.noise.handshake_finished: self.transport.write(self.noise.write_message()) if self.noise.handshake_finished: self.handshakeFinished() else: r = self.noise.decrypt(data) self.endpoint.transport.write(r) def handshakeFinished(self): # pragma: no cover '''This function is called when the handshake has been completed. This is used to start data flowing, and to do any necessary connection work.''' raise NotImplementedError def plaintextConnected(self, endpoint): '''Connect the plain text endpoint to the factory. All the decrypted data will be written to this protocol, (specifically, it's transport).''' self.endpoint = endpoint self.transport.resumeProducing() class TwistedNoiseServerProtocol(TwistedNoiseProtocol): mode = 'resp' def handshakeFinished(self): self.transport.pauseProducing() # start the connection to the endpoint ep = endpoints.clientFromString(reactor, self.factory.endpoint) epdef = ep.connect(ServerPTProxyFactory(self)) epdef.addCallback(self.plaintextConnected) class TwistedNoiseServerFactory(Factory): protocol = TwistedNoiseServerProtocol def __init__(self, priv_key, endpoint): self.priv_key = priv_key self.endpoint = endpoint # Supporting classes for TwistedNoiseServer class PTProxyProtocol(twisted.internet.protocol.Protocol): '''Simple protocol then when data is received, encrypts the data w/ the connected noise protocol.''' def dataReceived(self, data): self.factory.noiseproto.encData(data) class ServerPTProxyFactory(Factory): protocol = PTProxyProtocol def __init__(self, noiseproto): self.noiseproto = noiseproto class TwistedNoiseClientProtocol(TwistedNoiseProtocol): mode = 'init' class ClientPTFactory(Factory): protocol = TwistedNoiseClientProtocol def __init__(self, priv_key, servpub, sockstr): self.priv_key = priv_key self.pub_key = servpub self.sockstr = sockstr class TNServerTest(unittest.TestCase): @defer.inlineCallbacks def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() # Server's PT client will be here self.protos = [] self.connectionmade = defer.Deferred() class AccProtFactory(Factory): protocol = proto_helpers.AccumulatingProtocol def __init__(self, tc): self.__tc = tc Factory.__init__(self) protocolConnectionMade = self.connectionmade def buildProtocol(self, addr): r = Factory.buildProtocol(self, addr) self.__tc.protos.append(r) return r self.AccProtFactory = AccProtFactory # Setup PT client endpoint sockpath = os.path.join(self.tempdir, 'servptsock') ep = endpoints.UNIXServerEndpoint(reactor, sockpath) lpobj = yield ep.listen(AccProtFactory(self)) self.testserv = ep self.listenportobj = lpobj self.endpoint = _makeunix(sockpath) # Setup server, and configure where to connect to. self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) @defer.inlineCallbacks def tearDown(self): d = yield self.listenportobj.stopListening() shutil.rmtree(self.basetempdir) self.tempdir = None @defer.inlineCallbacks def test_testserver(self): # # How this test is plumbed: # # proto (NoiseConnection) -> self.tr (StringTransport) -> # self.proto (TwistedNoiseServerProtocol) -> # self.proto.endpoint (PTProxyProtocol) -> unix sock -> # self.protos[0] (AccumulatingProtocol) # # Generate a server protocol, and bind it to a string # transport for testing self.proto = self.servfactory.buildProtocol(None) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.proto.dataReceived(message) # Get response resp = self.tr.value() self.tr.clear() # And process it proto.read_message(resp) # Send second message message = proto.write_message() self.proto.dataReceived(message) # assert handshake finished self.assertTrue(proto.handshake_finished) # Make sure incoming data is paused till we establish client # connection, otherwise no place to write the data self.assertEqual(self.tr.producerState, 'paused') # Wait for the connection to be made d = yield self.connectionmade d = yield task.deferLater(reactor, .1, bool, 1) # How to make this ready? self.assertEqual(self.tr.producerState, 'producing') # Encrypt the message ptmsg = b'this is a test message' encmsg = proto.encrypt(ptmsg) # Feed it into the protocol self.proto.dataReceived(encmsg) # XXX - fix # wait to pass it through d = yield task.deferLater(reactor, .1, bool, 1) # fetch remote end out clientend = self.protos[0] self.assertEqual(clientend.data, ptmsg) # send a message the other direction rptmsg = b'this is a different test message going the other way' clientend.transport.write(rptmsg) # XXX - fix # wait to pass it through d = yield task.deferLater(reactor, .1, bool, 1) # receive it and decrypt it resp = self.tr.value() self.assertEqual(proto.decrypt(resp), rptmsg) if False: import time s = time.time() cnt = 40000 blksz = 1024 rnd = os.urandom(blksz) for i in range(0, cnt): proto.encrypt(rnd) e = time.time() print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024)) # clean up connection clientend.transport.loseConnection() @defer.inlineCallbacks def test_clientserver(self): # Path that the client "listener" will sit on. cptsockpath = os.path.join(self.tempdir, 'clientptsock') # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') # Start up the server servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) servlpobj = yield servep.listen(self.servfactory) # Start up the client half clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath)) clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath))) # Conenct to the client clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath)) clptconobj = yield clptep.connect(self.AccProtFactory(self)) # The client plain text connection clptproto = self.protos[-1] clptproto.transport.write('this is a test') # Clean up d = yield servlpobj.stopListening() d = yield clientlpobj.stopListening()