| @@ -1,6 +1,6 @@ | |||
| from noise.connection import NoiseConnection, Keypair | |||
| from twisted.trial import unittest | |||
| from twisted.test import proto_helpers | |||
| from noise.connection import NoiseConnection, Keypair | |||
| from twisted.internet.protocol import Factory | |||
| from twisted.internet import endpoints, reactor, defer, task | |||
| # XXX - shouldn't need to access the underlying primitives, but that's what | |||
| @@ -46,6 +46,11 @@ __license__ = '2-clause BSD license' | |||
| # Using XK, so that the connecting party's identity is hidden and that the | |||
| # server's party's key is known. | |||
| # | |||
| # The client and server class names are used to refer to the initiator and | |||
| # responder sides. Even though both client and server each have a server | |||
| # component (listen in on a socket to start comms), and a client component | |||
| # (create connection). | |||
| # | |||
| # Noise packets are 16 bytes + length of data | |||
| # | |||
| # Proposed method to hide message lengths: | |||
| @@ -101,8 +106,12 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): | |||
| self.noise = noise | |||
| if self.mode == 'resp': | |||
| noise.set_as_responder() | |||
| elif self.mode == 'init': | |||
| noise.set_as_initiator() | |||
| noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) | |||
| if hasattr(self.factory, 'pub_key'): | |||
| noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key) | |||
| # Start Handshake | |||
| noise.start_handshake() | |||
| @@ -121,6 +130,7 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): | |||
| self.noise.read_message(data) | |||
| if not self.noise.handshake_finished: | |||
| self.transport.write(self.noise.write_message()) | |||
| if self.noise.handshake_finished: | |||
| self.handshakeFinished() | |||
| else: | |||
| @@ -154,25 +164,37 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | |||
| epdef = ep.connect(ServerPTProxyFactory(self)) | |||
| epdef.addCallback(self.plaintextConnected) | |||
| class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | |||
| mode = 'init' | |||
| class TwistedNoiseServerFactory(Factory): | |||
| protocol = TwistedNoiseServerProtocol | |||
| def __init__(self, priv_key, endpoint): | |||
| self.priv_key = priv_key | |||
| self.endpoint = endpoint | |||
| # Supporting classes for TwistedNoiseServer | |||
| class PTProxyProtocol(twisted.internet.protocol.Protocol): | |||
| '''Simple protocol then when data is received, encrypts the data | |||
| w/ the connected noise protocol.''' | |||
| class ServerPTProxyProtocol(twisted.internet.protocol.Protocol): | |||
| def dataReceived(self, data): | |||
| self.factory.noiseproto.encData(data) | |||
| class ServerPTProxyFactory(Factory): | |||
| protocol = ServerPTProxyProtocol | |||
| protocol = PTProxyProtocol | |||
| def __init__(self, noiseproto): | |||
| self.noiseproto = noiseproto | |||
| class TwistedNoiseServerFactory(Factory): | |||
| protocol = TwistedNoiseServerProtocol | |||
| class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | |||
| mode = 'init' | |||
| def __init__(self, priv_key, endpoint): | |||
| class ClientPTFactory(Factory): | |||
| protocol = TwistedNoiseClientProtocol | |||
| def __init__(self, priv_key, servpub, sockstr): | |||
| self.priv_key = priv_key | |||
| self.endpoint = endpoint | |||
| self.pub_key = servpub | |||
| self.sockstr = sockstr | |||
| class TNServerTest(unittest.TestCase): | |||
| @defer.inlineCallbacks | |||
| @@ -205,6 +227,8 @@ class TNServerTest(unittest.TestCase): | |||
| self.__tc.protos.append(r) | |||
| return r | |||
| self.AccProtFactory = AccProtFactory | |||
| # Setup PT client endpoint | |||
| sockpath = os.path.join(self.tempdir, 'servptsock') | |||
| ep = endpoints.UNIXServerEndpoint(reactor, sockpath) | |||
| @@ -231,7 +255,7 @@ class TNServerTest(unittest.TestCase): | |||
| # | |||
| # proto (NoiseConnection) -> self.tr (StringTransport) -> | |||
| # self.proto (TwistedNoiseServerProtocol) -> | |||
| # self.proto.endpoint (ServerPTProxyProtocol) -> unix sock -> | |||
| # self.proto.endpoint (PTProxyProtocol) -> unix sock -> | |||
| # self.protos[0] (AccumulatingProtocol) | |||
| # | |||
| @@ -249,7 +273,6 @@ class TNServerTest(unittest.TestCase): | |||
| proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) | |||
| proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) | |||
| proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) | |||
| proto.start_handshake() | |||
| # Send first message | |||
| @@ -309,18 +332,46 @@ class TNServerTest(unittest.TestCase): | |||
| resp = self.tr.value() | |||
| self.assertEqual(proto.decrypt(resp), rptmsg) | |||
| if False: | |||
| import time | |||
| s = time.time() | |||
| cnt = 40000 | |||
| blksz = 1024 | |||
| rnd = os.urandom(blksz) | |||
| for i in range(0, cnt): | |||
| proto.encrypt(rnd) | |||
| e = time.time() | |||
| print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024)) | |||
| # clean up connection | |||
| clientend.transport.loseConnection() | |||
| @defer.inlineCallbacks | |||
| def test_clientserver(self): | |||
| # Path that the client "listener" sits on. | |||
| # Path that the client "listener" will sit on. | |||
| cptsockpath = os.path.join(self.tempdir, 'clientptsock') | |||
| # Path that the server sits on | |||
| # Path that the server will sit on | |||
| servsockpath = os.path.join(self.tempdir, 'servsock') | |||
| # Start up the server | |||
| servep = endpoints.serverFromString(reactor, _makeunix(servsockpath)) | |||
| servlpobj = yield servep.listen(self.servfactory) | |||
| # Start up the client half | |||
| clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath)) | |||
| clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath))) | |||
| # Conenct to the client | |||
| clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath)) | |||
| clptconobj = yield clptep.connect(self.AccProtFactory(self)) | |||
| # The client plain text connection | |||
| clptproto = self.protos[-1] | |||
| clptproto.transport.write('this is a test') | |||
| # Clean up | |||
| d = yield servlpobj.stopListening() | |||
| d = yield clientlpobj.stopListening() | |||