from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.serialization import load_pem_private_key from noise.connection import NoiseConnection, Keypair import tracemalloc; tracemalloc.start() import argparse import asyncio import base64 import os.path import shutil import socket import sys import tempfile import threading import unittest _backend = default_backend() def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv def _makefut(obj): loop = asyncio.get_running_loop() fut = loop.create_future() fut.set_result(obj) return fut def _makeunix(path): '''Make a properly formed unix path socket string.''' return 'unix:%s' % path def _parsesockstr(sockstr): proto, rem = sockstr.split(':', 1) return proto, rem async def connectsockstr(sockstr): proto, rem = _parsesockstr(sockstr) reader, writer = await asyncio.open_unix_connection(rem) return reader, writer async def listensockstr(sockstr, cb): '''Wrapper for asyncio.start_x_server. The format of sockstr is: 'proto:param=value[,param2=value2]'. If the proto has a default parameter, the value can be used directly, like: 'proto:value'. This is only allowed when the value can unambiguously be determined not to be a param. The characters that define 'param' must be all lower case ascii characters and may contain an underscore. The first character must not be and underscore. Supported protocols: unix: Default parameter is path. The path parameter specifies the path to the unix domain socket. The path MUST start w/ a slash if it is used as a default parameter. ''' proto, rem = _parsesockstr(sockstr) server = await asyncio.start_unix_server(cb, path=rem) return server # !!python makemessagelengths.py _handshakelens = \ [72, 72, 88] def _genciphfun(hash, ad): hkdf = HKDF(algorithm=hashes.SHA256(), length=32, salt=b'asdoifjsldkjdsf', info=ad, backend=_backend) key = hkdf.derive(hash) cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=_backend) enctor = cipher.encryptor() def encfun(data): # Returns the two bytes for length val = len(data) encbytes = enctor.update(data[:16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return (val ^ mask).to_bytes(length=2, byteorder='big') def decfun(data): # takes off the data and returns the total # length val = int.from_bytes(data[:2], byteorder='big') encbytes = enctor.update(data[2:2 + 16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return val ^ mask return encfun, decfun async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): rdr, wrr = await rdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) if pub_key is not None: proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) if mode == 'resp': proto.set_as_responder() proto.start_handshake() proto.read_message(await rdr.readexactly(_handshakelens[0])) wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[2])) elif mode == 'init': proto.set_as_initiator() proto.start_handshake() wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[1])) wrr.write(proto.write_message()) if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') # generate the keys for lengths if mode == 'resp': _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') elif mode == 'init': enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') reader, writer = await ptpair async def decses(): try: while True: try: msg = await rdr.readexactly(2 + 16) except asyncio.streams.IncompleteReadError: if rdr.at_eof(): return 'dec' tlen = declenfun(msg) rmsg = await rdr.readexactly(tlen - 16) tmsg = msg[2:] + rmsg writer.write(proto.decrypt(tmsg)) await writer.drain() #except: # import traceback # traceback.print_exc() # raise finally: writer.write_eof() async def encses(): try: while True: # largest message ptmsg = await reader.read(65535 - 16) if not ptmsg: # eof return 'enc' encmsg = proto.encrypt(ptmsg) wrr.write(enclenfun(encmsg)) wrr.write(encmsg) await wrr.drain() #except: # import traceback # traceback.print_exc() # raise finally: wrr.write_eof() return await asyncio.gather(decses(), encses()) # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code # Slightly modified to timeout def async_test(f): def wrapper(*args, **kwargs): coro = asyncio.coroutine(f) future = coro(*args, **kwargs) loop = asyncio.get_event_loop() # timeout after 2 seconds loop.run_until_complete(asyncio.wait_for(future, 2)) return wrapper class Tests_misc(unittest.TestCase): def test_listensockstr(self): # XXX write test pass def test_genciphfun(self): enc, dec = _genciphfun(b'0' * 32, b'foobar') msg = b'this is a bunch of data' tb = enc(msg) self.assertEqual(len(msg), dec(tb + msg)) for i in [ 20, 1384, 64000, 23839, 65535 ]: msg = os.urandom(i) self.assertEqual(len(msg), dec(enc(msg) + msg)) def cmd_genkey(args): keypair = genkeypair() key = x448.X448PrivateKey.generate() # public key part enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw pub = key.public_key().public_bytes(encoding=enc, format=pubformat) try: fname = args.fname + '.pub' with open(fname, 'x', encoding='ascii') as fp: print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp) except FileExistsError: print('failed to create %s, file exists.' % fname, file=sys.stderr) sys.exit(1) enc = serialization.Encoding.PEM format = serialization.PrivateFormat.PKCS8 encalgo = serialization.NoEncryption() with open(args.fname, 'x', encoding='ascii') as fp: fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii')) def main(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help') parser_gk = subparsers.add_parser('genkey', help='generate keys') parser_gk.add_argument('fname', type=str, help='file name for the key') parser_gk.set_defaults(func=cmd_genkey) args = parser.parse_args() try: fun = args.func except AttributeError: parser.print_usage() sys.exit(5) fun(args) if __name__ == '__main__': # pragma: no cover main() def _asyncsockpair(): '''Create a pair of sockets that are bound to each other. The function will return a tuple of two coroutine's, that each, when await'ed upon, will return the reader/writer pair.''' socka, sockb = socket.socketpair() return asyncio.open_connection(sock=socka), \ asyncio.open_connection(sock=sockb) class TestMain(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() os.chdir(self.tempdir) def tearDown(self): shutil.rmtree(self.basetempdir) self.tempdir = None def test_noargs(self): sys.argv = [ 'prog' ] with self.assertRaises(SystemExit) as cm: main() # XXX - not checking error message # And that it exited w/ the correct code self.assertEqual(5, cm.exception.code) def test_genkey(self): # that it can generate a key sys.argv = [ 'prog', 'genkey', 'somefile' ] main() with open('somefile.pub', encoding='ascii') as fp: lines = fp.readlines() self.assertEqual(len(lines), 1) keytype, keyvalue = lines[0].split() self.assertEqual(keytype, 'ntun-x448') key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) with open('somefile', encoding='ascii') as fp: data = fp.read().encode('ascii') key = load_pem_private_key(data, password=None, backend=default_backend()) self.assertIsInstance(key, x448.X448PrivateKey) # that a second call fails with self.assertRaises(SystemExit) as cm: main() # XXX - not checking error message # And that it exited w/ the correct code self.assertEqual(1, cm.exception.code) class TestNoiseFowarder(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() def tearDown(self): shutil.rmtree(self.basetempdir) self.tempdir = None @async_test async def test_server(self): # Test is plumbed: # (reader, writer) -> servsock -> # (rdr, wrr) NoiseForward (reader, writer) -> # servptsock -> (ptsock[0], ptsock[1]) # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') servarg = _makeunix(servsockpath) # Path that the server will send pt data to servptpath = os.path.join(self.tempdir, 'servptsock') # Setup pt target listener pttarg = _makeunix(servptpath) ptsock = [] def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) # Bind to pt listener lsock = await listensockstr(pttarg, ptsockaccept) nfs = [] event = asyncio.Event() async def runnf(rdr, wrr): ptpair = asyncio.create_task(connectsockstr(pttarg)) a = await NoiseForwarder('resp', _makefut((rdr, wrr)), ptpair, priv_key=self.server_key_pair[1]) nfs.append(a) event.set() # Setup server listener ssock = await listensockstr(servarg, runnf) # Connect to server reader, writer = await connectsockstr(servarg) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.assertEqual(len(message), _handshakelens[0]) writer.write(message) # Get response respmsg = await reader.readexactly(_handshakelens[1]) proto.read_message(respmsg) # Send final reply message = proto.write_message() writer.write(message) # Make sure handshake has completed self.assertTrue(proto.handshake_finished) # generate the keys for lengths enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') # write a test message ptmsg = b'this is a test message that should be a little in length' encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # XXX - how to sync? await asyncio.sleep(.1) ptreader, ptwriter = ptsock[0] # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # write a different message ptmsg = os.urandom(2843) encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # XXX - how to sync? await asyncio.sleep(.1) # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # now try the other way ptmsg = os.urandom(912) ptwriter.write(ptmsg) # find out how much we need to read encmsg = await reader.readexactly(2 + 16) tlen = declenfun(encmsg) # read the rest of the message rencmsg = await reader.readexactly(tlen - 16) tmsg = encmsg[2:] + rencmsg rptmsg = proto.decrypt(tmsg) self.assertEqual(rptmsg, ptmsg) # shut down sending writer.write_eof() # so pt reader should be shut down self.assertEqual(b'', await ptreader.read(1)) self.assertTrue(ptreader.at_eof()) # shut down pt ptwriter.write_eof() # make sure the enc reader is eof self.assertEqual(b'', await reader.read(1)) self.assertTrue(reader.at_eof()) await event.wait() self.assertEqual(nfs[0], [ 'dec', 'enc' ]) @async_test async def test_serverclient(self): # plumbing: # # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb # ptcsockapair, ptcsockbpair = _asyncsockpair() ptcareader, ptcawriter = await ptcsockapair #ptcsockbpair passed directly clssockapair, clssockbpair = _asyncsockpair() #both passed directly ptssockapair, ptssockbpair = _asyncsockpair() #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair clientnf = asyncio.create_task(NoiseForwarder('init', clssockapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) servnf = asyncio.create_task(NoiseForwarder('resp', clssockbpair, ptssockapair, priv_key=self.server_key_pair[1])) # send a message msga = os.urandom(183) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a second message msga = os.urandom(2834) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message larger than the block size msga = os.urandom(103958) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message the other direction msga = os.urandom(103958) ptsbwriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptcareader.readexactly(len(msga))) # close down the pt writers, the rest should follow ptsbwriter.write_eof() ptcawriter.write_eof() # make sure they are closed, and there is no more data self.assertEqual(b'', await ptsbreader.read(1)) self.assertTrue(ptsbreader.at_eof()) self.assertEqual(b'', await ptcareader.read(1)) self.assertTrue(ptcareader.at_eof()) self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf)