- from twisted.trial import unittest
- from twisted.test import proto_helpers
- from noise.connection import NoiseConnection, Keypair
- from twisted.internet.protocol import Factory
- # XXX - shouldn't need to access the underlying primitives, but that's what
- # noiseprotocol module requires.
- from cryptography.hazmat.primitives.asymmetric import x448
- from cryptography.hazmat.primitives import serialization
- # Notes:
- # Using XK, so that the connecting party's identity is hidden and that the
- # server's party's key is known.
- import twisted.internet.protocol
- def genkeypair():
- '''Generates a keypair, and returns a tuple of (public, private).
- They are encoded as raw bytes, and sutible for use w/ Noise.'''
- key = x448.X448PrivateKey.generate()
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
- return pub, priv
- class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
- def connectionMade(self):
- # Initialize Noise
- noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- self.noise = noise
- noise.set_as_responder()
- noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
- # Start Handshake
- noise.start_handshake()
- def dataReceived(self, data):
- if not self.noise.handshake_finished:
- self.noise.read_message(data)
- if not self.noise.handshake_finished:
- self.transport.write(self.noise.write_message())
- else:
- r = self.noise.decrypt(data)
- # echo it
- self.transport.write(self.noise.encrypt(r))
- class TwistedNoiseServerFactory(Factory):
- protocol = TwistedNoiseServerProtocol
- def __init__(self, server_key):
- self.server_key = server_key
- class TNServerTest(unittest.TestCase):
- def setUp(self):
- self.server_key_pair = genkeypair()
- factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1])
- self.proto = factory.buildProtocol(('', 0))
- self.tr = proto_helpers.StringTransport()
- self.proto.makeConnection(self.tr)
- self.client_key_pair = genkeypair()
- def test_testprotocol(self):
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.start_handshake()
- # Send first message
- message = proto.write_message()
- self.proto.dataReceived(message)
- # Get response
- resp = self.tr.value()
- self.tr.clear()
- # And process it
- proto.read_message(resp)
- # Send second message
- message = proto.write_message()
- self.proto.dataReceived(message)
- # Finish handshake
- self.assertTrue(proto.handshake_finished)
- ptmsg = b'this is a test message'
- encmsg = proto.encrypt(ptmsg)
- self.proto.dataReceived(encmsg)
- # Get echo
- resp = self.tr.value()
- self.tr.clear()
- ptresp = proto.decrypt(resp)
- self.assertEqual(ptresp, ptmsg)