|
|
@@ -60,6 +60,11 @@ __license__ = '2-clause BSD license' |
|
|
|
# the connection aborts because of decryption failure. |
|
|
|
# |
|
|
|
|
|
|
|
def _makeunix(path): |
|
|
|
'''Make a properly formed unix path socket string.''' |
|
|
|
|
|
|
|
return 'unix:%s' % path |
|
|
|
|
|
|
|
def genkeypair(): |
|
|
|
'''Generates a keypair, and returns a tuple of (public, private). |
|
|
|
They are encoded as raw bytes, and sutible for use w/ Noise.''' |
|
|
@@ -146,18 +151,18 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): |
|
|
|
|
|
|
|
# start the connection to the endpoint |
|
|
|
ep = endpoints.clientFromString(reactor, self.factory.endpoint) |
|
|
|
epdef = ep.connect(ClientProxyFactory(self)) |
|
|
|
epdef = ep.connect(ServerPTProxyFactory(self)) |
|
|
|
epdef.addCallback(self.plaintextConnected) |
|
|
|
|
|
|
|
class TwistedNoiseClientProtocol(TwistedNoiseProtocol): |
|
|
|
mode = 'init' |
|
|
|
|
|
|
|
class ClientProxyProtocol(twisted.internet.protocol.Protocol): |
|
|
|
class ServerPTProxyProtocol(twisted.internet.protocol.Protocol): |
|
|
|
def dataReceived(self, data): |
|
|
|
self.factory.noiseproto.encData(data) |
|
|
|
|
|
|
|
class ClientProxyFactory(Factory): |
|
|
|
protocol = ClientProxyProtocol |
|
|
|
class ServerPTProxyFactory(Factory): |
|
|
|
protocol = ServerPTProxyProtocol |
|
|
|
|
|
|
|
def __init__(self, noiseproto): |
|
|
|
self.noiseproto = noiseproto |
|
|
@@ -172,12 +177,17 @@ class TwistedNoiseServerFactory(Factory): |
|
|
|
class TNServerTest(unittest.TestCase): |
|
|
|
@defer.inlineCallbacks |
|
|
|
def setUp(self): |
|
|
|
# setup temporary directory |
|
|
|
d = os.path.realpath(tempfile.mkdtemp()) |
|
|
|
self.basetempdir = d |
|
|
|
self.tempdir = os.path.join(d, 'subdir') |
|
|
|
os.mkdir(self.tempdir) |
|
|
|
|
|
|
|
# Generate key pairs |
|
|
|
self.server_key_pair = genkeypair() |
|
|
|
self.client_key_pair = genkeypair() |
|
|
|
|
|
|
|
# Server's PT client will be here |
|
|
|
self.protos = [] |
|
|
|
self.connectionmade = defer.Deferred() |
|
|
|
|
|
|
@@ -195,23 +205,21 @@ class TNServerTest(unittest.TestCase): |
|
|
|
self.__tc.protos.append(r) |
|
|
|
return r |
|
|
|
|
|
|
|
sockpath = os.path.join(self.tempdir, 'clientsock') |
|
|
|
# Setup PT client endpoint |
|
|
|
sockpath = os.path.join(self.tempdir, 'servptsock') |
|
|
|
ep = endpoints.UNIXServerEndpoint(reactor, sockpath) |
|
|
|
lpobj = yield ep.listen(AccProtFactory(self)) |
|
|
|
|
|
|
|
self.testserv = ep |
|
|
|
self.listenportobj = lpobj |
|
|
|
self.endpoint = 'unix:path=%s' % sockpath |
|
|
|
self.endpoint = _makeunix(sockpath) |
|
|
|
|
|
|
|
factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) |
|
|
|
self.proto = factory.buildProtocol(None) |
|
|
|
self.tr = proto_helpers.StringTransport() |
|
|
|
self.proto.makeConnection(self.tr) |
|
|
|
|
|
|
|
self.client_key_pair = genkeypair() |
|
|
|
# Setup server, and configure where to connect to. |
|
|
|
self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) |
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
def tearDown(self): |
|
|
|
self.listenportobj.stopListening() |
|
|
|
d = yield self.listenportobj.stopListening() |
|
|
|
|
|
|
|
shutil.rmtree(self.basetempdir) |
|
|
|
self.tempdir = None |
|
|
@@ -223,9 +231,16 @@ class TNServerTest(unittest.TestCase): |
|
|
|
# |
|
|
|
# proto (NoiseConnection) -> self.tr (StringTransport) -> |
|
|
|
# self.proto (TwistedNoiseServerProtocol) -> |
|
|
|
# self.proto.endpoint (ClientProxyProtocol) -> unix sock -> |
|
|
|
# self.proto.endpoint (ServerPTProxyProtocol) -> unix sock -> |
|
|
|
# self.protos[0] (AccumulatingProtocol) |
|
|
|
# |
|
|
|
|
|
|
|
# Generate a server protocol, and bind it to a string |
|
|
|
# transport for testing |
|
|
|
self.proto = self.servfactory.buildProtocol(None) |
|
|
|
self.tr = proto_helpers.StringTransport() |
|
|
|
self.proto.makeConnection(self.tr) |
|
|
|
|
|
|
|
# Create client |
|
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
proto.set_as_initiator() |
|
|
@@ -274,6 +289,7 @@ class TNServerTest(unittest.TestCase): |
|
|
|
# Feed it into the protocol |
|
|
|
self.proto.dataReceived(encmsg) |
|
|
|
|
|
|
|
# XXX - fix |
|
|
|
# wait to pass it through |
|
|
|
d = yield task.deferLater(reactor, .1, bool, 1) |
|
|
|
|
|
|
@@ -285,6 +301,7 @@ class TNServerTest(unittest.TestCase): |
|
|
|
rptmsg = b'this is a different test message going the other way' |
|
|
|
clientend.transport.write(rptmsg) |
|
|
|
|
|
|
|
# XXX - fix |
|
|
|
# wait to pass it through |
|
|
|
d = yield task.deferLater(reactor, .1, bool, 1) |
|
|
|
|
|
|
@@ -294,3 +311,16 @@ class TNServerTest(unittest.TestCase): |
|
|
|
|
|
|
|
# clean up connection |
|
|
|
clientend.transport.loseConnection() |
|
|
|
|
|
|
|
@defer.inlineCallbacks |
|
|
|
def test_clientserver(self): |
|
|
|
# Path that the client "listener" sits on. |
|
|
|
cptsockpath = os.path.join(self.tempdir, 'clientptsock') |
|
|
|
|
|
|
|
# Path that the server sits on |
|
|
|
servsockpath = os.path.join(self.tempdir, 'servsock') |
|
|
|
|
|
|
|
servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) |
|
|
|
servlpobj = yield servep.listen(self.servfactory) |
|
|
|
|
|
|
|
d = yield servlpobj.stopListening() |