Browse Source

minor rework, and add some documentation...

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
29ae6fec65
1 changed files with 27 additions and 6 deletions
  1. +27
    -6
      twistednoise.py

+ 27
- 6
twistednoise.py View File

@@ -94,16 +94,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
# Initialize Noise # Initialize Noise
noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
self.noise = noise self.noise = noise
noise.set_as_responder()
noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
if self.mode == 'resp':
noise.set_as_responder()

noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)


# Start Handshake # Start Handshake
noise.start_handshake() noise.start_handshake()


def encData(self, data): def encData(self, data):
'''Receive plain text data, encrypt it, and send it down the
wire.'''

self.transport.write(self.noise.encrypt(data)) self.transport.write(self.noise.encrypt(data))


def dataReceived(self, data): def dataReceived(self, data):
'''Receive encrypted data, and write it to the endpoint that
was connected via the plaintextConnected method.'''

if not self.noise.handshake_finished: if not self.noise.handshake_finished:
self.noise.read_message(data) self.noise.read_message(data)
if not self.noise.handshake_finished: if not self.noise.handshake_finished:
@@ -115,14 +123,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):


self.endpoint.transport.write(r) self.endpoint.transport.write(r)


def handshakeFinished(self):
def handshakeFinished(self): # pragma: no cover
'''This function is called when the handshake has been
completed. This is used to start data flowing, and to
do any necessary connection work.'''

raise NotImplementedError raise NotImplementedError


def plaintextConnected(self, endpoint): def plaintextConnected(self, endpoint):
'''Connect the plain text endpoint to the factory. All the
decrypted data will be written to this protocol,
(specifically, it's transport).'''

self.endpoint = endpoint self.endpoint = endpoint
self.transport.resumeProducing() self.transport.resumeProducing()


class TwistedNoiseServerProtocol(TwistedNoiseProtocol): class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
mode = 'resp'

def handshakeFinished(self): def handshakeFinished(self):
self.transport.pauseProducing() self.transport.pauseProducing()


@@ -131,6 +149,9 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
epdef = ep.connect(ClientProxyFactory(self)) epdef = ep.connect(ClientProxyFactory(self))
epdef.addCallback(self.plaintextConnected) epdef.addCallback(self.plaintextConnected)


class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
mode = 'init'

class ClientProxyProtocol(twisted.internet.protocol.Protocol): class ClientProxyProtocol(twisted.internet.protocol.Protocol):
def dataReceived(self, data): def dataReceived(self, data):
self.factory.noiseproto.encData(data) self.factory.noiseproto.encData(data)
@@ -144,8 +165,8 @@ class ClientProxyFactory(Factory):
class TwistedNoiseServerFactory(Factory): class TwistedNoiseServerFactory(Factory):
protocol = TwistedNoiseServerProtocol protocol = TwistedNoiseServerProtocol


def __init__(self, server_key, endpoint):
self.server_key = server_key
def __init__(self, priv_key, endpoint):
self.priv_key = priv_key
self.endpoint = endpoint self.endpoint = endpoint


class TNServerTest(unittest.TestCase): class TNServerTest(unittest.TestCase):
@@ -182,7 +203,7 @@ class TNServerTest(unittest.TestCase):
self.listenportobj = lpobj self.listenportobj = lpobj
self.endpoint = 'unix:path=%s' % sockpath self.endpoint = 'unix:path=%s' % sockpath


factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)
self.proto = factory.buildProtocol(None) self.proto = factory.buildProtocol(None)
self.tr = proto_helpers.StringTransport() self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr) self.proto.makeConnection(self.tr)


Loading…
Cancel
Save