from noise.connection import NoiseConnection, Keypair from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import hashes from twistednoise import genkeypair from cryptography.hazmat.backends import default_backend import asyncio import os.path import shutil import socket import tempfile import threading import unittest _backend = default_backend() def _makeunix(path): '''Make a properly formed unix path socket string.''' return 'unix:%s' % path def _parsesockstr(sockstr): proto, rem = sockstr.split(':', 1) return proto, rem async def connectsockstr(sockstr): proto, rem = _parsesockstr(sockstr) reader, writer = await asyncio.open_unix_connection(rem) return reader, writer async def listensockstr(sockstr, cb): '''Wrapper for asyncio.start_x_server. The format of sockstr is: 'proto:param=value[,param2=value2]'. If the proto has a default parameter, the value can be used directly, like: 'proto:value'. This is only allowed when the value can unambiguously be determined not to be a param. The characters that define 'param' must be all lower case ascii characters and may contain an underscore. The first character must not be and underscore. Supported protocols: unix: Default parameter is path. The path parameter specifies the path to the unix domain socket. The path MUST start w/ a slash if it is used as a default parameter. ''' proto, rem = _parsesockstr(sockstr) server = await asyncio.start_unix_server(cb, path=rem) return server # !!python makemessagelengths.py _handshakelens = \ [72, 72, 88] def _genciphfun(hash, ad): hkdf = HKDF(algorithm=hashes.SHA256(), length=32, salt=b'asdoifjsldkjdsf', info=ad, backend=_backend) key = hkdf.derive(hash) cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=_backend) enctor = cipher.encryptor() def encfun(data): # Returns the two bytes for length val = len(data) encbytes = enctor.update(data[:16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return (val ^ mask).to_bytes(length=2, byteorder='big') def decfun(data): # takes off the data and returns the total # length val = int.from_bytes(data[:2], byteorder='big') encbytes = enctor.update(data[2:2 + 16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return val ^ mask return encfun, decfun async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): rdr, wrr = rdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) if pub_key is not None: proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) if mode == 'resp': proto.set_as_responder() elif mode == 'init': proto.set_as_initiator() proto.start_handshake() if mode == 'resp': proto.read_message(await rdr.readexactly(_handshakelens[0])) wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[2])) elif mode == 'init': wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[1])) wrr.write(proto.write_message()) if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') # generate the keys for lengths if mode == 'resp': _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') elif mode == 'init': enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') reader, writer = await ptpair async def decses(): try: while True: try: msg = await rdr.readexactly(2 + 16) except asyncio.streams.IncompleteReadError: if rdr.at_eof(): return 'dec' tlen = declenfun(msg) rmsg = await rdr.readexactly(tlen - 16) tmsg = msg[2:] + rmsg writer.write(proto.decrypt(tmsg)) await writer.drain() #except: # import traceback # traceback.print_exc() # raise finally: writer.write_eof() async def encses(): try: while True: ptmsg = await reader.read(65535 - 16) # largest message if not ptmsg: # eof return 'enc' encmsg = proto.encrypt(ptmsg) wrr.write(enclenfun(encmsg)) wrr.write(encmsg) await wrr.drain() #except: # import traceback # traceback.print_exc() # raise finally: wrr.write_eof() return await asyncio.gather(decses(), encses()) # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code # Slightly modified to timeout def async_test(f): def wrapper(*args, **kwargs): coro = asyncio.coroutine(f) future = coro(*args, **kwargs) loop = asyncio.get_event_loop() # timeout after 2 seconds loop.run_until_complete(asyncio.wait_for(future, 2)) return wrapper class Tests_misc(unittest.TestCase): def test_listensockstr(self): # XXX write test pass def test_genciphfun(self): enc, dec = _genciphfun(b'0' * 32, b'foobar') msg = b'this is a bunch of data' tb = enc(msg) self.assertEqual(len(msg), dec(tb + msg)) for i in [ 20, 1384, 64000, 23839, 65535 ]: msg = os.urandom(i) self.assertEqual(len(msg), dec(enc(msg) + msg)) def _asyncsockpair(): '''Create a pair of sockets that are bound to each other. The function will return a tuple of two coroutine's, that each, when await'ed upon, will return the reader/writer pair.''' socka, sockb = socket.socketpair() return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb) class Tests(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() def tearDown(self): shutil.rmtree(self.basetempdir) self.tempdir = None @async_test async def test_server(self): # Test is plumbed: # (reader, writer) -> servsock -> # (rdr, wrr) NoiseForward (reader, writer) -> # servptsock -> (ptsock[0], ptsock[1]) # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') servarg = _makeunix(servsockpath) # Path that the server will send pt data to servptpath = os.path.join(self.tempdir, 'servptsock') # Setup pt target listener pttarg = _makeunix(servptpath) ptsock = [] def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) # Bind to pt listener lsock = await listensockstr(pttarg, ptsockaccept) nfs = [] event = asyncio.Event() async def runnf(rdr, wrr): ptpair = asyncio.create_task(connectsockstr(pttarg)) a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1]) nfs.append(a) event.set() # Setup server listener ssock = await listensockstr(servarg, runnf) # Connect to server reader, writer = await connectsockstr(servarg) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.assertEqual(len(message), _handshakelens[0]) writer.write(message) # Get response respmsg = await reader.readexactly(_handshakelens[1]) proto.read_message(respmsg) # Send final reply message = proto.write_message() writer.write(message) # Make sure handshake has completed self.assertTrue(proto.handshake_finished) # generate the keys for lengths enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') # write a test message ptmsg = b'this is a test message that should be a little in length' encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # XXX - how to sync? await asyncio.sleep(.1) ptreader, ptwriter = ptsock[0] # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # write a different message ptmsg = os.urandom(2843) encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # XXX - how to sync? await asyncio.sleep(.1) # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # now try the other way ptmsg = os.urandom(912) ptwriter.write(ptmsg) # find out how much we need to read encmsg = await reader.readexactly(2 + 16) tlen = declenfun(encmsg) # read the rest of the message rencmsg = await reader.readexactly(tlen - 16) tmsg = encmsg[2:] + rencmsg rptmsg = proto.decrypt(tmsg) self.assertEqual(rptmsg, ptmsg) # shut down sending writer.write_eof() # so pt reader should be shut down r = await ptreader.read(1) self.assertTrue(ptreader.at_eof()) # shut down pt ptwriter.write_eof() # make sure the enc reader is eof r = await reader.read(1) self.assertTrue(reader.at_eof()) await event.wait() self.assertEqual(nfs[0], [ 'dec', 'enc' ]) @async_test async def test_serverclient(self): # plumbing: # # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb # ptcsockapair, ptcsockbpair = _asyncsockpair() ptcareader, ptcawriter = await ptcsockapair #ptcsockbpair passed directly clssockapair, clssockbpair = _asyncsockpair() clsapair = await clssockapair clsbpair = await clssockbpair ptssockapair, ptssockbpair = _asyncsockpair() #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1])) # send a message msga = os.urandom(183) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a second message msga = os.urandom(2834) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message larger than the block size msga = os.urandom(103958) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message the other direction msga = os.urandom(103958) ptsbwriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptcareader.readexactly(len(msga))) # close down the pt writers, the rest should follow ptsbwriter.write_eof() ptcawriter.write_eof() self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf)