|
|
@@ -1,6 +1,6 @@ |
|
|
|
from noise.connection import NoiseConnection, Keypair |
|
|
|
from twisted.trial import unittest |
|
|
|
from twisted.test import proto_helpers |
|
|
|
from noise.connection import NoiseConnection, Keypair |
|
|
|
from twisted.internet.protocol import Factory |
|
|
|
from twisted.internet import endpoints, reactor, defer, task |
|
|
|
# XXX - shouldn't need to access the underlying primitives, but that's what |
|
|
@@ -46,6 +46,11 @@ __license__ = '2-clause BSD license' |
|
|
|
# Using XK, so that the connecting party's identity is hidden and that the |
|
|
|
# server's party's key is known. |
|
|
|
# |
|
|
|
# The client and server class names are used to refer to the initiator and |
|
|
|
# responder sides. Even though both client and server each have a server |
|
|
|
# component (listen in on a socket to start comms), and a client component |
|
|
|
# (create connection). |
|
|
|
# |
|
|
|
# Noise packets are 16 bytes + length of data |
|
|
|
# |
|
|
|
# Proposed method to hide message lengths: |
|
|
@@ -101,8 +106,12 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): |
|
|
|
self.noise = noise |
|
|
|
if self.mode == 'resp': |
|
|
|
noise.set_as_responder() |
|
|
|
elif self.mode == 'init': |
|
|
|
noise.set_as_initiator() |
|
|
|
|
|
|
|
noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) |
|
|
|
if hasattr(self.factory, 'pub_key'): |
|
|
|
noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key) |
|
|
|
|
|
|
|
# Start Handshake |
|
|
|
noise.start_handshake() |
|
|
@@ -121,6 +130,7 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): |
|
|
|
self.noise.read_message(data) |
|
|
|
if not self.noise.handshake_finished: |
|
|
|
self.transport.write(self.noise.write_message()) |
|
|
|
|
|
|
|
if self.noise.handshake_finished: |
|
|
|
self.handshakeFinished() |
|
|
|
else: |
|
|
@@ -154,25 +164,37 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): |
|
|
|
epdef = ep.connect(ServerPTProxyFactory(self)) |
|
|
|
epdef.addCallback(self.plaintextConnected) |
|
|
|
|
|
|
|
class TwistedNoiseClientProtocol(TwistedNoiseProtocol): |
|
|
|
mode = 'init' |
|
|
|
class TwistedNoiseServerFactory(Factory): |
|
|
|
protocol = TwistedNoiseServerProtocol |
|
|
|
|
|
|
|
def __init__(self, priv_key, endpoint): |
|
|
|
self.priv_key = priv_key |
|
|
|
self.endpoint = endpoint |
|
|
|
|
|
|
|
# Supporting classes for TwistedNoiseServer |
|
|
|
class PTProxyProtocol(twisted.internet.protocol.Protocol): |
|
|
|
'''Simple protocol then when data is received, encrypts the data |
|
|
|
w/ the connected noise protocol.''' |
|
|
|
|
|
|
|
class ServerPTProxyProtocol(twisted.internet.protocol.Protocol): |
|
|
|
def dataReceived(self, data): |
|
|
|
self.factory.noiseproto.encData(data) |
|
|
|
|
|
|
|
class ServerPTProxyFactory(Factory): |
|
|
|
protocol = ServerPTProxyProtocol |
|
|
|
protocol = PTProxyProtocol |
|
|
|
|
|
|
|
def __init__(self, noiseproto): |
|
|
|
self.noiseproto = noiseproto |
|
|
|
|
|
|
|
class TwistedNoiseServerFactory(Factory): |
|
|
|
protocol = TwistedNoiseServerProtocol |
|
|
|
class TwistedNoiseClientProtocol(TwistedNoiseProtocol): |
|
|
|
mode = 'init' |
|
|
|
|
|
|
|
def __init__(self, priv_key, endpoint): |
|
|
|
class ClientPTFactory(Factory): |
|
|
|
protocol = TwistedNoiseClientProtocol |
|
|
|
|
|
|
|
def __init__(self, priv_key, servpub, sockstr): |
|
|
|
self.priv_key = priv_key |
|
|
|
self.endpoint = endpoint |
|
|
|
self.pub_key = servpub |
|
|
|
self.sockstr = sockstr |
|
|
|
|
|
|
|
class TNServerTest(unittest.TestCase): |
|
|
|
@defer.inlineCallbacks |
|
|
@@ -205,6 +227,8 @@ class TNServerTest(unittest.TestCase): |
|
|
|
self.__tc.protos.append(r) |
|
|
|
return r |
|
|
|
|
|
|
|
self.AccProtFactory = AccProtFactory |
|
|
|
|
|
|
|
# Setup PT client endpoint |
|
|
|
sockpath = os.path.join(self.tempdir, 'servptsock') |
|
|
|
ep = endpoints.UNIXServerEndpoint(reactor, sockpath) |
|
|
@@ -231,7 +255,7 @@ class TNServerTest(unittest.TestCase): |
|
|
|
# |
|
|
|
# proto (NoiseConnection) -> self.tr (StringTransport) -> |
|
|
|
# self.proto (TwistedNoiseServerProtocol) -> |
|
|
|
# self.proto.endpoint (ServerPTProxyProtocol) -> unix sock -> |
|
|
|
# self.proto.endpoint (PTProxyProtocol) -> unix sock -> |
|
|
|
# self.protos[0] (AccumulatingProtocol) |
|
|
|
# |
|
|
|
|
|
|
@@ -249,7 +273,6 @@ class TNServerTest(unittest.TestCase): |
|
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) |
|
|
|
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) |
|
|
|
|
|
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) |
|
|
|
proto.start_handshake() |
|
|
|
|
|
|
|
# Send first message |
|
|
@@ -309,18 +332,46 @@ class TNServerTest(unittest.TestCase): |
|
|
|
resp = self.tr.value() |
|
|
|
self.assertEqual(proto.decrypt(resp), rptmsg) |
|
|
|
|
|
|
|
if False: |
|
|
|
import time |
|
|
|
s = time.time() |
|
|
|
cnt = 40000 |
|
|
|
blksz = 1024 |
|
|
|
rnd = os.urandom(blksz) |
|
|
|
for i in range(0, cnt): |
|
|
|
proto.encrypt(rnd) |
|
|
|
e = time.time() |
|
|
|
|
|
|
|
print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024)) |
|
|
|
|
|
|
|
# clean up connection |
|
|
|
clientend.transport.loseConnection() |
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
def test_clientserver(self): |
|
|
|
# Path that the client "listener" sits on. |
|
|
|
# Path that the client "listener" will sit on. |
|
|
|
cptsockpath = os.path.join(self.tempdir, 'clientptsock') |
|
|
|
|
|
|
|
# Path that the server sits on |
|
|
|
# Path that the server will sit on |
|
|
|
servsockpath = os.path.join(self.tempdir, 'servsock') |
|
|
|
|
|
|
|
# Start up the server |
|
|
|
servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) |
|
|
|
servlpobj = yield servep.listen(self.servfactory) |
|
|
|
|
|
|
|
# Start up the client half |
|
|
|
clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath)) |
|
|
|
clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath))) |
|
|
|
|
|
|
|
# Conenct to the client |
|
|
|
clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath)) |
|
|
|
clptconobj = yield clptep.connect(self.AccProtFactory(self)) |
|
|
|
|
|
|
|
# The client plain text connection |
|
|
|
clptproto = self.protos[-1] |
|
|
|
|
|
|
|
clptproto.transport.write('this is a test') |
|
|
|
|
|
|
|
# Clean up |
|
|
|
d = yield servlpobj.stopListening() |
|
|
|
d = yield clientlpobj.stopListening() |