|
- from noise.connection import NoiseConnection, Keypair
- from cryptography.hazmat.primitives.kdf.hkdf import HKDF
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives import hashes
- from twistednoise import genkeypair
- from cryptography.hazmat.backends import default_backend
-
- import asyncio
- import os.path
- import shutil
- import socket
- import tempfile
- import threading
- import unittest
-
- _backend = default_backend()
-
- def _makefut(obj):
- loop = asyncio.get_running_loop()
- fut = loop.create_future()
- fut.set_result(obj)
-
- return fut
-
- def _makeunix(path):
- '''Make a properly formed unix path socket string.'''
-
- return 'unix:%s' % path
-
- def _parsesockstr(sockstr):
- proto, rem = sockstr.split(':', 1)
-
- return proto, rem
-
- async def connectsockstr(sockstr):
- proto, rem = _parsesockstr(sockstr)
-
- reader, writer = await asyncio.open_unix_connection(rem)
-
- return reader, writer
-
- async def listensockstr(sockstr, cb):
- '''Wrapper for asyncio.start_x_server.
-
- The format of sockstr is: 'proto:param=value[,param2=value2]'.
- If the proto has a default parameter, the value can be used
- directly, like: 'proto:value'. This is only allowed when the
- value can unambiguously be determined not to be a param.
-
- The characters that define 'param' must be all lower case ascii
- characters and may contain an underscore. The first character
- must not be and underscore.
-
- Supported protocols:
- unix:
- Default parameter is path.
- The path parameter specifies the path to the
- unix domain socket. The path MUST start w/ a
- slash if it is used as a default parameter.
- '''
-
- proto, rem = _parsesockstr(sockstr)
-
- server = await asyncio.start_unix_server(cb, path=rem)
-
- return server
-
- # !!python makemessagelengths.py
- _handshakelens = \
- [72, 72, 88]
-
- def _genciphfun(hash, ad):
- hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
- salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
-
- key = hkdf.derive(hash)
- cipher = Cipher(algorithms.AES(key), modes.ECB(),
- backend=_backend)
- enctor = cipher.encryptor()
-
- def encfun(data):
- # Returns the two bytes for length
- val = len(data)
- encbytes = enctor.update(data[:16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
-
- return (val ^ mask).to_bytes(length=2, byteorder='big')
-
- def decfun(data):
- # takes off the data and returns the total
- # length
- val = int.from_bytes(data[:2], byteorder='big')
- encbytes = enctor.update(data[2:2 + 16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
-
- return val ^ mask
-
- return encfun, decfun
-
- async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
- rdr, wrr = await rdrwrr
-
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
-
- proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
- if pub_key is not None:
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- pub_key)
-
- if mode == 'resp':
- proto.set_as_responder()
-
- proto.start_handshake()
-
- proto.read_message(await rdr.readexactly(_handshakelens[0]))
-
- wrr.write(proto.write_message())
-
- proto.read_message(await rdr.readexactly(_handshakelens[2]))
- elif mode == 'init':
- proto.set_as_initiator()
-
- proto.start_handshake()
-
- wrr.write(proto.write_message())
-
- proto.read_message(await rdr.readexactly(_handshakelens[1]))
-
- wrr.write(proto.write_message())
-
- if not proto.handshake_finished: # pragma: no cover
- raise RuntimeError('failed to finish handshake')
-
- # generate the keys for lengths
- if mode == 'resp':
- _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
- elif mode == 'init':
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
-
- reader, writer = await ptpair
-
- async def decses():
- try:
- while True:
- try:
- msg = await rdr.readexactly(2 + 16)
- except asyncio.streams.IncompleteReadError:
- if rdr.at_eof():
- return 'dec'
-
- tlen = declenfun(msg)
- rmsg = await rdr.readexactly(tlen - 16)
- tmsg = msg[2:] + rmsg
- writer.write(proto.decrypt(tmsg))
- await writer.drain()
- #except:
- # import traceback
- # traceback.print_exc()
- # raise
- finally:
- writer.write_eof()
-
- async def encses():
- try:
- while True:
- # largest message
- ptmsg = await reader.read(65535 - 16)
- if not ptmsg:
- # eof
- return 'enc'
-
- encmsg = proto.encrypt(ptmsg)
- wrr.write(enclenfun(encmsg))
- wrr.write(encmsg)
- await wrr.drain()
- #except:
- # import traceback
- # traceback.print_exc()
- # raise
- finally:
- wrr.write_eof()
-
- return await asyncio.gather(decses(), encses())
-
- # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
- # Slightly modified to timeout
- def async_test(f):
- def wrapper(*args, **kwargs):
- coro = asyncio.coroutine(f)
- future = coro(*args, **kwargs)
- loop = asyncio.get_event_loop()
-
- # timeout after 2 seconds
- loop.run_until_complete(asyncio.wait_for(future, 2))
- return wrapper
-
- class Tests_misc(unittest.TestCase):
- def test_listensockstr(self):
- # XXX write test
- pass
-
- def test_genciphfun(self):
- enc, dec = _genciphfun(b'0' * 32, b'foobar')
-
- msg = b'this is a bunch of data'
-
- tb = enc(msg)
-
- self.assertEqual(len(msg), dec(tb + msg))
-
- for i in [ 20, 1384, 64000, 23839, 65535 ]:
- msg = os.urandom(i)
- self.assertEqual(len(msg), dec(enc(msg) + msg))
-
- def _asyncsockpair():
- '''Create a pair of sockets that are bound to each other.
- The function will return a tuple of two coroutine's, that
- each, when await'ed upon, will return the reader/writer pair.'''
-
- socka, sockb = socket.socketpair()
-
- return asyncio.open_connection(sock=socka), \
- asyncio.open_connection(sock=sockb)
-
- class Tests(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
-
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
-
- def tearDown(self):
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
-
- @async_test
- async def test_server(self):
- # Test is plumbed:
- # (reader, writer) -> servsock ->
- # (rdr, wrr) NoiseForward (reader, writer) ->
- # servptsock -> (ptsock[0], ptsock[1])
- # Path that the server will sit on
- servsockpath = os.path.join(self.tempdir, 'servsock')
- servarg = _makeunix(servsockpath)
-
- # Path that the server will send pt data to
- servptpath = os.path.join(self.tempdir, 'servptsock')
-
- # Setup pt target listener
- pttarg = _makeunix(servptpath)
- ptsock = []
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
-
- # Bind to pt listener
- lsock = await listensockstr(pttarg, ptsockaccept)
-
- nfs = []
- event = asyncio.Event()
-
- async def runnf(rdr, wrr):
- ptpair = asyncio.create_task(connectsockstr(pttarg))
-
- a = await NoiseForwarder('resp',
- _makefut((rdr, wrr)), ptpair,
- priv_key=self.server_key_pair[1])
-
- nfs.append(a)
- event.set()
-
- # Setup server listener
- ssock = await listensockstr(servarg, runnf)
-
- # Connect to server
- reader, writer = await connectsockstr(servarg)
-
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC,
- self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- self.server_key_pair[0])
-
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
-
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
-
- # Send final reply
- message = proto.write_message()
- writer.write(message)
-
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
-
- # generate the keys for lengths
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
- b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(),
- b'toinit')
-
- # write a test message
- ptmsg = b'this is a test message that should be a little in length'
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- ptreader, ptwriter = ptsock[0]
- # read the test message
- rptmsg = await ptreader.readexactly(len(ptmsg))
-
- self.assertEqual(rptmsg, ptmsg)
-
- # write a different message
- ptmsg = os.urandom(2843)
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- # read the test message
- rptmsg = await ptreader.readexactly(len(ptmsg))
-
- self.assertEqual(rptmsg, ptmsg)
-
- # now try the other way
- ptmsg = os.urandom(912)
- ptwriter.write(ptmsg)
-
- # find out how much we need to read
- encmsg = await reader.readexactly(2 + 16)
- tlen = declenfun(encmsg)
-
- # read the rest of the message
- rencmsg = await reader.readexactly(tlen - 16)
- tmsg = encmsg[2:] + rencmsg
- rptmsg = proto.decrypt(tmsg)
-
- self.assertEqual(rptmsg, ptmsg)
-
- # shut down sending
- writer.write_eof()
-
- # so pt reader should be shut down
- self.assertEqual(b'', await ptreader.read(1))
- self.assertTrue(ptreader.at_eof())
-
- # shut down pt
- ptwriter.write_eof()
-
- # make sure the enc reader is eof
- self.assertEqual(b'', await reader.read(1))
- self.assertTrue(reader.at_eof())
-
- await event.wait()
-
- self.assertEqual(nfs[0], [ 'dec', 'enc' ])
-
- @async_test
- async def test_serverclient(self):
- # plumbing:
- #
- # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
- #
-
- ptcsockapair, ptcsockbpair = _asyncsockpair()
- ptcareader, ptcawriter = await ptcsockapair
- #ptcsockbpair passed directly
- clssockapair, clssockbpair = _asyncsockpair()
- #both passed directly
- ptssockapair, ptssockbpair = _asyncsockpair()
- #ptssockapair passed directly
- ptsbreader, ptsbwriter = await ptssockbpair
-
- clientnf = asyncio.create_task(NoiseForwarder('init',
- clssockapair, ptcsockbpair,
- priv_key=self.client_key_pair[1],
- pub_key=self.server_key_pair[0]))
- servnf = asyncio.create_task(NoiseForwarder('resp',
- clssockbpair, ptssockapair,
- priv_key=self.server_key_pair[1]))
-
- # send a message
- msga = os.urandom(183)
- ptcawriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
-
- # send a second message
- msga = os.urandom(2834)
- ptcawriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
-
- # send a message larger than the block size
- msga = os.urandom(103958)
- ptcawriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
-
- # send a message the other direction
- msga = os.urandom(103958)
- ptsbwriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
-
- # close down the pt writers, the rest should follow
- ptsbwriter.write_eof()
- ptcawriter.write_eof()
-
- # make sure they are closed, and there is no more data
- self.assertEqual(b'', await ptsbreader.read(1))
- self.assertTrue(ptsbreader.at_eof())
- self.assertEqual(b'', await ptcareader.read(1))
- self.assertTrue(ptcareader.at_eof())
-
- self.assertEqual([ 'dec', 'enc' ], await clientnf)
- self.assertEqual([ 'dec', 'enc' ], await servnf)
|