Browse Source

more refactoring and prep for client work...

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
7c51fbd77b
1 changed files with 44 additions and 14 deletions
  1. +44
    -14
      twistednoise.py

+ 44
- 14
twistednoise.py View File

@@ -60,6 +60,11 @@ __license__ = '2-clause BSD license'
# the connection aborts because of decryption failure. # the connection aborts because of decryption failure.
# #


def _makeunix(path):
'''Make a properly formed unix path socket string.'''

return 'unix:%s' % path

def genkeypair(): def genkeypair():
'''Generates a keypair, and returns a tuple of (public, private). '''Generates a keypair, and returns a tuple of (public, private).
They are encoded as raw bytes, and sutible for use w/ Noise.''' They are encoded as raw bytes, and sutible for use w/ Noise.'''
@@ -146,18 +151,18 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol):


# start the connection to the endpoint # start the connection to the endpoint
ep = endpoints.clientFromString(reactor, self.factory.endpoint) ep = endpoints.clientFromString(reactor, self.factory.endpoint)
epdef = ep.connect(ClientProxyFactory(self))
epdef = ep.connect(ServerPTProxyFactory(self))
epdef.addCallback(self.plaintextConnected) epdef.addCallback(self.plaintextConnected)


class TwistedNoiseClientProtocol(TwistedNoiseProtocol): class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
mode = 'init' mode = 'init'


class ClientProxyProtocol(twisted.internet.protocol.Protocol):
class ServerPTProxyProtocol(twisted.internet.protocol.Protocol):
def dataReceived(self, data): def dataReceived(self, data):
self.factory.noiseproto.encData(data) self.factory.noiseproto.encData(data)


class ClientProxyFactory(Factory):
protocol = ClientProxyProtocol
class ServerPTProxyFactory(Factory):
protocol = ServerPTProxyProtocol


def __init__(self, noiseproto): def __init__(self, noiseproto):
self.noiseproto = noiseproto self.noiseproto = noiseproto
@@ -172,12 +177,17 @@ class TwistedNoiseServerFactory(Factory):
class TNServerTest(unittest.TestCase): class TNServerTest(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
# setup temporary directory
d = os.path.realpath(tempfile.mkdtemp()) d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d self.basetempdir = d
self.tempdir = os.path.join(d, 'subdir') self.tempdir = os.path.join(d, 'subdir')
os.mkdir(self.tempdir) os.mkdir(self.tempdir)


# Generate key pairs
self.server_key_pair = genkeypair() self.server_key_pair = genkeypair()
self.client_key_pair = genkeypair()

# Server's PT client will be here
self.protos = [] self.protos = []
self.connectionmade = defer.Deferred() self.connectionmade = defer.Deferred()


@@ -195,23 +205,21 @@ class TNServerTest(unittest.TestCase):
self.__tc.protos.append(r) self.__tc.protos.append(r)
return r return r


sockpath = os.path.join(self.tempdir, 'clientsock')
# Setup PT client endpoint
sockpath = os.path.join(self.tempdir, 'servptsock')
ep = endpoints.UNIXServerEndpoint(reactor, sockpath) ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
lpobj = yield ep.listen(AccProtFactory(self)) lpobj = yield ep.listen(AccProtFactory(self))


self.testserv = ep self.testserv = ep
self.listenportobj = lpobj self.listenportobj = lpobj
self.endpoint = 'unix:path=%s' % sockpath
self.endpoint = _makeunix(sockpath)


factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)
self.proto = factory.buildProtocol(None)
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)

self.client_key_pair = genkeypair()
# Setup server, and configure where to connect to.
self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)


@defer.inlineCallbacks
def tearDown(self): def tearDown(self):
self.listenportobj.stopListening()
d = yield self.listenportobj.stopListening()


shutil.rmtree(self.basetempdir) shutil.rmtree(self.basetempdir)
self.tempdir = None self.tempdir = None
@@ -223,9 +231,16 @@ class TNServerTest(unittest.TestCase):
# #
# proto (NoiseConnection) -> self.tr (StringTransport) -> # proto (NoiseConnection) -> self.tr (StringTransport) ->
# self.proto (TwistedNoiseServerProtocol) -> # self.proto (TwistedNoiseServerProtocol) ->
# self.proto.endpoint (ClientProxyProtocol) -> unix sock ->
# self.proto.endpoint (ServerPTProxyProtocol) -> unix sock ->
# self.protos[0] (AccumulatingProtocol) # self.protos[0] (AccumulatingProtocol)
# #

# Generate a server protocol, and bind it to a string
# transport for testing
self.proto = self.servfactory.buildProtocol(None)
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)

# Create client # Create client
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
proto.set_as_initiator() proto.set_as_initiator()
@@ -274,6 +289,7 @@ class TNServerTest(unittest.TestCase):
# Feed it into the protocol # Feed it into the protocol
self.proto.dataReceived(encmsg) self.proto.dataReceived(encmsg)


# XXX - fix
# wait to pass it through # wait to pass it through
d = yield task.deferLater(reactor, .1, bool, 1) d = yield task.deferLater(reactor, .1, bool, 1)


@@ -285,6 +301,7 @@ class TNServerTest(unittest.TestCase):
rptmsg = b'this is a different test message going the other way' rptmsg = b'this is a different test message going the other way'
clientend.transport.write(rptmsg) clientend.transport.write(rptmsg)


# XXX - fix
# wait to pass it through # wait to pass it through
d = yield task.deferLater(reactor, .1, bool, 1) d = yield task.deferLater(reactor, .1, bool, 1)


@@ -294,3 +311,16 @@ class TNServerTest(unittest.TestCase):


# clean up connection # clean up connection
clientend.transport.loseConnection() clientend.transport.loseConnection()

@defer.inlineCallbacks
def test_clientserver(self):
# Path that the client "listener" sits on.
cptsockpath = os.path.join(self.tempdir, 'clientptsock')

# Path that the server sits on
servsockpath = os.path.join(self.tempdir, 'servsock')

servep = endpoints.serverFromString(reactor, _makeunix(servsockpath))
servlpobj = yield servep.listen(self.servfactory)

d = yield servlpobj.stopListening()

Loading…
Cancel
Save