|
|
@@ -94,16 +94,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): |
|
|
|
# Initialize Noise |
|
|
|
noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
self.noise = noise |
|
|
|
noise.set_as_responder() |
|
|
|
noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) |
|
|
|
if self.mode == 'resp': |
|
|
|
noise.set_as_responder() |
|
|
|
|
|
|
|
noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key) |
|
|
|
|
|
|
|
# Start Handshake |
|
|
|
noise.start_handshake() |
|
|
|
|
|
|
|
def encData(self, data): |
|
|
|
'''Receive plain text data, encrypt it, and send it down the |
|
|
|
wire.''' |
|
|
|
|
|
|
|
self.transport.write(self.noise.encrypt(data)) |
|
|
|
|
|
|
|
def dataReceived(self, data): |
|
|
|
'''Receive encrypted data, and write it to the endpoint that |
|
|
|
was connected via the plaintextConnected method.''' |
|
|
|
|
|
|
|
if not self.noise.handshake_finished: |
|
|
|
self.noise.read_message(data) |
|
|
|
if not self.noise.handshake_finished: |
|
|
@@ -115,14 +123,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol): |
|
|
|
|
|
|
|
self.endpoint.transport.write(r) |
|
|
|
|
|
|
|
def handshakeFinished(self): |
|
|
|
def handshakeFinished(self): # pragma: no cover |
|
|
|
'''This function is called when the handshake has been |
|
|
|
completed. This is used to start data flowing, and to |
|
|
|
do any necessary connection work.''' |
|
|
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def plaintextConnected(self, endpoint): |
|
|
|
'''Connect the plain text endpoint to the factory. All the |
|
|
|
decrypted data will be written to this protocol, |
|
|
|
(specifically, it's transport).''' |
|
|
|
|
|
|
|
self.endpoint = endpoint |
|
|
|
self.transport.resumeProducing() |
|
|
|
|
|
|
|
class TwistedNoiseServerProtocol(TwistedNoiseProtocol): |
|
|
|
mode = 'resp' |
|
|
|
|
|
|
|
def handshakeFinished(self): |
|
|
|
self.transport.pauseProducing() |
|
|
|
|
|
|
@@ -131,6 +149,9 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol): |
|
|
|
epdef = ep.connect(ClientProxyFactory(self)) |
|
|
|
epdef.addCallback(self.plaintextConnected) |
|
|
|
|
|
|
|
class TwistedNoiseClientProtocol(TwistedNoiseProtocol): |
|
|
|
mode = 'init' |
|
|
|
|
|
|
|
class ClientProxyProtocol(twisted.internet.protocol.Protocol): |
|
|
|
def dataReceived(self, data): |
|
|
|
self.factory.noiseproto.encData(data) |
|
|
@@ -144,8 +165,8 @@ class ClientProxyFactory(Factory): |
|
|
|
class TwistedNoiseServerFactory(Factory): |
|
|
|
protocol = TwistedNoiseServerProtocol |
|
|
|
|
|
|
|
def __init__(self, server_key, endpoint): |
|
|
|
self.server_key = server_key |
|
|
|
def __init__(self, priv_key, endpoint): |
|
|
|
self.priv_key = priv_key |
|
|
|
self.endpoint = endpoint |
|
|
|
|
|
|
|
class TNServerTest(unittest.TestCase): |
|
|
@@ -182,7 +203,7 @@ class TNServerTest(unittest.TestCase): |
|
|
|
self.listenportobj = lpobj |
|
|
|
self.endpoint = 'unix:path=%s' % sockpath |
|
|
|
|
|
|
|
factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) |
|
|
|
factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint) |
|
|
|
self.proto = factory.buildProtocol(None) |
|
|
|
self.tr = proto_helpers.StringTransport() |
|
|
|
self.proto.makeConnection(self.tr) |
|
|
|