|
- from noise.connection import NoiseConnection, Keypair
- from cryptography.hazmat.primitives.kdf.hkdf import HKDF
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives import hashes
- from twistednoise import genkeypair
- from cryptography.hazmat.backends import default_backend
-
- import asyncio
- import os.path
- import shutil
- import socket
- import tempfile
- import threading
- import unittest
-
- _backend = default_backend()
-
- def _makeunix(path):
- '''Make a properly formed unix path socket string.'''
-
- return 'unix:%s' % path
-
- def _parsesockstr(sockstr):
- proto, rem = sockstr.split(':', 1)
-
- return proto, rem
-
- async def connectsockstr(sockstr):
- proto, rem = _parsesockstr(sockstr)
-
- reader, writer = await asyncio.open_unix_connection(rem)
-
- return reader, writer
-
- async def listensockstr(sockstr, cb):
- '''Wrapper for asyncio.start_x_server.
-
- The format of sockstr is: 'proto:param=value[,param2=value2]'.
- If the proto has a default parameter, the value can be used
- directly, like: 'proto:value'. This is only allowed when the
- value can unambiguously be determined not to be a param.
-
- The characters that define 'param' must be all lower case ascii
- characters and may contain an underscore. The first character
- must not be and underscore.
-
- Supported protocols:
- unix:
- Default parameter is path.
- The path parameter specifies the path to the
- unix domain socket. The path MUST start w/ a
- slash if it is used as a default parameter.
- '''
-
- proto, rem = _parsesockstr(sockstr)
-
- server = await asyncio.start_unix_server(cb, path=rem)
-
- return server
-
- # !!python makemessagelengths.py
- _handshakelens = \
- [72, 72, 88]
-
- def _genciphfun(hash, ad):
- hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
- salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
-
- key = hkdf.derive(hash)
- cipher = Cipher(algorithms.AES(key), modes.ECB(),
- backend=_backend)
- enctor = cipher.encryptor()
-
- def encfun(data):
- # Returns the two bytes for length
- val = len(data)
- encbytes = enctor.update(data[:16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
-
- return (val ^ mask).to_bytes(length=2, byteorder='big')
-
- def decfun(data):
- # takes off the data and returns the total
- # length
- val = int.from_bytes(data[:2], byteorder='big')
- encbytes = enctor.update(data[2:2 + 16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
-
- return val ^ mask
-
- return encfun, decfun
-
- async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
- rdr, wrr = rdrwrr
-
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
-
- proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
-
- proto.set_as_responder()
-
- proto.start_handshake()
-
- proto.read_message(await rdr.readexactly(_handshakelens[0]))
-
- wrr.write(proto.write_message())
-
- proto.read_message(await rdr.readexactly(_handshakelens[2]))
-
- if not proto.handshake_finished: # pragma: no cover
- raise RuntimeError('failed to finish handshake')
-
- # generate the keys for lengths
- _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
-
- reader, writer = await connectsockstr(ptsockstr)
-
- async def decses():
- try:
- while True:
- try:
- msg = await rdr.readexactly(2 + 16)
- except asyncio.streams.IncompleteReadError:
- if rdr.at_eof():
- return 'dec'
-
- tlen = declenfun(msg)
- rmsg = await rdr.readexactly(tlen - 16)
- tmsg = msg[2:] + rmsg
- writer.write(proto.decrypt(tmsg))
- await writer.drain()
- finally:
- print('foo')
- # XXX - how to test
- #writer.write_eof()
-
- async def encses():
- while True:
- ptmsg = await reader.read(65535 - 16) # largest message
- encmsg = proto.encrypt(ptmsg)
- wrr.write(enclenfun(encmsg))
- wrr.write(encmsg)
- await wrr.drain()
-
- done, pending = await asyncio.wait((decses(), encses()), return_when=asyncio.FIRST_COMPLETED)
- for i in done:
- print('v:', repr(await i))
-
- done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
- for i in done:
- print('v:', repr(await i))
-
- return done
-
- class TestListenSocket(unittest.TestCase):
- def test_listensockstr(self):
- # XXX write test
- pass
-
- # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
- def async_test(f):
- def wrapper(*args, **kwargs):
- coro = asyncio.coroutine(f)
- future = coro(*args, **kwargs)
- loop = asyncio.get_event_loop()
-
- # timeout after 2 seconds
- loop.run_until_complete(asyncio.wait_for(future, 2))
- return wrapper
-
- class Tests_misc(unittest.TestCase):
- def test_genciphfun(self):
- enc, dec = _genciphfun(b'0' * 32, b'foobar')
-
- msg = b'this is a bunch of data'
-
- tb = enc(msg)
-
- self.assertEqual(len(msg), dec(tb + msg))
-
- for i in [ 20, 1384, 64000, 23839, 65535 ]:
- msg = os.urandom(i)
- self.assertEqual(len(msg), dec(enc(msg) + msg))
-
- class Tests(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
-
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
-
- def tearDown(self):
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
-
- @async_test
- async def test_server(self):
- # Path that the server will sit on
- servsockpath = os.path.join(self.tempdir, 'servsock')
- servarg = _makeunix(servsockpath)
-
- # Path that the server will send pt data to
- servsockpath = os.path.join(self.tempdir, 'servptsock')
-
- # Setup pt target listener
- pttarg = _makeunix(servsockpath)
- ptsock = []
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
-
- # Bind to pt listener
- lsock = await listensockstr(pttarg, ptsockaccept)
-
- nfs = []
- event = asyncio.Event()
-
- async def runnf(rdr, wrr):
- print('a')
- a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)
-
- print('b')
- nfs.append(a)
- print('c')
- event.set()
- print('d')
-
- # Setup server listener
- ssock = await listensockstr(servarg, runnf)
-
- # Connect to server
- reader, writer = await connectsockstr(servarg)
-
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
-
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
-
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
-
- # Send final reply
- message = proto.write_message()
- writer.write(message)
-
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
-
- # generate the keys for lengths
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
-
- # write a test message
- ptmsg = b'this is a test message that should be a little in length'
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- # read the test message
- rptmsg = await ptsock[0][0].readexactly(len(ptmsg))
-
- self.assertEqual(rptmsg, ptmsg)
-
- # write a different message
- ptmsg = os.urandom(2843)
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- # read the test message
- rptmsg = await ptsock[0][0].readexactly(len(ptmsg))
-
- self.assertEqual(rptmsg, ptmsg)
-
- # now try the other way
- ptmsg = os.urandom(912)
- ptsock[0][1].write(ptmsg)
-
- # find out how much we need to read
- encmsg = await reader.readexactly(2 + 16)
- tlen = declenfun(encmsg)
-
- # read the rest of the message
- rencmsg = await reader.readexactly(tlen - 16)
- tmsg = encmsg[2:] + rencmsg
- rptmsg = proto.decrypt(tmsg)
-
- self.assertEqual(rptmsg, ptmsg)
-
- # shut everything down
- writer.write_eof()
- #ptsock[0][1].write_eof()
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- await event.wait()
- print(repr(nfs))
|