From fa9a3d20dca4336edb2561aeccd7d5747840c8a0 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Thu, 24 Oct 2019 23:54:12 -0700 Subject: [PATCH] switch to using asyncio. It's so much cleaner than twisted. --- ntunnel.py | 230 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 202 insertions(+), 28 deletions(-) diff --git a/ntunnel.py b/ntunnel.py index 47d16f4..6a53135 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -1,6 +1,11 @@ from noise.connection import NoiseConnection, Keypair +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives import hashes from twistednoise import genkeypair +from cryptography.hazmat.backends import default_backend +import asyncio import os.path import shutil import socket @@ -8,23 +13,27 @@ import tempfile import threading import unittest +_backend = default_backend() + def _makeunix(path): '''Make a properly formed unix path socket string.''' return 'unix:%s' % path -def _acceptfun(s, fun): - while True: - sock = s.accept() +def _parsesockstr(sockstr): + proto, rem = sockstr.split(':', 1) + + return proto, rem + +async def connectsockstr(sockstr): + proto, rem = _parsesockstr(sockstr) - fun(*sock) + reader, writer = await asyncio.open_unix_connection(rem) -def listensocket(sockstr, fun): - '''Listen for connections on sockstr. When ever a connection - is accepted, the parameter fun is called with the socket and - the from address. The return will be a Thread object. Note - that fun MUST NOT block, as if it does, it will stop accepting - other connections. + return reader, writer + +async def listensockstr(sockstr, cb): + '''Wrapper for asyncio.start_x_server. The format of sockstr is: 'proto:param=value[,param2=value2]'. If the proto has a default parameter, the value can be used @@ -43,29 +52,123 @@ def listensocket(sockstr, fun): slash if it is used as a default parameter. ''' - proto, rem = sockstr.split(':', 1) + proto, rem = _parsesockstr(sockstr) - s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s.bind(rem) - s.listen(-1) + server = await asyncio.start_unix_server(cb, path=rem) - thr = threading.Thread(target=_acceptfun, name='accept thread: %s' % repr(sockstr), args=(s, fun)) - thr.setDaemon(True) + return server - thr.start() +# !!python makemessagelengths.py +_handshakelens = \ +[72, 72, 88] - return thr +def _genciphfun(hash, ad): + hkdf = HKDF(algorithm=hashes.SHA256(), length=32, + salt=b'asdoifjsldkjdsf', info=ad, backend=_backend) -class NoiseForwarder(object): - def __init__(self, mode, sock, ): - nf = NoiseForwarder('resp', self.server_key_pair[1], ssock, pttarg) - pass + key = hkdf.derive(hash) + cipher = Cipher(algorithms.AES(key), modes.ECB(), + backend=_backend) + enctor = cipher.encryptor() + + def encfun(data): + # Returns the two bytes for length + val = len(data) + encbytes = enctor.update(data[:16]) + mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff + + return (val ^ mask).to_bytes(length=2, byteorder='big') + + def decfun(data): + # takes off the data and returns the total + # length + val = int.from_bytes(data[:2], byteorder='big') + encbytes = enctor.update(data[2:2 + 16]) + mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff + + return val ^ mask + + return encfun, decfun + +async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): + rdr, wrr = rdrwrr + + proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') + + proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) + + proto.set_as_responder() + + proto.start_handshake() + + proto.read_message(await rdr.readexactly(_handshakelens[0])) + + wrr.write(proto.write_message()) + + proto.read_message(await rdr.readexactly(_handshakelens[2])) + + if not proto.handshake_finished: # pragma: no cover + raise RuntimeError('failed to finish handshake') + + # generate the keys for lengths + _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') + + reader, writer = await connectsockstr(ptsockstr) + + async def decses(): + while True: + msg = await rdr.readexactly(2 + 16) + tlen = declenfun(msg) + rmsg = await rdr.readexactly(tlen - 16) + tmsg = msg[2:] + rmsg + writer.write(proto.decrypt(tmsg)) + await writer.drain() + + async def encses(): + while True: + ptmsg = await reader.read(65535 - 16) # largest message + encmsg = proto.encrypt(ptmsg) + wrr.write(enclenfun(encmsg)) + wrr.write(encmsg) + await wrr.drain() + + r = await asyncio.gather(decses(), encses(), return_exceptions=True) + + print(repr(r)) + + return r class TestListenSocket(unittest.TestCase): - def test_listensocket(self): + def test_listensockstr(self): # XXX write test pass +# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code +def async_test(f): + def wrapper(*args, **kwargs): + coro = asyncio.coroutine(f) + future = coro(*args, **kwargs) + loop = asyncio.get_event_loop() + + # timeout after 2 seconds + loop.run_until_complete(asyncio.wait_for(future, 2)) + return wrapper + +class Tests_misc(unittest.TestCase): + def test_genciphfun(self): + enc, dec = _genciphfun(b'0' * 32, b'foobar') + + msg = b'this is a bunch of data' + + tb = enc(msg) + + self.assertEqual(len(msg), dec(tb + msg)) + + for i in [ 20, 1384, 64000, 23839, 65535 ]: + msg = os.urandom(i) + self.assertEqual(len(msg), dec(enc(msg) + msg)) + class Tests(unittest.TestCase): def setUp(self): # setup temporary directory @@ -82,7 +185,8 @@ class Tests(unittest.TestCase): shutil.rmtree(self.basetempdir) self.tempdir = None - def test_server(self): + @async_test + async def test_server(self): # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') servarg = _makeunix(servsockpath) @@ -93,14 +197,17 @@ class Tests(unittest.TestCase): # Setup pt target listener pttarg = _makeunix(servsockpath) ptsock = [] - def ptsockaccept(sock, frm, ptsock=ptsock): - ptsock.append(sock) + def ptsockaccept(reader, writer, ptsock=ptsock): + ptsock.append((reader, writer)) # Bind to pt listener - lsock = listensocket(pttarg, ptsockaccept) + lsock = await listensockstr(pttarg, ptsockaccept) # Setup server listener - ssock = listensocket(servarg, lambda x, y: NoiseForwarder('resp', self.server_key_pair[1], x, pttarg)) + ssock = await listensockstr(servarg, lambda rdr, wrr: NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)) + + # Connect to server + reader, writer = await connectsockstr(servarg) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') @@ -114,3 +221,70 @@ class Tests(unittest.TestCase): # Send first message message = proto.write_message() + self.assertEqual(len(message), _handshakelens[0]) + writer.write(message) + + # Get response + respmsg = await reader.readexactly(_handshakelens[1]) + proto.read_message(respmsg) + + # Send final reply + message = proto.write_message() + writer.write(message) + + # Make sure handshake has completed + self.assertTrue(proto.handshake_finished) + + # generate the keys for lengths + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') + _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') + + # write a test message + ptmsg = b'this is a test message that should be a little in length' + encmsg = proto.encrypt(ptmsg) + writer.write(enclenfun(encmsg)) + writer.write(encmsg) + + # XXX - how to sync? + await asyncio.sleep(.1) + + # read the test message + rptmsg = await ptsock[0][0].readexactly(len(ptmsg)) + + self.assertEqual(rptmsg, ptmsg) + + # write a different message + ptmsg = os.urandom(2843) + encmsg = proto.encrypt(ptmsg) + writer.write(enclenfun(encmsg)) + writer.write(encmsg) + + # XXX - how to sync? + await asyncio.sleep(.1) + + # read the test message + rptmsg = await ptsock[0][0].readexactly(len(ptmsg)) + + self.assertEqual(rptmsg, ptmsg) + + # now try the other way + ptmsg = os.urandom(912) + ptsock[0][1].write(ptmsg) + + # find out how much we need to read + encmsg = await reader.readexactly(2 + 16) + tlen = declenfun(encmsg) + + # read the rest of the message + rencmsg = await reader.readexactly(tlen - 16) + tmsg = encmsg[2:] + rencmsg + rptmsg = proto.decrypt(tmsg) + + self.assertEqual(rptmsg, ptmsg) + + # shut everything down + ptsock[0][1].write_eof() + writer.write_eof() + + # XXX - how to sync? + await asyncio.sleep(1)