|
- from noise.connection import NoiseConnection, Keypair
- from twisted.trial import unittest
- from twisted.test import proto_helpers
- from twisted.internet.protocol import Factory
- from twisted.internet import endpoints, reactor, defer, task
- # XXX - shouldn't need to access the underlying primitives, but that's what
- # noiseprotocol module requires.
- from cryptography.hazmat.primitives.asymmetric import x448
- from cryptography.hazmat.primitives import serialization
-
- import mock
- import os.path
- import shutil
- import tempfile
- import twisted.internet.protocol
-
- __author__ = 'John-Mark Gurney'
- __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.'
- __license__ = '2-clause BSD license'
-
- # Copyright 2019 John-Mark Gurney.
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions
- # are met:
- # 1. Redistributions of source code must retain the above copyright
- # notice, this list of conditions and the following disclaimer.
- # 2. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- #
- # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
- # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
- # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
- # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
- # SUCH DAMAGE.
-
- # Notes:
- # Using XK, so that the connecting party's identity is hidden and that the
- # server's party's key is known.
- #
- # The client and server class names are used to refer to the initiator and
- # responder sides. Even though both client and server each have a server
- # component (listen in on a socket to start comms), and a client component
- # (create connection).
- #
- # Noise packets are 16 bytes + length of data
- #
- # Proposed method to hide message lengths:
- # Immediately after handshake completes, each side generates and sends
- # an n byte key that will be used for encrypting (algo tbd) their own
- # byte counts. The length field will be encrypted via
- # E(pktnum, key) XOR 2 byte length.
- #
- # Note that authenticating the message length is NOT needed. This is
- # because the noise message blocks themselves are authenticated. The
- # worse that could happen is that a larger read (64k) is done, and then
- # the connection aborts because of decryption failure.
- #
-
- def _makeunix(path):
- '''Make a properly formed unix path socket string.'''
-
- return 'unix:%s' % path
-
- def genkeypair():
- '''Generates a keypair, and returns a tuple of (public, private).
- They are encoded as raw bytes, and sutible for use w/ Noise.'''
-
- key = x448.X448PrivateKey.generate()
-
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
-
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
-
- return pub, priv
-
- class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
- '''This class acts as a Noise Protocol responder. The factory that
- creates this Protocol is required to have the properties server_key
- and endpoint.
-
- The server_key propery is the key for the server that the clients are
- required to have (due to Noise XK protocol used) to authenticate the
- server.
-
- The endpoint property contains the endpoint as a string that will be
- used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
- and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
- for information on how to use this property.'''
-
- def connectionMade(self):
- # Initialize Noise
- noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- self.noise = noise
- if self.mode == 'resp':
- noise.set_as_responder()
- elif self.mode == 'init':
- noise.set_as_initiator()
-
- noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)
- if hasattr(self.factory, 'pub_key'):
- noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key)
-
- # Start Handshake
- noise.start_handshake()
-
- def encData(self, data):
- '''Receive plain text data, encrypt it, and send it down the
- wire.'''
-
- self.transport.write(self.noise.encrypt(data))
-
- def dataReceived(self, data):
- '''Receive encrypted data, and write it to the endpoint that
- was connected via the plaintextConnected method.'''
-
- if not self.noise.handshake_finished:
- self.noise.read_message(data)
- if not self.noise.handshake_finished:
- self.transport.write(self.noise.write_message())
-
- if self.noise.handshake_finished:
- self.handshakeFinished()
- else:
- r = self.noise.decrypt(data)
-
- self.endpoint.transport.write(r)
-
- def handshakeFinished(self): # pragma: no cover
- '''This function is called when the handshake has been
- completed. This is used to start data flowing, and to
- do any necessary connection work.'''
-
- raise NotImplementedError
-
- def plaintextConnected(self, endpoint):
- '''Connect the plain text endpoint to the factory. All the
- decrypted data will be written to this protocol,
- (specifically, it's transport).'''
-
- self.endpoint = endpoint
- self.transport.resumeProducing()
-
- class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
- mode = 'resp'
-
- def handshakeFinished(self):
- self.transport.pauseProducing()
-
- # start the connection to the endpoint
- ep = endpoints.clientFromString(reactor, self.factory.endpoint)
- epdef = ep.connect(ServerPTProxyFactory(self))
- epdef.addCallback(self.plaintextConnected)
-
- class TwistedNoiseServerFactory(Factory):
- protocol = TwistedNoiseServerProtocol
-
- def __init__(self, priv_key, endpoint):
- self.priv_key = priv_key
- self.endpoint = endpoint
-
- # Supporting classes for TwistedNoiseServer
- class PTProxyProtocol(twisted.internet.protocol.Protocol):
- '''Simple protocol then when data is received, encrypts the data
- w/ the connected noise protocol.'''
-
- def dataReceived(self, data):
- self.factory.noiseproto.encData(data)
-
- class ServerPTProxyFactory(Factory):
- protocol = PTProxyProtocol
-
- def __init__(self, noiseproto):
- self.noiseproto = noiseproto
-
- class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
- mode = 'init'
-
- class ClientPTFactory(Factory):
- protocol = TwistedNoiseClientProtocol
-
- def __init__(self, priv_key, servpub, sockstr):
- self.priv_key = priv_key
- self.pub_key = servpub
- self.sockstr = sockstr
-
- class TNServerTest(unittest.TestCase):
- @defer.inlineCallbacks
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
-
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
-
- # Server's PT client will be here
- self.protos = []
- self.connectionmade = defer.Deferred()
-
- class AccProtFactory(Factory):
- protocol = proto_helpers.AccumulatingProtocol
-
- def __init__(self, tc):
- self.__tc = tc
- Factory.__init__(self)
-
- protocolConnectionMade = self.connectionmade
-
- def buildProtocol(self, addr):
- r = Factory.buildProtocol(self, addr)
- self.__tc.protos.append(r)
- return r
-
- self.AccProtFactory = AccProtFactory
-
- # Setup PT client endpoint
- sockpath = os.path.join(self.tempdir, 'servptsock')
- ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
- lpobj = yield ep.listen(AccProtFactory(self))
-
- self.testserv = ep
- self.listenportobj = lpobj
- self.endpoint = _makeunix(sockpath)
-
- # Setup server, and configure where to connect to.
- self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)
-
- @defer.inlineCallbacks
- def tearDown(self):
- d = yield self.listenportobj.stopListening()
-
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
-
- @defer.inlineCallbacks
- def test_testserver(self):
- #
- # How this test is plumbed:
- #
- # proto (NoiseConnection) -> self.tr (StringTransport) ->
- # self.proto (TwistedNoiseServerProtocol) ->
- # self.proto.endpoint (PTProxyProtocol) -> unix sock ->
- # self.protos[0] (AccumulatingProtocol)
- #
-
- # Generate a server protocol, and bind it to a string
- # transport for testing
- self.proto = self.servfactory.buildProtocol(None)
- self.tr = proto_helpers.StringTransport()
- self.proto.makeConnection(self.tr)
-
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
-
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.proto.dataReceived(message)
-
- # Get response
- resp = self.tr.value()
- self.tr.clear()
-
- # And process it
- proto.read_message(resp)
-
- # Send second message
- message = proto.write_message()
- self.proto.dataReceived(message)
-
- # assert handshake finished
- self.assertTrue(proto.handshake_finished)
-
- # Make sure incoming data is paused till we establish client
- # connection, otherwise no place to write the data
- self.assertEqual(self.tr.producerState, 'paused')
-
- # Wait for the connection to be made
- d = yield self.connectionmade
-
- d = yield task.deferLater(reactor, .1, bool, 1)
-
- # How to make this ready?
- self.assertEqual(self.tr.producerState, 'producing')
-
- # Encrypt the message
- ptmsg = b'this is a test message'
- encmsg = proto.encrypt(ptmsg)
-
- # Feed it into the protocol
- self.proto.dataReceived(encmsg)
-
- # XXX - fix
- # wait to pass it through
- d = yield task.deferLater(reactor, .1, bool, 1)
-
- # fetch remote end out
- clientend = self.protos[0]
- self.assertEqual(clientend.data, ptmsg)
-
- # send a message the other direction
- rptmsg = b'this is a different test message going the other way'
- clientend.transport.write(rptmsg)
-
- # XXX - fix
- # wait to pass it through
- d = yield task.deferLater(reactor, .1, bool, 1)
-
- # receive it and decrypt it
- resp = self.tr.value()
- self.assertEqual(proto.decrypt(resp), rptmsg)
-
- if False:
- import time
- s = time.time()
- cnt = 40000
- blksz = 1024
- rnd = os.urandom(blksz)
- for i in range(0, cnt):
- proto.encrypt(rnd)
- e = time.time()
-
- print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024))
-
- # clean up connection
- clientend.transport.loseConnection()
-
- @defer.inlineCallbacks
- def test_clientserver(self):
- # Path that the client "listener" will sit on.
- cptsockpath = os.path.join(self.tempdir, 'clientptsock')
-
- # Path that the server will sit on
- servsockpath = os.path.join(self.tempdir, 'servsock')
-
- # Start up the server
- servep = endpoints.serverFromString(reactor, _makeunix(servsockpath))
- servlpobj = yield servep.listen(self.servfactory)
-
- # Start up the client half
- clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath))
- clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath)))
-
- # Conenct to the client
- clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath))
- clptconobj = yield clptep.connect(self.AccProtFactory(self))
-
- # The client plain text connection
- clptproto = self.protos[-1]
-
- clptproto.transport.write('this is a test')
-
- # Clean up
- d = yield servlpobj.stopListening()
- d = yield clientlpobj.stopListening()
|