diff --git a/twistednoise.py b/twistednoise.py index 9f4e5b2..6873479 100644 --- a/twistednoise.py +++ b/twistednoise.py @@ -94,16 +94,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): # Initialize Noise noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') self.noise = noise - noise.set_as_responder() - noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) + if self.mode == 'resp': + noise.set_as_responder() + + noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) # Start Handshake noise.start_handshake() def encData(self, data): + '''Receive plain text data, encrypt it, and send it down the + wire.''' + self.transport.write(self.noise.encrypt(data)) def dataReceived(self, data): + '''Receive encrypted data, and write it to the endpoint that + was connected via the plaintextConnected method.''' + if not self.noise.handshake_finished: self.noise.read_message(data) if not self.noise.handshake_finished: @@ -115,14 +123,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): self.endpoint.transport.write(r) - def handshakeFinished(self): + def handshakeFinished(self): # pragma: no cover + '''This function is called when the handshake has been + completed. This is used to start data flowing, and to + do any necessary connection work.''' + raise NotImplementedError def plaintextConnected(self, endpoint): + '''Connect the plain text endpoint to the factory. All the + decrypted data will be written to this protocol, + (specifically, it's transport).''' + self.endpoint = endpoint self.transport.resumeProducing() class TwistedNoiseServerProtocol(TwistedNoiseProtocol): + mode = 'resp' + def handshakeFinished(self): self.transport.pauseProducing() @@ -131,6 +149,9 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): epdef = ep.connect(ClientProxyFactory(self)) epdef.addCallback(self.plaintextConnected) +class TwistedNoiseClientProtocol(TwistedNoiseProtocol): + mode = 'init' + class ClientProxyProtocol(twisted.internet.protocol.Protocol): def dataReceived(self, data): self.factory.noiseproto.encData(data) @@ -144,8 +165,8 @@ class ClientProxyFactory(Factory): class TwistedNoiseServerFactory(Factory): protocol = TwistedNoiseServerProtocol - def __init__(self, server_key, endpoint): - self.server_key = server_key + def __init__(self, priv_key, endpoint): + self.priv_key = priv_key self.endpoint = endpoint class TNServerTest(unittest.TestCase): @@ -182,7 +203,7 @@ class TNServerTest(unittest.TestCase): self.listenportobj = lpobj self.endpoint = 'unix:path=%s' % sockpath - factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) + factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) self.proto = factory.buildProtocol(None) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr)