| @@ -94,16 +94,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): | |||||
| # Initialize Noise | # Initialize Noise | ||||
| noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | ||||
| self.noise = noise | self.noise = noise | ||||
| noise.set_as_responder() | |||||
| noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) | |||||
| if self.mode == 'resp': | |||||
| noise.set_as_responder() | |||||
| noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) | |||||
| # Start Handshake | # Start Handshake | ||||
| noise.start_handshake() | noise.start_handshake() | ||||
| def encData(self, data): | def encData(self, data): | ||||
| '''Receive plain text data, encrypt it, and send it down the | |||||
| wire.''' | |||||
| self.transport.write(self.noise.encrypt(data)) | self.transport.write(self.noise.encrypt(data)) | ||||
| def dataReceived(self, data): | def dataReceived(self, data): | ||||
| '''Receive encrypted data, and write it to the endpoint that | |||||
| was connected via the plaintextConnected method.''' | |||||
| if not self.noise.handshake_finished: | if not self.noise.handshake_finished: | ||||
| self.noise.read_message(data) | self.noise.read_message(data) | ||||
| if not self.noise.handshake_finished: | if not self.noise.handshake_finished: | ||||
| @@ -115,14 +123,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): | |||||
| self.endpoint.transport.write(r) | self.endpoint.transport.write(r) | ||||
| def handshakeFinished(self): | |||||
| def handshakeFinished(self): # pragma: no cover | |||||
| '''This function is called when the handshake has been | |||||
| completed. This is used to start data flowing, and to | |||||
| do any necessary connection work.''' | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def plaintextConnected(self, endpoint): | def plaintextConnected(self, endpoint): | ||||
| '''Connect the plain text endpoint to the factory. All the | |||||
| decrypted data will be written to this protocol, | |||||
| (specifically, it's transport).''' | |||||
| self.endpoint = endpoint | self.endpoint = endpoint | ||||
| self.transport.resumeProducing() | self.transport.resumeProducing() | ||||
| class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | ||||
| mode = 'resp' | |||||
| def handshakeFinished(self): | def handshakeFinished(self): | ||||
| self.transport.pauseProducing() | self.transport.pauseProducing() | ||||
| @@ -131,6 +149,9 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): | |||||
| epdef = ep.connect(ClientProxyFactory(self)) | epdef = ep.connect(ClientProxyFactory(self)) | ||||
| epdef.addCallback(self.plaintextConnected) | epdef.addCallback(self.plaintextConnected) | ||||
| class TwistedNoiseClientProtocol(TwistedNoiseProtocol): | |||||
| mode = 'init' | |||||
| class ClientProxyProtocol(twisted.internet.protocol.Protocol): | class ClientProxyProtocol(twisted.internet.protocol.Protocol): | ||||
| def dataReceived(self, data): | def dataReceived(self, data): | ||||
| self.factory.noiseproto.encData(data) | self.factory.noiseproto.encData(data) | ||||
| @@ -144,8 +165,8 @@ class ClientProxyFactory(Factory): | |||||
| class TwistedNoiseServerFactory(Factory): | class TwistedNoiseServerFactory(Factory): | ||||
| protocol = TwistedNoiseServerProtocol | protocol = TwistedNoiseServerProtocol | ||||
| def __init__(self, server_key, endpoint): | |||||
| self.server_key = server_key | |||||
| def __init__(self, priv_key, endpoint): | |||||
| self.priv_key = priv_key | |||||
| self.endpoint = endpoint | self.endpoint = endpoint | ||||
| class TNServerTest(unittest.TestCase): | class TNServerTest(unittest.TestCase): | ||||
| @@ -182,7 +203,7 @@ class TNServerTest(unittest.TestCase): | |||||
| self.listenportobj = lpobj | self.listenportobj = lpobj | ||||
| self.endpoint = 'unix:path=%s' % sockpath | self.endpoint = 'unix:path=%s' % sockpath | ||||
| factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) | |||||
| factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) | |||||
| self.proto = factory.buildProtocol(None) | self.proto = factory.buildProtocol(None) | ||||
| self.tr = proto_helpers.StringTransport() | self.tr = proto_helpers.StringTransport() | ||||
| self.proto.makeConnection(self.tr) | self.proto.makeConnection(self.tr) | ||||