diff --git a/twistednoise.py b/twistednoise.py index 6873479..cc3398a 100644 --- a/twistednoise.py +++ b/twistednoise.py @@ -60,6 +60,11 @@ __license__ = '2-clause BSD license' # the connection aborts because of decryption failure. # +def _makeunix(path): + '''Make a properly formed unix path socket string.''' + + return 'unix:%s' % path + def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' @@ -146,18 +151,18 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): # start the connection to the endpoint ep = endpoints.clientFromString(reactor, self.factory.endpoint) - epdef = ep.connect(ClientProxyFactory(self)) + epdef = ep.connect(ServerPTProxyFactory(self)) epdef.addCallback(self.plaintextConnected) class TwistedNoiseClientProtocol(TwistedNoiseProtocol): mode = 'init' -class ClientProxyProtocol(twisted.internet.protocol.Protocol): +class ServerPTProxyProtocol(twisted.internet.protocol.Protocol): def dataReceived(self, data): self.factory.noiseproto.encData(data) -class ClientProxyFactory(Factory): - protocol = ClientProxyProtocol +class ServerPTProxyFactory(Factory): + protocol = ServerPTProxyProtocol def __init__(self, noiseproto): self.noiseproto = noiseproto @@ -172,12 +177,17 @@ class TwistedNoiseServerFactory(Factory): class TNServerTest(unittest.TestCase): @defer.inlineCallbacks def setUp(self): + # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) + # Generate key pairs self.server_key_pair = genkeypair() + self.client_key_pair = genkeypair() + + # Server's PT client will be here self.protos = [] self.connectionmade = defer.Deferred() @@ -195,23 +205,21 @@ class TNServerTest(unittest.TestCase): self.__tc.protos.append(r) return r - sockpath = os.path.join(self.tempdir, 'clientsock') + # Setup PT client endpoint + sockpath = os.path.join(self.tempdir, 'servptsock') ep = endpoints.UNIXServerEndpoint(reactor, sockpath) lpobj = yield ep.listen(AccProtFactory(self)) self.testserv = ep self.listenportobj = lpobj - self.endpoint = 'unix:path=%s' % sockpath + self.endpoint = _makeunix(sockpath) - factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) - self.proto = factory.buildProtocol(None) - self.tr = proto_helpers.StringTransport() - self.proto.makeConnection(self.tr) - - self.client_key_pair = genkeypair() + # Setup server, and configure where to connect to. + self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) + @defer.inlineCallbacks def tearDown(self): - self.listenportobj.stopListening() + d = yield self.listenportobj.stopListening() shutil.rmtree(self.basetempdir) self.tempdir = None @@ -223,9 +231,16 @@ class TNServerTest(unittest.TestCase): # # proto (NoiseConnection) -> self.tr (StringTransport) -> # self.proto (TwistedNoiseServerProtocol) -> - # self.proto.endpoint (ClientProxyProtocol) -> unix sock -> + # self.proto.endpoint (ServerPTProxyProtocol) -> unix sock -> # self.protos[0] (AccumulatingProtocol) # + + # Generate a server protocol, and bind it to a string + # transport for testing + self.proto = self.servfactory.buildProtocol(None) + self.tr = proto_helpers.StringTransport() + self.proto.makeConnection(self.tr) + # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() @@ -274,6 +289,7 @@ class TNServerTest(unittest.TestCase): # Feed it into the protocol self.proto.dataReceived(encmsg) + # XXX - fix # wait to pass it through d = yield task.deferLater(reactor, .1, bool, 1) @@ -285,6 +301,7 @@ class TNServerTest(unittest.TestCase): rptmsg = b'this is a different test message going the other way' clientend.transport.write(rptmsg) + # XXX - fix # wait to pass it through d = yield task.deferLater(reactor, .1, bool, 1) @@ -294,3 +311,16 @@ class TNServerTest(unittest.TestCase): # clean up connection clientend.transport.loseConnection() + + @defer.inlineCallbacks + def test_clientserver(self): + # Path that the client "listener" sits on. + cptsockpath = os.path.join(self.tempdir, 'clientptsock') + + # Path that the server sits on + servsockpath = os.path.join(self.tempdir, 'servsock') + + servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) + servlpobj = yield servep.listen(self.servfactory) + + d = yield servlpobj.stopListening()