|
|
@@ -0,0 +1,111 @@ |
|
|
|
from twisted.trial import unittest |
|
|
|
from twisted.test import proto_helpers |
|
|
|
from noise.connection import NoiseConnection, Keypair |
|
|
|
from twisted.internet.protocol import Factory |
|
|
|
# XXX - shouldn't need to access the underlying primitives, but that's what |
|
|
|
# noiseprotocol module requires. |
|
|
|
from cryptography.hazmat.primitives.asymmetric import x448 |
|
|
|
from cryptography.hazmat.primitives import serialization |
|
|
|
|
|
|
|
# Notes: |
|
|
|
# Using XK, so that the connecting party's identity is hidden and that the |
|
|
|
# server's party's key is known. |
|
|
|
|
|
|
|
import twisted.internet.protocol |
|
|
|
|
|
|
|
def genkeypair(): |
|
|
|
'''Generates a keypair, and returns a tuple of (public, private). |
|
|
|
They are encoded as raw bytes, and sutible for use w/ Noise.''' |
|
|
|
|
|
|
|
key = x448.X448PrivateKey.generate() |
|
|
|
|
|
|
|
enc = serialization.Encoding.Raw |
|
|
|
pubformat = serialization.PublicFormat.Raw |
|
|
|
privformat = serialization.PrivateFormat.Raw |
|
|
|
encalgo = serialization.NoEncryption() |
|
|
|
|
|
|
|
pub = key.public_key().public_bytes(encoding=enc, format=pubformat) |
|
|
|
priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) |
|
|
|
|
|
|
|
return pub, priv |
|
|
|
|
|
|
|
class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol): |
|
|
|
def connectionMade(self): |
|
|
|
# Initialize Noise |
|
|
|
noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
self.noise = noise |
|
|
|
noise.set_as_responder() |
|
|
|
noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key) |
|
|
|
|
|
|
|
# Start Handshake |
|
|
|
noise.start_handshake() |
|
|
|
|
|
|
|
def dataReceived(self, data): |
|
|
|
if not self.noise.handshake_finished: |
|
|
|
self.noise.read_message(data) |
|
|
|
if not self.noise.handshake_finished: |
|
|
|
self.transport.write(self.noise.write_message()) |
|
|
|
else: |
|
|
|
r = self.noise.decrypt(data) |
|
|
|
# echo it |
|
|
|
self.transport.write(self.noise.encrypt(r)) |
|
|
|
|
|
|
|
class TwistedNoiseServerFactory(Factory): |
|
|
|
protocol = TwistedNoiseServerProtocol |
|
|
|
|
|
|
|
def __init__(self, server_key): |
|
|
|
self.server_key = server_key |
|
|
|
|
|
|
|
class TNServerTest(unittest.TestCase): |
|
|
|
def setUp(self): |
|
|
|
self.server_key_pair = genkeypair() |
|
|
|
|
|
|
|
factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1]) |
|
|
|
self.proto = factory.buildProtocol(('127.0.0.1', 0)) |
|
|
|
self.tr = proto_helpers.StringTransport() |
|
|
|
self.proto.makeConnection(self.tr) |
|
|
|
|
|
|
|
self.client_key_pair = genkeypair() |
|
|
|
|
|
|
|
def test_testprotocol(self): |
|
|
|
# Create client |
|
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
proto.set_as_initiator() |
|
|
|
|
|
|
|
# Setup required keys |
|
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) |
|
|
|
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) |
|
|
|
|
|
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) |
|
|
|
proto.start_handshake() |
|
|
|
|
|
|
|
# Send first message |
|
|
|
message = proto.write_message() |
|
|
|
self.proto.dataReceived(message) |
|
|
|
|
|
|
|
# Get response |
|
|
|
resp = self.tr.value() |
|
|
|
self.tr.clear() |
|
|
|
|
|
|
|
# And process it |
|
|
|
proto.read_message(resp) |
|
|
|
|
|
|
|
# Send second message |
|
|
|
message = proto.write_message() |
|
|
|
self.proto.dataReceived(message) |
|
|
|
|
|
|
|
# Finish handshake |
|
|
|
self.assertTrue(proto.handshake_finished) |
|
|
|
|
|
|
|
ptmsg = b'this is a test message' |
|
|
|
encmsg = proto.encrypt(ptmsg) |
|
|
|
|
|
|
|
self.proto.dataReceived(encmsg) |
|
|
|
|
|
|
|
# Get echo |
|
|
|
resp = self.tr.value() |
|
|
|
self.tr.clear() |
|
|
|
|
|
|
|
ptresp = proto.decrypt(resp) |
|
|
|
|
|
|
|
self.assertEqual(ptresp, ptmsg) |