From 1a77fc8067abfabc92fb2d2cfe640bb17996c184 Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Tue, 29 Oct 2019 18:25:53 -0700 Subject: [PATCH] add tests for _parsesockstr/connectsockstr/listensockstr, add support for TCP for them... clean up a bunch of the test sockets, need to document how to clean up the various services... Stream line some of the test code so that it is properly all covered. --- ntunnel.py | 227 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 206 insertions(+), 21 deletions(-) diff --git a/ntunnel.py b/ntunnel.py index 583ea19..25ae403 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -7,7 +7,7 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.serialization import load_pem_private_key from noise.connection import NoiseConnection, Keypair -#import tracemalloc; tracemalloc.start() +#import tracemalloc; tracemalloc.start(100) import argparse import asyncio @@ -81,15 +81,59 @@ def _makeunix(path): return 'unix:%s' % path +_allowedparameters = { + 'unix': { + 'path': str, + }, + 'tcp': { + 'host': str, + 'port': int, + }, +} + def _parsesockstr(sockstr): + '''Parse a socket string to its parts. If there are no + kwargs (no = after the colon), a dictionary w/ a single + key of default will pass the string after the colon. + + default is a reserved keyword and MUST NOT be used.''' + proto, rem = sockstr.split(':', 1) - return proto, rem + if '=' not in rem: + if proto == 'unix' and rem[0] != '/': + raise ValueError('bare path MUST start w/ a slash (/).') + + if proto == 'unix': + args = { 'path': rem } + else: + args = dict(i.split('=', 1) for i in rem.split(',')) + + try: + allowed = _allowedparameters[proto] + except KeyError: + raise ValueError('unsupported proto: %s' % repr(proto)) + + extrakeys = args.keys() - allowed.keys() + if extrakeys: + raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys)) + + for i in args: + args[i] = allowed[i](args[i]) + + return proto, args async def connectsockstr(sockstr): - proto, rem = _parsesockstr(sockstr) + '''Wrapper for asyncio.open_*_connection.''' - reader, writer = await asyncio.open_unix_connection(rem) + proto, args = _parsesockstr(sockstr) + + if proto == 'unix': + fun = asyncio.open_unix_connection + elif proto == 'tcp': + fun = asyncio.open_connection + + reader, writer = await fun(**args) return reader, writer @@ -101,9 +145,15 @@ async def listensockstr(sockstr, cb): directly, like: 'proto:value'. This is only allowed when the value can unambiguously be determined not to be a param. + The cb parameter is passed to asyncio's start_server or related + calls. Per those docs, the cb parameter is calls or scheduled + as a task when a client establishes a connection. It is called + with two arguments, the reader and writer streams. For more + information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server + The characters that define 'param' must be all lower case ascii characters and may contain an underscore. The first character - must not be and underscore. + must not be an underscore. Supported protocols: unix: @@ -113,11 +163,14 @@ async def listensockstr(sockstr, cb): slash if it is used as a default parameter. ''' - proto, rem = _parsesockstr(sockstr) + proto, args = _parsesockstr(sockstr) - server = await asyncio.start_unix_server(cb, path=rem) + if proto == 'unix': + fun = asyncio.start_unix_server + elif proto == 'tcp': + fun = asyncio.start_server - return server + return await fun(cb, **args) # !!python makemessagelengths.py _handshakelens = \ @@ -212,8 +265,12 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') - reader, writer = await ptpairfun(getattr(proto.get_keypair( - Keypair.REMOTE_STATIC), 'public_bytes', None)) + try: + reader, writer = await ptpairfun(getattr(proto.get_keypair( + Keypair.REMOTE_STATIC), 'public_bytes', None)) + except: + wrr.close() + raise # generate the keys for lengths # XXX - get_handshake_hash is probably not the best option, but @@ -273,7 +330,15 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): finally: wrr.write_eof() - return await asyncio.gather(decses(), encses()) + res = await asyncio.gather(decses(), encses()) + + await wrr.drain() # not sure if needed + wrr.close() + + await writer.drain() # not sure if needed + writer.close() + + return res # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code # Slightly modified to timeout and to print trace back when canceled. @@ -296,9 +361,107 @@ def async_test(f): return wrapper class Tests_misc(unittest.TestCase): - def test_listensockstr(self): - # XXX write test - pass + def setUp(self): + # setup temporary directory + d = os.path.realpath(tempfile.mkdtemp()) + self.basetempdir = d + self.tempdir = os.path.join(d, 'subdir') + os.mkdir(self.tempdir) + + os.chdir(self.tempdir) + + def tearDown(self): + #print('td:', time.time()) + shutil.rmtree(self.basetempdir) + self.tempdir = None + + def test_parsesockstr_bad(self): + badstrs = [ + 'unix:ff', + 'randomnocolon', + 'unix:somethingelse=bogus', + 'tcp:port=bogus', + ] + + for i in badstrs: + with self.assertRaises(ValueError, + msg='Should have failed processing: %s' % repr(i)): + _parsesockstr(i) + + def test_parsesockstr(self): + results = { + # Not all of these are valid when passed to a *sockstr + # function + 'unix:/apath': ('unix', { 'path': '/apath' }), + 'unix:path=apath': ('unix', { 'path': 'apath' }), + 'tcp:host=apath': ('tcp', { 'host': 'apath' }), + 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath', + 'port': 5 }), + } + + for s, r in results.items(): + self.assertEqual(_parsesockstr(s), r) + + @async_test + async def test_listensockstr_bad(self): + with self.assertRaises(ValueError): + ls = await listensockstr('bogus:some=arg', None) + + with self.assertRaises(ValueError): + ls = await connectsockstr('bogus:some=arg') + + @async_test + async def test_listenconnectsockstr(self): + msgsent = b'this is a test message' + msgrcv = b'testing message for receive' + + # That when a connection is received and receives and sends + async def servconfhandle(rdr, wrr): + msg = await rdr.readexactly(len(msgsent)) + self.assertEqual(msg, msgsent) + + #print(repr(wrr.get_extra_info('sockname'))) + wrr.write(msgrcv) + await wrr.drain() + + wrr.close() + + return True + + # Test listensockstr + for sstr, confun in [ + ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')), + ('tcp:port=9384', lambda: asyncio.open_connection(port=9384)) + ]: + # that listensockstr will bind to the correct path, can call cb + ls = await listensockstr(sstr, servconfhandle) + + # that we open a connection to the path + rdr, wrr = await confun() + + # and send a message + wrr.write(msgsent) + + # and receive the message + rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5) + self.assertEqual(rcv, msgrcv) + + wrr.close() + + # Now test that connectsockstr works similarly. + rdr, wrr = await connectsockstr(sstr) + + # and send a message + wrr.write(msgsent) + + # and receive the message + rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5) + self.assertEqual(rcv, msgrcv) + + wrr.close() + + ls.close() + await ls.wait_closed() def test_genciphfun(self): enc, dec = _genciphfun(b'0' * 32, b'foobar') @@ -531,14 +694,10 @@ class TestMain(unittest.TestCase): servtargstr = _makeunix(servtargpath) # Setup server target listener - ptsock = [] ptsockevent = asyncio.Event() - def ptsockaccept(reader, writer, ptsock=ptsock): - ptsock.append((reader, writer)) - ptsockevent.set() # Bind to pt listener - lsock = await listensockstr(servtargstr, ptsockaccept) + lsock = await listensockstr(servtargstr, None) # Startup the server server = await self.run_with_args('server', @@ -566,6 +725,8 @@ class TestMain(unittest.TestCase): # make sure that we don't get the conenction await asyncio.wait_for(ptsockevent.wait(), .5) + writer.close() + # Make sure that when the server is terminated server.terminate() @@ -577,6 +738,9 @@ class TestMain(unittest.TestCase): # even the example echo server has this same leak #self.assertNotIn(b'Task exception was never retrieved', stderr) + lsock.close() + await lsock.wait_closed() + @async_test async def test_end2end(self): # Generate necessar keys @@ -654,6 +818,12 @@ class TestMain(unittest.TestCase): endwrr.write(msg) self.assertEqual(msg, await reader.readexactly(len(msg))) + writer.close() + endwrr.close() + + lsock.close() + await lsock.wait_closed() + @async_test async def test_genkey(self): # that it can generate a key @@ -757,6 +927,8 @@ class TestNoiseFowarder(unittest.TestCase): with self.assertRaises(ValueError): await servnf + writer.close() + @async_test async def test_server(self): # Test is plumbed: @@ -893,6 +1065,14 @@ class TestNoiseFowarder(unittest.TestCase): self.assertEqual(nfs[0], [ 'dec', 'enc' ]) + writer.close() + ptwriter.close() + + lsock.close() + ssock.close() + await lsock.wait_closed() + await ssock.wait_closed() + @async_test async def test_serverclient(self): # plumbing: @@ -910,8 +1090,7 @@ class TestNoiseFowarder(unittest.TestCase): ptsbreader, ptsbwriter = await ptssockbpair async def validateclientkey(pubkey): - if pubkey != self.client_key_pair[0]: - raise ValueError('invalid key') + self.assertEqual(pubkey, self.client_key_pair[0]) return await ptssockapair @@ -963,3 +1142,9 @@ class TestNoiseFowarder(unittest.TestCase): self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf) + + await ptsbwriter.drain() + await ptcawriter.drain() + + ptsbwriter.close() + ptcawriter.close()