Browse Source

last bit of work on the twisted version before I stopped...

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
1112dfb6dc
1 changed files with 64 additions and 13 deletions
  1. +64
    -13
      twistednoise.py

+ 64
- 13
twistednoise.py View File

@@ -1,6 +1,6 @@
from noise.connection import NoiseConnection, Keypair
from twisted.trial import unittest from twisted.trial import unittest
from twisted.test import proto_helpers from twisted.test import proto_helpers
from noise.connection import NoiseConnection, Keypair
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import endpoints, reactor, defer, task from twisted.internet import endpoints, reactor, defer, task
# XXX - shouldn't need to access the underlying primitives, but that's what # XXX - shouldn't need to access the underlying primitives, but that's what
@@ -46,6 +46,11 @@ __license__ = '2-clause BSD license'
# Using XK, so that the connecting party's identity is hidden and that the # Using XK, so that the connecting party's identity is hidden and that the
# server's party's key is known. # server's party's key is known.
# #
# The client and server class names are used to refer to the initiator and
# responder sides. Even though both client and server each have a server
# component (listen in on a socket to start comms), and a client component
# (create connection).
#
# Noise packets are 16 bytes + length of data # Noise packets are 16 bytes + length of data
# #
# Proposed method to hide message lengths: # Proposed method to hide message lengths:
@@ -101,8 +106,12 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
self.noise = noise self.noise = noise
if self.mode == 'resp': if self.mode == 'resp':
noise.set_as_responder() noise.set_as_responder()
elif self.mode == 'init':
noise.set_as_initiator()


noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)
if hasattr(self.factory, 'pub_key'):
noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key)


# Start Handshake # Start Handshake
noise.start_handshake() noise.start_handshake()
@@ -121,6 +130,7 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
self.noise.read_message(data) self.noise.read_message(data)
if not self.noise.handshake_finished: if not self.noise.handshake_finished:
self.transport.write(self.noise.write_message()) self.transport.write(self.noise.write_message())

if self.noise.handshake_finished: if self.noise.handshake_finished:
self.handshakeFinished() self.handshakeFinished()
else: else:
@@ -154,25 +164,37 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
epdef = ep.connect(ServerPTProxyFactory(self)) epdef = ep.connect(ServerPTProxyFactory(self))
epdef.addCallback(self.plaintextConnected) epdef.addCallback(self.plaintextConnected)


class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
mode = 'init'
class TwistedNoiseServerFactory(Factory):
protocol = TwistedNoiseServerProtocol

def __init__(self, priv_key, endpoint):
self.priv_key = priv_key
self.endpoint = endpoint

# Supporting classes for TwistedNoiseServer
class PTProxyProtocol(twisted.internet.protocol.Protocol):
'''Simple protocol then when data is received, encrypts the data
w/ the connected noise protocol.'''


class ServerPTProxyProtocol(twisted.internet.protocol.Protocol):
def dataReceived(self, data): def dataReceived(self, data):
self.factory.noiseproto.encData(data) self.factory.noiseproto.encData(data)


class ServerPTProxyFactory(Factory): class ServerPTProxyFactory(Factory):
protocol = ServerPTProxyProtocol
protocol = PTProxyProtocol


def __init__(self, noiseproto): def __init__(self, noiseproto):
self.noiseproto = noiseproto self.noiseproto = noiseproto


class TwistedNoiseServerFactory(Factory):
protocol = TwistedNoiseServerProtocol
class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
mode = 'init'


def __init__(self, priv_key, endpoint):
class ClientPTFactory(Factory):
protocol = TwistedNoiseClientProtocol

def __init__(self, priv_key, servpub, sockstr):
self.priv_key = priv_key self.priv_key = priv_key
self.endpoint = endpoint
self.pub_key = servpub
self.sockstr = sockstr


class TNServerTest(unittest.TestCase): class TNServerTest(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -205,6 +227,8 @@ class TNServerTest(unittest.TestCase):
self.__tc.protos.append(r) self.__tc.protos.append(r)
return r return r


self.AccProtFactory = AccProtFactory

# Setup PT client endpoint # Setup PT client endpoint
sockpath = os.path.join(self.tempdir, 'servptsock') sockpath = os.path.join(self.tempdir, 'servptsock')
ep = endpoints.UNIXServerEndpoint(reactor, sockpath) ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
@@ -231,7 +255,7 @@ class TNServerTest(unittest.TestCase):
# #
# proto (NoiseConnection) -> self.tr (StringTransport) -> # proto (NoiseConnection) -> self.tr (StringTransport) ->
# self.proto (TwistedNoiseServerProtocol) -> # self.proto (TwistedNoiseServerProtocol) ->
# self.proto.endpoint (ServerPTProxyProtocol) -> unix sock ->
# self.proto.endpoint (PTProxyProtocol) -> unix sock ->
# self.protos[0] (AccumulatingProtocol) # self.protos[0] (AccumulatingProtocol)
# #


@@ -249,7 +273,6 @@ class TNServerTest(unittest.TestCase):
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])


proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
proto.start_handshake() proto.start_handshake()


# Send first message # Send first message
@@ -309,18 +332,46 @@ class TNServerTest(unittest.TestCase):
resp = self.tr.value() resp = self.tr.value()
self.assertEqual(proto.decrypt(resp), rptmsg) self.assertEqual(proto.decrypt(resp), rptmsg)


if False:
import time
s = time.time()
cnt = 40000
blksz = 1024
rnd = os.urandom(blksz)
for i in range(0, cnt):
proto.encrypt(rnd)
e = time.time()

print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024))

# clean up connection # clean up connection
clientend.transport.loseConnection() clientend.transport.loseConnection()


@defer.inlineCallbacks @defer.inlineCallbacks
def test_clientserver(self): def test_clientserver(self):
# Path that the client "listener" sits on.
# Path that the client "listener" will sit on.
cptsockpath = os.path.join(self.tempdir, 'clientptsock') cptsockpath = os.path.join(self.tempdir, 'clientptsock')


# Path that the server sits on
# Path that the server will sit on
servsockpath = os.path.join(self.tempdir, 'servsock') servsockpath = os.path.join(self.tempdir, 'servsock')


# Start up the server
servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) servep = endpoints.serverFromString(reactor, _makeunix(servsockpath))
servlpobj = yield servep.listen(self.servfactory) servlpobj = yield servep.listen(self.servfactory)


# Start up the client half
clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath))
clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath)))

# Conenct to the client
clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath))
clptconobj = yield clptep.connect(self.AccProtFactory(self))

# The client plain text connection
clptproto = self.protos[-1]

clptproto.transport.write('this is a test')

# Clean up
d = yield servlpobj.stopListening() d = yield servlpobj.stopListening()
d = yield clientlpobj.stopListening()

Loading…
Cancel
Save