From 01b664048eabbbc4c63ed910f5be8255ed39d664 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Tue, 22 Oct 2019 11:43:50 -0700 Subject: [PATCH] minor rearchitecting to support client.. --- twistednoise.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/twistednoise.py b/twistednoise.py index e41ba29..9f4e5b2 100644 --- a/twistednoise.py +++ b/twistednoise.py @@ -76,7 +76,7 @@ def genkeypair(): return pub, priv -class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): +class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): '''This class acts as a Noise Protocol responder. The factory that creates this Protocol is required to have the properties server_key and endpoint. @@ -109,21 +109,28 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): if not self.noise.handshake_finished: self.transport.write(self.noise.write_message()) if self.noise.handshake_finished: - self.transport.pauseProducing() - - # start the connection to the endpoint - ep = endpoints.clientFromString(reactor, self.factory.endpoint) - epdef = ep.connect(ClientProxyFactory(self)) - epdef.addCallback(self.proxyConnected) + self.handshakeFinished() else: r = self.noise.decrypt(data) self.endpoint.transport.write(r) - def proxyConnected(self, endpoint): + def handshakeFinished(self): + raise NotImplementedError + + def plaintextConnected(self, endpoint): self.endpoint = endpoint self.transport.resumeProducing() +class TwistedNoiseServerProtocol(TwistedNoiseProtocol): + def handshakeFinished(self): + self.transport.pauseProducing() + + # start the connection to the endpoint + ep = endpoints.clientFromString(reactor, self.factory.endpoint) + epdef = ep.connect(ClientProxyFactory(self)) + epdef.addCallback(self.plaintextConnected) + class ClientProxyProtocol(twisted.internet.protocol.Protocol): def dataReceived(self, data): self.factory.noiseproto.encData(data) @@ -189,7 +196,7 @@ class TNServerTest(unittest.TestCase): self.tempdir = None @defer.inlineCallbacks - def test_testprotocol(self): + def test_testserver(self): # # How this test is plumbed: #