from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.serialization import load_pem_private_key from noise.connection import NoiseConnection, Keypair #import tracemalloc; tracemalloc.start() import argparse import asyncio import base64 import os.path import shutil import socket import sys import tempfile import time import threading import unittest _backend = default_backend() def loadprivkey(fname): with open(fname, encoding='ascii') as fp: data = fp.read().encode('ascii') key = load_pem_private_key(data, password=None, backend=default_backend()) return key def loadprivkeyraw(fname): key = loadprivkey(fname) enc = serialization.Encoding.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) def loadpubkeyraw(fname): with open(fname, encoding='ascii') as fp: lines = fp.readlines() # XXX #self.assertEqual(len(lines), 1) keytype, keyvalue = lines[0].split() if keytype != 'ntun-x448': raise RuntimeError return base64.urlsafe_b64decode(keyvalue) def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv def _makefut(obj): loop = asyncio.get_running_loop() fut = loop.create_future() fut.set_result(obj) return fut def _makeunix(path): '''Make a properly formed unix path socket string.''' return 'unix:%s' % path def _parsesockstr(sockstr): proto, rem = sockstr.split(':', 1) return proto, rem async def connectsockstr(sockstr): proto, rem = _parsesockstr(sockstr) reader, writer = await asyncio.open_unix_connection(rem) return reader, writer async def listensockstr(sockstr, cb): '''Wrapper for asyncio.start_x_server. The format of sockstr is: 'proto:param=value[,param2=value2]'. If the proto has a default parameter, the value can be used directly, like: 'proto:value'. This is only allowed when the value can unambiguously be determined not to be a param. The characters that define 'param' must be all lower case ascii characters and may contain an underscore. The first character must not be and underscore. Supported protocols: unix: Default parameter is path. The path parameter specifies the path to the unix domain socket. The path MUST start w/ a slash if it is used as a default parameter. ''' proto, rem = _parsesockstr(sockstr) server = await asyncio.start_unix_server(cb, path=rem) return server # !!python makemessagelengths.py _handshakelens = \ [72, 72, 88] def _genciphfun(hash, ad): hkdf = HKDF(algorithm=hashes.SHA256(), length=32, salt=b'asdoifjsldkjdsf', info=ad, backend=_backend) key = hkdf.derive(hash) cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=_backend) enctor = cipher.encryptor() def encfun(data): # Returns the two bytes for length val = len(data) encbytes = enctor.update(data[:16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return (val ^ mask).to_bytes(length=2, byteorder='big') def decfun(data): # takes off the data and returns the total # length val = int.from_bytes(data[:2], byteorder='big') encbytes = enctor.update(data[2:2 + 16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return val ^ mask return encfun, decfun async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): rdr, wrr = await rdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) if pub_key is not None: proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) if mode == 'resp': proto.set_as_responder() proto.start_handshake() proto.read_message(await rdr.readexactly(_handshakelens[0])) wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[2])) elif mode == 'init': proto.set_as_initiator() proto.start_handshake() wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[1])) wrr.write(proto.write_message()) if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') # generate the keys for lengths # XXX - get_handshake_hash is probably not the best option, but # this is only to obscure lengths, it is not required to be secure # as the underlying NoiseProtocol securely validates everything. # It is marginally useful as writing patterns likely expose the # true length. Adding padding could marginally help w/ this. if mode == 'resp': _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') elif mode == 'init': enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') reader, writer = await ptpair async def decses(): try: while True: try: msg = await rdr.readexactly(2 + 16) except asyncio.streams.IncompleteReadError: if rdr.at_eof(): return 'dec' tlen = declenfun(msg) rmsg = await rdr.readexactly(tlen - 16) tmsg = msg[2:] + rmsg writer.write(proto.decrypt(tmsg)) await writer.drain() #except: # import traceback # traceback.print_exc() # raise finally: try: writer.write_eof() except OSError as e: if e.errno != 57: raise async def encses(): try: while True: # largest message ptmsg = await reader.read(65535 - 16) if not ptmsg: # eof return 'enc' encmsg = proto.encrypt(ptmsg) wrr.write(enclenfun(encmsg)) wrr.write(encmsg) await wrr.drain() #except: # import traceback # traceback.print_exc() # raise finally: wrr.write_eof() return await asyncio.gather(decses(), encses()) # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code # Slightly modified to timeout and to print trace back when canceled. # This makes it easier to figure out what "froze". def async_test(f): def wrapper(*args, **kwargs): async def tbcapture(): try: return await f(*args, **kwargs) except asyncio.CancelledError as e: # if we are going to be cancelled, print out a tb import traceback traceback.print_exc() raise loop = asyncio.get_event_loop() # timeout after 4 seconds loop.run_until_complete(asyncio.wait_for(tbcapture(), 4)) return wrapper class Tests_misc(unittest.TestCase): def test_listensockstr(self): # XXX write test pass def test_genciphfun(self): enc, dec = _genciphfun(b'0' * 32, b'foobar') msg = b'this is a bunch of data' tb = enc(msg) self.assertEqual(len(msg), dec(tb + msg)) for i in [ 20, 1384, 64000, 23839, 65535 ]: msg = os.urandom(i) self.assertEqual(len(msg), dec(enc(msg) + msg)) def cmd_client(args): privkey = loadprivkeyraw(args.clientkey) pubkey = loadpubkeyraw(args.servkey) async def runnf(rdr, wrr): encpair = asyncio.create_task(connectsockstr(args.clienttarget)) a = await NoiseForwarder('init', encpair, _makefut((rdr, wrr)), priv_key=privkey, pub_key=pubkey) # Setup client listener ssock = listensockstr(args.clientlisten, runnf) loop = asyncio.get_event_loop() obj = loop.run_until_complete(ssock) loop.run_until_complete(obj.serve_forever()) def cmd_server(args): privkey = loadprivkeyraw(args.servkey) async def runnf(rdr, wrr): ptpair = asyncio.create_task(connectsockstr(args.servtarget)) a = await NoiseForwarder('resp', _makefut((rdr, wrr)), ptpair, priv_key=privkey) # Setup server listener ssock = listensockstr(args.servlisten, runnf) loop = asyncio.get_event_loop() obj = loop.run_until_complete(ssock) loop.run_until_complete(obj.serve_forever()) def cmd_genkey(args): keypair = genkeypair() key = x448.X448PrivateKey.generate() # public key part enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw pub = key.public_key().public_bytes(encoding=enc, format=pubformat) try: fname = args.fname + '.pub' with open(fname, 'x', encoding='ascii') as fp: print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp) except FileExistsError: print('failed to create %s, file exists.' % fname, file=sys.stderr) sys.exit(2) enc = serialization.Encoding.PEM format = serialization.PrivateFormat.PKCS8 encalgo = serialization.NoEncryption() with open(args.fname, 'x', encoding='ascii') as fp: fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii')) def main(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help') parser_gk = subparsers.add_parser('genkey', help='generate keys') parser_gk.add_argument('fname', type=str, help='file name for the key') parser_gk.set_defaults(func=cmd_genkey) parser_serv = subparsers.add_parser('server', help='run a server') parser_serv.add_argument('-c', action='append', type=str, help='file of authorized client keys, or a .pub file') parser_serv.add_argument('servkey', type=str, help='file name for the server key') parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on') parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to') parser_serv.set_defaults(func=cmd_server) parser_client = subparsers.add_parser('client', help='run a client') parser_client.add_argument('clientkey', type=str, help='file name for the client private key') parser_client.add_argument('servkey', type=str, help='file name for the server public key') parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on') parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to') parser_client.set_defaults(func=cmd_client) args = parser.parse_args() try: fun = args.func except AttributeError: parser.print_usage() sys.exit(5) fun(args) if __name__ == '__main__': # pragma: no cover main() def _asyncsockpair(): '''Create a pair of sockets that are bound to each other. The function will return a tuple of two coroutine's, that each, when await'ed upon, will return the reader/writer pair.''' socka, sockb = socket.socketpair() return asyncio.open_connection(sock=socka), \ asyncio.open_connection(sock=sockb) async def _awaitfile(fname): while not os.path.exists(fname): await asyncio.sleep(.01) return True class TestMain(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() os.chdir(self.tempdir) def tearDown(self): #print('td:', time.time()) shutil.rmtree(self.basetempdir) self.tempdir = None @async_test async def test_noargs(self): proc = await self.run_with_args() await proc.wait() # XXX - not checking error message # And that it exited w/ the correct code self.assertEqual(proc.returncode, 5) def run_with_args(self, *args, pipes=True): kwargs = {} if pipes: kwargs.update(dict( stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)) return asyncio.create_subprocess_exec(sys.executable, # XXX - figure out how to add coverage data on these runs #'-m', 'coverage', 'run', '-p', __file__, *args, **kwargs) async def genkey(self, name): proc = await self.run_with_args('genkey', name, pipes=False) await proc.wait() self.assertEqual(proc.returncode, 0) @async_test async def test_loadpubkey(self): keypath = os.path.join(self.tempdir, 'loadpubkeytest') await self.genkey(keypath) privkey = loadprivkey(keypath) enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat) pubkey = loadpubkeyraw(keypath + '.pub') self.assertEqual(pubkeybytes, pubkey) privrawkey = loadprivkeyraw(keypath) enc = serialization.Encoding.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) self.assertEqual(rprivrawkey, privrawkey) @async_test async def test_end2end(self): # Generate necessar keys servkeypath = os.path.join(self.tempdir, 'server_key') await self.genkey(servkeypath) clientkeypath = os.path.join(self.tempdir, 'client_key') await self.genkey(clientkeypath) await asyncio.sleep(.1) #import pdb; pdb.set_trace() # forwards connectsion to this socket (created by client) ptclientpath = os.path.join(self.tempdir, 'incclient.sock') ptclientstr = _makeunix(ptclientpath) # this is the socket server listen to incservpath = os.path.join(self.tempdir, 'incserv.sock') incservstr = _makeunix(incservpath) # to this socket, opened by server servtargpath = os.path.join(self.tempdir, 'servtarget.sock') servtargstr = _makeunix(servtargpath) # Setup server target listener ptsock = [] ptsockevent = asyncio.Event() def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) ptsockevent.set() # Bind to pt listener lsock = await listensockstr(servtargstr, ptsockaccept) # Startup the server server = await self.run_with_args('server', '-c', clientkeypath + '.pub', servkeypath, incservstr, servtargstr, pipes=False) # Startup the client client = await self.run_with_args('client', clientkeypath, servkeypath + '.pub', ptclientstr, incservstr, pipes=False) # wait for server target to be created await _awaitfile(servtargpath) # wait for server to start await _awaitfile(incservpath) # wait for client to start await _awaitfile(ptclientpath) await asyncio.sleep(.1) # Connect to the client reader, writer = await connectsockstr(ptclientstr) # send a message ptmsg = b'this is a message for testing' writer.write(ptmsg) # make sure that we got the conenction await ptsockevent.wait() # get the connection endrdr, endwrr = ptsock[0] # make sure we can read back what we sent self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg))) # test some additional messages for i in [ 129, 1287, 28792, 129872 ]: # in on direction msg = os.urandom(i) writer.write(msg) self.assertEqual(msg, await endrdr.readexactly(len(msg))) # and the other endwrr.write(msg) self.assertEqual(msg, await reader.readexactly(len(msg))) @async_test async def test_genkey(self): # that it can generate a key proc = await self.run_with_args('genkey', 'somefile') await proc.wait() #print(await proc.communicate()) self.assertEqual(proc.returncode, 0) with open('somefile.pub', encoding='ascii') as fp: lines = fp.readlines() self.assertEqual(len(lines), 1) keytype, keyvalue = lines[0].split() self.assertEqual(keytype, 'ntun-x448') key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) key = loadprivkey('somefile') self.assertIsInstance(key, x448.X448PrivateKey) # that a second call fails proc = await self.run_with_args('genkey', 'somefile') await proc.wait() stdoutdata, stderrdata = await proc.communicate() self.assertFalse(stdoutdata) self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata) # And that it exited w/ the correct code self.assertEqual(proc.returncode, 2) class TestNoiseFowarder(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() def tearDown(self): shutil.rmtree(self.basetempdir) self.tempdir = None @async_test async def test_server(self): # Test is plumbed: # (reader, writer) -> servsock -> # (rdr, wrr) NoiseForward (reader, writer) -> # servptsock -> (ptsock[0], ptsock[1]) # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') servarg = _makeunix(servsockpath) # Path that the server will send pt data to servptpath = os.path.join(self.tempdir, 'servptsock') # Setup pt target listener pttarg = _makeunix(servptpath) ptsock = [] def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) # Bind to pt listener lsock = await listensockstr(pttarg, ptsockaccept) nfs = [] event = asyncio.Event() async def runnf(rdr, wrr): ptpair = asyncio.create_task(connectsockstr(pttarg)) a = await NoiseForwarder('resp', _makefut((rdr, wrr)), ptpair, priv_key=self.server_key_pair[1]) nfs.append(a) event.set() # Setup server listener ssock = await listensockstr(servarg, runnf) # Connect to server reader, writer = await connectsockstr(servarg) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.assertEqual(len(message), _handshakelens[0]) writer.write(message) # Get response respmsg = await reader.readexactly(_handshakelens[1]) proto.read_message(respmsg) # Send final reply message = proto.write_message() writer.write(message) # Make sure handshake has completed self.assertTrue(proto.handshake_finished) # generate the keys for lengths enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') # write a test message ptmsg = b'this is a test message that should be a little in length' encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # XXX - how to sync? await asyncio.sleep(.1) ptreader, ptwriter = ptsock[0] # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # write a different message ptmsg = os.urandom(2843) encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # XXX - how to sync? await asyncio.sleep(.1) # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # now try the other way ptmsg = os.urandom(912) ptwriter.write(ptmsg) # find out how much we need to read encmsg = await reader.readexactly(2 + 16) tlen = declenfun(encmsg) # read the rest of the message rencmsg = await reader.readexactly(tlen - 16) tmsg = encmsg[2:] + rencmsg rptmsg = proto.decrypt(tmsg) self.assertEqual(rptmsg, ptmsg) # shut down sending writer.write_eof() # so pt reader should be shut down self.assertEqual(b'', await ptreader.read(1)) self.assertTrue(ptreader.at_eof()) # shut down pt ptwriter.write_eof() # make sure the enc reader is eof self.assertEqual(b'', await reader.read(1)) self.assertTrue(reader.at_eof()) await event.wait() self.assertEqual(nfs[0], [ 'dec', 'enc' ]) @async_test async def test_serverclient(self): # plumbing: # # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb # ptcsockapair, ptcsockbpair = _asyncsockpair() ptcareader, ptcawriter = await ptcsockapair #ptcsockbpair passed directly clssockapair, clssockbpair = _asyncsockpair() #both passed directly ptssockapair, ptssockbpair = _asyncsockpair() #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair clientnf = asyncio.create_task(NoiseForwarder('init', clssockapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) servnf = asyncio.create_task(NoiseForwarder('resp', clssockbpair, ptssockapair, priv_key=self.server_key_pair[1])) # send a message msga = os.urandom(183) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a second message msga = os.urandom(2834) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message larger than the block size msga = os.urandom(103958) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message the other direction msga = os.urandom(103958) ptsbwriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptcareader.readexactly(len(msga))) # close down the pt writers, the rest should follow ptsbwriter.write_eof() ptcawriter.write_eof() # make sure they are closed, and there is no more data self.assertEqual(b'', await ptsbreader.read(1)) self.assertTrue(ptsbreader.at_eof()) self.assertEqual(b'', await ptcareader.read(1)) self.assertTrue(ptcareader.at_eof()) self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf)