from twisted.trial import unittest from twisted.test import proto_helpers from noise.connection import NoiseConnection, Keypair from twisted.internet.protocol import Factory from twisted.internet import endpoints, reactor, defer, task # XXX - shouldn't need to access the underlying primitives, but that's what # noiseprotocol module requires. from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives import serialization import mock import os.path import shutil import tempfile import twisted.internet.protocol __author__ = 'John-Mark Gurney' __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.' __license__ = '2-clause BSD license' # Copyright 2019 John-Mark Gurney. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # 1. Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # 2. Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF # SUCH DAMAGE. # Notes: # Using XK, so that the connecting party's identity is hidden and that the # server's party's key is known. # # Noise packets are 16 bytes + length of data # # Proposed method to hide message lengths: # Immediately after handshake completes, each side generates and sends # an n byte key that will be used for encrypting (algo tbd) their own # byte counts. The length field will be encrypted via # E(pktnum, key) XOR 2 byte length. # # Note that authenticating the message length is NOT needed. This is # because the noise message blocks themselves are authenticated. The # worse that could happen is that a larger read (64k) is done, and then # the connection aborts because of decryption failure. # def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): '''This class acts as a Noise Protocol responder. The factory that creates this Protocol is required to have the properties server_key and endpoint. The server_key propery is the key for the server that the clients are required to have (due to Noise XK protocol used) to authenticate the server. The endpoint property contains the endpoint as a string that will be used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients for information on how to use this property.''' def connectionMade(self): # Initialize Noise noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') self.noise = noise noise.set_as_responder() noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) # Start Handshake noise.start_handshake() def encData(self, data): self.transport.write(self.noise.encrypt(data)) def dataReceived(self, data): if not self.noise.handshake_finished: self.noise.read_message(data) if not self.noise.handshake_finished: self.transport.write(self.noise.write_message()) if self.noise.handshake_finished: self.handshakeFinished() else: r = self.noise.decrypt(data) self.endpoint.transport.write(r) def handshakeFinished(self): raise NotImplementedError def plaintextConnected(self, endpoint): self.endpoint = endpoint self.transport.resumeProducing() class TwistedNoiseServerProtocol(TwistedNoiseProtocol): def handshakeFinished(self): self.transport.pauseProducing() # start the connection to the endpoint ep = endpoints.clientFromString(reactor, self.factory.endpoint) epdef = ep.connect(ClientProxyFactory(self)) epdef.addCallback(self.plaintextConnected) class ClientProxyProtocol(twisted.internet.protocol.Protocol): def dataReceived(self, data): self.factory.noiseproto.encData(data) class ClientProxyFactory(Factory): protocol = ClientProxyProtocol def __init__(self, noiseproto): self.noiseproto = noiseproto class TwistedNoiseServerFactory(Factory): protocol = TwistedNoiseServerProtocol def __init__(self, server_key, endpoint): self.server_key = server_key self.endpoint = endpoint class TNServerTest(unittest.TestCase): @defer.inlineCallbacks def setUp(self): d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) self.server_key_pair = genkeypair() self.protos = [] self.connectionmade = defer.Deferred() class AccProtFactory(Factory): protocol = proto_helpers.AccumulatingProtocol def __init__(self, tc): self.__tc = tc Factory.__init__(self) protocolConnectionMade = self.connectionmade def buildProtocol(self, addr): r = Factory.buildProtocol(self, addr) self.__tc.protos.append(r) return r sockpath = os.path.join(self.tempdir, 'clientsock') ep = endpoints.UNIXServerEndpoint(reactor, sockpath) lpobj = yield ep.listen(AccProtFactory(self)) self.testserv = ep self.listenportobj = lpobj self.endpoint = 'unix:path=%s' % sockpath factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) self.proto = factory.buildProtocol(None) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) self.client_key_pair = genkeypair() def tearDown(self): self.listenportobj.stopListening() shutil.rmtree(self.basetempdir) self.tempdir = None @defer.inlineCallbacks def test_testserver(self): # # How this test is plumbed: # # proto (NoiseConnection) -> self.tr (StringTransport) -> # self.proto (TwistedNoiseServerProtocol) -> # self.proto.endpoint (ClientProxyProtocol) -> unix sock -> # self.protos[0] (AccumulatingProtocol) # # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.start_handshake() # Send first message message = proto.write_message() self.proto.dataReceived(message) # Get response resp = self.tr.value() self.tr.clear() # And process it proto.read_message(resp) # Send second message message = proto.write_message() self.proto.dataReceived(message) # assert handshake finished self.assertTrue(proto.handshake_finished) # Make sure incoming data is paused till we establish client # connection, otherwise no place to write the data self.assertEqual(self.tr.producerState, 'paused') # Wait for the connection to be made d = yield self.connectionmade d = yield task.deferLater(reactor, .1, bool, 1) # How to make this ready? self.assertEqual(self.tr.producerState, 'producing') # Encrypt the message ptmsg = b'this is a test message' encmsg = proto.encrypt(ptmsg) # Feed it into the protocol self.proto.dataReceived(encmsg) # wait to pass it through d = yield task.deferLater(reactor, .1, bool, 1) # fetch remote end out clientend = self.protos[0] self.assertEqual(clientend.data, ptmsg) # send a message the other direction rptmsg = b'this is a different test message going the other way' clientend.transport.write(rptmsg) # wait to pass it through d = yield task.deferLater(reactor, .1, bool, 1) # receive it and decrypt it resp = self.tr.value() self.assertEqual(proto.decrypt(resp), rptmsg) # clean up connection clientend.transport.loseConnection()