| @@ -94,16 +94,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): | |||
| # Initialize Noise | |||
| noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||
| self.noise = noise | |||
| noise.set_as_responder() | |||
| noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) | |||
| if self.mode == 'resp': | |||
| noise.set_as_responder() | |||
| noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) | |||
| # Start Handshake | |||
| noise.start_handshake() | |||
| def encData(self, data): | |||
| '''Receive plain text data, encrypt it, and send it down the | |||
| wire.''' | |||
| self.transport.write(self.noise.encrypt(data)) | |||
| def dataReceived(self, data): | |||
| '''Receive encrypted data, and write it to the endpoint that | |||
| was connected via the plaintextConnected method.''' | |||
| if not self.noise.handshake_finished: | |||
| self.noise.read_message(data) | |||
| if not self.noise.handshake_finished: | |||
| @@ -115,14 +123,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): | |||
| self.endpoint.transport.write(r) | |||
| def handshakeFinished(self): | |||
| def handshakeFinished(self): # pragma: no cover | |||
| '''This function is called when the handshake has been | |||
| completed. This is used to start data flowing, and to | |||
| do any necessary connection work.''' | |||
| raise NotImplementedError | |||
| def plaintextConnected(self, endpoint): | |||
| '''Connect the plain text endpoint to the factory. All the | |||
| decrypted data will be written to this protocol, | |||
| (specifically, it's transport).''' | |||
| self.endpoint = endpoint | |||
| self.transport.resumeProducing() | |||
| class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | |||
| mode = 'resp' | |||
| def handshakeFinished(self): | |||
| self.transport.pauseProducing() | |||
| @@ -131,6 +149,9 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | |||
| epdef = ep.connect(ClientProxyFactory(self)) | |||
| epdef.addCallback(self.plaintextConnected) | |||
| class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | |||
| mode = 'init' | |||
| class ClientProxyProtocol(twisted.internet.protocol.Protocol): | |||
| def dataReceived(self, data): | |||
| self.factory.noiseproto.encData(data) | |||
| @@ -144,8 +165,8 @@ class ClientProxyFactory(Factory): | |||
| class TwistedNoiseServerFactory(Factory): | |||
| protocol = TwistedNoiseServerProtocol | |||
| def __init__(self, server_key, endpoint): | |||
| self.server_key = server_key | |||
| def __init__(self, priv_key, endpoint): | |||
| self.priv_key = priv_key | |||
| self.endpoint = endpoint | |||
| class TNServerTest(unittest.TestCase): | |||
| @@ -182,7 +203,7 @@ class TNServerTest(unittest.TestCase): | |||
| self.listenportobj = lpobj | |||
| self.endpoint = 'unix:path=%s' % sockpath | |||
| factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) | |||
| factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) | |||
| self.proto = factory.buildProtocol(None) | |||
| self.tr = proto_helpers.StringTransport() | |||
| self.proto.makeConnection(self.tr) | |||