from twisted.trial import unittest from twisted.test import proto_helpers from noise.connection import NoiseConnection, Keypair from twisted.internet.protocol import Factory # XXX - shouldn't need to access the underlying primitives, but that's what # noiseprotocol module requires. from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives import serialization # Notes: # Using XK, so that the connecting party's identity is hidden and that the # server's party's key is known. import twisted.internet.protocol def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): def connectionMade(self): # Initialize Noise noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') self.noise = noise noise.set_as_responder() noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) # Start Handshake noise.start_handshake() def dataReceived(self, data): if not self.noise.handshake_finished: self.noise.read_message(data) if not self.noise.handshake_finished: self.transport.write(self.noise.write_message()) else: r = self.noise.decrypt(data) # echo it self.transport.write(self.noise.encrypt(r)) class TwistedNoiseServerFactory(Factory): protocol = TwistedNoiseServerProtocol def __init__(self, server_key): self.server_key = server_key class TNServerTest(unittest.TestCase): def setUp(self): self.server_key_pair = genkeypair() factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1]) self.proto = factory.buildProtocol(('127.0.0.1', 0)) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) self.client_key_pair = genkeypair() def test_testprotocol(self): # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.start_handshake() # Send first message message = proto.write_message() self.proto.dataReceived(message) # Get response resp = self.tr.value() self.tr.clear() # And process it proto.read_message(resp) # Send second message message = proto.write_message() self.proto.dataReceived(message) # Finish handshake self.assertTrue(proto.handshake_finished) ptmsg = b'this is a test message' encmsg = proto.encrypt(ptmsg) self.proto.dataReceived(encmsg) # Get echo resp = self.tr.value() self.tr.clear() ptresp = proto.decrypt(resp) self.assertEqual(ptresp, ptmsg)