Browse Source

check point for some work... Changing directions, but don't want

to lose the work of mocking the client connection..
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
dfcd36caab
2 changed files with 102 additions and 15 deletions
  1. +1
    -0
      requirements.txt
  2. +101
    -15
      twistednoise.py

+ 1
- 0
requirements.txt View File

@@ -1,3 +1,4 @@
coverage
-e git+https://github.com/jmgurney/noiseprotocol.git@ab6f8ebe0e28f5a4105928c13baddcfdc43b7e82#egg=noiseprotocol
twisted
mock

+ 101
- 15
twistednoise.py View File

@@ -2,17 +2,19 @@ from twisted.trial import unittest
from twisted.test import proto_helpers
from noise.connection import NoiseConnection, Keypair
from twisted.internet.protocol import Factory
from twisted.internet import endpoints, reactor,defer
# XXX - shouldn't need to access the underlying primitives, but that's what
# noiseprotocol module requires.
from cryptography.hazmat.primitives.asymmetric import x448
from cryptography.hazmat.primitives import serialization

import twisted.internet.protocol
import mock

# Notes:
# Using XK, so that the connecting party's identity is hidden and that the
# server's party's key is known.

import twisted.internet.protocol

def genkeypair():
'''Generates a keypair, and returns a tuple of (public, private).
They are encoded as raw bytes, and sutible for use w/ Noise.'''
@@ -30,6 +32,19 @@ def genkeypair():
return pub, priv

class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
'''This class acts as a Noise Protocol responder. The factory that
creates this Protocol is required to have the properties server_key
and endpoint.

The server_key propery is the key for the server that the clients are
required to have (due to Noise XK protocol used) to authenticate the
server.

The endpoint property contains the endpoint as a string that will be
used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
for information on how to use this property.'''

def connectionMade(self):
# Initialize Noise
noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
@@ -45,29 +60,81 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
self.noise.read_message(data)
if not self.noise.handshake_finished:
self.transport.write(self.noise.write_message())
if self.noise.handshake_finished:
self.transport.pauseProducing()

# start the connection to the endpoint
ep = endpoints.clientFromString(reactor, self.factory.endpoint)
epdef = ep.connect(ClientProxyFactory(self))
epdef.addCallback(self.proxyConnected)
else:
r = self.noise.decrypt(data)
# echo it
self.transport.write(self.noise.encrypt(r))

self.endpoint.write(r)

def proxyConnected(self, endpoint):
print('pc')
self.endpoint = endpoint
self.transport.resumeProducing()

class ClientProxyProtocol(twisted.internet.protocol.Protocol):
pass

class ClientProxyFactory(Factory):
protocol = ClientProxyProtocol

def __init__(self, noiseproto):
self.noiseproto = noiseproto

class TwistedNoiseServerFactory(Factory):
protocol = TwistedNoiseServerProtocol

def __init__(self, server_key):
def __init__(self, server_key, endpoint):
self.server_key = server_key
self.endpoint = endpoint

class TNServerTest(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.server_key_pair = genkeypair()
self.protos = []
class AccProtFactory(Factory):
protocol = proto_helpers.AccumulatingProtocol

def __init__(self, tc):
self.__tc = tc
Factory.__init__(self)

def buildProtocol(self, addr):
r = Factory.buildProtocol(addr)
self.__tc.append(r)
return r

for i in range(10000, 20000):
ep = endpoints.TCP4ServerEndpoint(reactor, i)
try:
lpobj = yield ep.listen(AccProtFactory(self))
except Exception:
continue
break
else:
raise RuntimeError('all ports occupied')

factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1])
self.testserv = ep
self.listenportobj = lpobj
self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i
factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
self.proto = factory.buildProtocol(('127.0.0.1', 0))
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)

self.client_key_pair = genkeypair()

def test_testprotocol(self):
def tearDown(self):
self.listenportobj.stopListening()

@mock.patch('twisted.internet.endpoints.clientFromString')
def test_testprotocol(self, cfs):
# Create client
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
proto.set_as_initiator()
@@ -90,22 +157,41 @@ class TNServerTest(unittest.TestCase):
# And process it
proto.read_message(resp)

clientconnection = defer.Deferred()
cfs().connect.return_value = clientconnection

# Send second message
message = proto.write_message()
self.proto.dataReceived(message)

# Finish handshake
# assert handshake finished
self.assertTrue(proto.handshake_finished)

# Make sure incoming data is paused till we establish client
# connection, otherwise no place to write the data
self.assertEqual(self.tr.producerState, 'paused')

# Make sure that clientFromString is called properly
cfs.assert_called_with(reactor, self.endpoint)

# And that it was connect'ed
cfs().connect.assert_called()

# and that ClientProxyFactory was called properly
args = cfs().connect.call_args.args
self.assertIsInstance(args[0], ClientProxyFactory)
self.assertIs(args[0].noiseproto, self.proto)

# Simulate that a connection has happened
remoteend = proto_helpers.StringTransport()
remoteproto = args[0].buildProtocol(None)
remoteproto.makeConnection(remoteend)

# Encrypt the message
ptmsg = b'this is a test message'
encmsg = proto.encrypt(ptmsg)

# Feed it into the protocol
self.proto.dataReceived(encmsg)

# Get echo
resp = self.tr.value()
self.tr.clear()

ptresp = proto.decrypt(resp)

self.assertEqual(ptresp, ptmsg)
self.assertEqual(remoteend.value(), ptmsg)

Loading…
Cancel
Save