Browse Source

minor rework, and add some documentation...

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
29ae6fec65
1 changed files with 27 additions and 6 deletions
  1. +27
    -6
      twistednoise.py

+ 27
- 6
twistednoise.py View File

@@ -94,16 +94,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
# Initialize Noise
noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
self.noise = noise
noise.set_as_responder()
noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
if self.mode == 'resp':
noise.set_as_responder()

noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)

# Start Handshake
noise.start_handshake()

def encData(self, data):
'''Receive plain text data, encrypt it, and send it down the
wire.'''

self.transport.write(self.noise.encrypt(data))

def dataReceived(self, data):
'''Receive encrypted data, and write it to the endpoint that
was connected via the plaintextConnected method.'''

if not self.noise.handshake_finished:
self.noise.read_message(data)
if not self.noise.handshake_finished:
@@ -115,14 +123,24 @@ class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):

self.endpoint.transport.write(r)

def handshakeFinished(self):
def handshakeFinished(self): # pragma: no cover
'''This function is called when the handshake has been
completed. This is used to start data flowing, and to
do any necessary connection work.'''

raise NotImplementedError

def plaintextConnected(self, endpoint):
'''Connect the plain text endpoint to the factory. All the
decrypted data will be written to this protocol,
(specifically, it's transport).'''

self.endpoint = endpoint
self.transport.resumeProducing()

class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
mode = 'resp'

def handshakeFinished(self):
self.transport.pauseProducing()

@@ -131,6 +149,9 @@ class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
epdef = ep.connect(ClientProxyFactory(self))
epdef.addCallback(self.plaintextConnected)

class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
mode = 'init'

class ClientProxyProtocol(twisted.internet.protocol.Protocol):
def dataReceived(self, data):
self.factory.noiseproto.encData(data)
@@ -144,8 +165,8 @@ class ClientProxyFactory(Factory):
class TwistedNoiseServerFactory(Factory):
protocol = TwistedNoiseServerProtocol

def __init__(self, server_key, endpoint):
self.server_key = server_key
def __init__(self, priv_key, endpoint):
self.priv_key = priv_key
self.endpoint = endpoint

class TNServerTest(unittest.TestCase):
@@ -182,7 +203,7 @@ class TNServerTest(unittest.TestCase):
self.listenportobj = lpobj
self.endpoint = 'unix:path=%s' % sockpath

factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
factory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)
self.proto = factory.buildProtocol(None)
self.tr = proto_helpers.StringTransport()
self.proto.makeConnection(self.tr)


Loading…
Cancel
Save