Browse Source

add tests for _parsesockstr/connectsockstr/listensockstr, add support

for TCP for them...

clean up a bunch of the test sockets, need to document how to clean up
the various services...

Stream line some of the test code so that it is properly all covered.
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
1a77fc8067
1 changed files with 206 additions and 21 deletions
  1. +206
    -21
      ntunnel.py

+ 206
- 21
ntunnel.py View File

@@ -7,7 +7,7 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from noise.connection import NoiseConnection, Keypair

#import tracemalloc; tracemalloc.start()
#import tracemalloc; tracemalloc.start(100)

import argparse
import asyncio
@@ -81,15 +81,59 @@ def _makeunix(path):

return 'unix:%s' % path

_allowedparameters = {
'unix': {
'path': str,
},
'tcp': {
'host': str,
'port': int,
},
}

def _parsesockstr(sockstr):
'''Parse a socket string to its parts. If there are no
kwargs (no = after the colon), a dictionary w/ a single
key of default will pass the string after the colon.

default is a reserved keyword and MUST NOT be used.'''

proto, rem = sockstr.split(':', 1)

return proto, rem
if '=' not in rem:
if proto == 'unix' and rem[0] != '/':
raise ValueError('bare path MUST start w/ a slash (/).')

if proto == 'unix':
args = { 'path': rem }
else:
args = dict(i.split('=', 1) for i in rem.split(','))

try:
allowed = _allowedparameters[proto]
except KeyError:
raise ValueError('unsupported proto: %s' % repr(proto))

extrakeys = args.keys() - allowed.keys()
if extrakeys:
raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))

for i in args:
args[i] = allowed[i](args[i])

return proto, args

async def connectsockstr(sockstr):
proto, rem = _parsesockstr(sockstr)
'''Wrapper for asyncio.open_*_connection.'''

reader, writer = await asyncio.open_unix_connection(rem)
proto, args = _parsesockstr(sockstr)

if proto == 'unix':
fun = asyncio.open_unix_connection
elif proto == 'tcp':
fun = asyncio.open_connection

reader, writer = await fun(**args)

return reader, writer

@@ -101,9 +145,15 @@ async def listensockstr(sockstr, cb):
directly, like: 'proto:value'. This is only allowed when the
value can unambiguously be determined not to be a param.

The cb parameter is passed to asyncio's start_server or related
calls. Per those docs, the cb parameter is calls or scheduled
as a task when a client establishes a connection. It is called
with two arguments, the reader and writer streams. For more
information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server

The characters that define 'param' must be all lower case ascii
characters and may contain an underscore. The first character
must not be and underscore.
must not be an underscore.

Supported protocols:
unix:
@@ -113,11 +163,14 @@ async def listensockstr(sockstr, cb):
slash if it is used as a default parameter.
'''

proto, rem = _parsesockstr(sockstr)
proto, args = _parsesockstr(sockstr)

server = await asyncio.start_unix_server(cb, path=rem)
if proto == 'unix':
fun = asyncio.start_unix_server
elif proto == 'tcp':
fun = asyncio.start_server

return server
return await fun(cb, **args)

# !!python makemessagelengths.py
_handshakelens = \
@@ -212,8 +265,12 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake')

reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
try:
reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
except:
wrr.close()
raise

# generate the keys for lengths
# XXX - get_handshake_hash is probably not the best option, but
@@ -273,7 +330,15 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
finally:
wrr.write_eof()

return await asyncio.gather(decses(), encses())
res = await asyncio.gather(decses(), encses())

await wrr.drain() # not sure if needed
wrr.close()

await writer.drain() # not sure if needed
writer.close()

return res

# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
# Slightly modified to timeout and to print trace back when canceled.
@@ -296,9 +361,107 @@ def async_test(f):
return wrapper

class Tests_misc(unittest.TestCase):
def test_listensockstr(self):
# XXX write test
pass
def setUp(self):
# setup temporary directory
d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d
self.tempdir = os.path.join(d, 'subdir')
os.mkdir(self.tempdir)

os.chdir(self.tempdir)

def tearDown(self):
#print('td:', time.time())
shutil.rmtree(self.basetempdir)
self.tempdir = None

def test_parsesockstr_bad(self):
badstrs = [
'unix:ff',
'randomnocolon',
'unix:somethingelse=bogus',
'tcp:port=bogus',
]

for i in badstrs:
with self.assertRaises(ValueError,
msg='Should have failed processing: %s' % repr(i)):
_parsesockstr(i)

def test_parsesockstr(self):
results = {
# Not all of these are valid when passed to a *sockstr
# function
'unix:/apath': ('unix', { 'path': '/apath' }),
'unix:path=apath': ('unix', { 'path': 'apath' }),
'tcp:host=apath': ('tcp', { 'host': 'apath' }),
'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
'port': 5 }),
}

for s, r in results.items():
self.assertEqual(_parsesockstr(s), r)

@async_test
async def test_listensockstr_bad(self):
with self.assertRaises(ValueError):
ls = await listensockstr('bogus:some=arg', None)

with self.assertRaises(ValueError):
ls = await connectsockstr('bogus:some=arg')

@async_test
async def test_listenconnectsockstr(self):
msgsent = b'this is a test message'
msgrcv = b'testing message for receive'

# That when a connection is received and receives and sends
async def servconfhandle(rdr, wrr):
msg = await rdr.readexactly(len(msgsent))
self.assertEqual(msg, msgsent)

#print(repr(wrr.get_extra_info('sockname')))
wrr.write(msgrcv)
await wrr.drain()

wrr.close()

return True

# Test listensockstr
for sstr, confun in [
('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
]:
# that listensockstr will bind to the correct path, can call cb
ls = await listensockstr(sstr, servconfhandle)

# that we open a connection to the path
rdr, wrr = await confun()

# and send a message
wrr.write(msgsent)

# and receive the message
rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
self.assertEqual(rcv, msgrcv)

wrr.close()

# Now test that connectsockstr works similarly.
rdr, wrr = await connectsockstr(sstr)

# and send a message
wrr.write(msgsent)

# and receive the message
rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
self.assertEqual(rcv, msgrcv)

wrr.close()

ls.close()
await ls.wait_closed()

def test_genciphfun(self):
enc, dec = _genciphfun(b'0' * 32, b'foobar')
@@ -531,14 +694,10 @@ class TestMain(unittest.TestCase):
servtargstr = _makeunix(servtargpath)

# Setup server target listener
ptsock = []
ptsockevent = asyncio.Event()
def ptsockaccept(reader, writer, ptsock=ptsock):
ptsock.append((reader, writer))
ptsockevent.set()

# Bind to pt listener
lsock = await listensockstr(servtargstr, ptsockaccept)
lsock = await listensockstr(servtargstr, None)

# Startup the server
server = await self.run_with_args('server',
@@ -566,6 +725,8 @@ class TestMain(unittest.TestCase):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)

writer.close()

# Make sure that when the server is terminated
server.terminate()

@@ -577,6 +738,9 @@ class TestMain(unittest.TestCase):
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)

lsock.close()
await lsock.wait_closed()

@async_test
async def test_end2end(self):
# Generate necessar keys
@@ -654,6 +818,12 @@ class TestMain(unittest.TestCase):
endwrr.write(msg)
self.assertEqual(msg, await reader.readexactly(len(msg)))

writer.close()
endwrr.close()

lsock.close()
await lsock.wait_closed()

@async_test
async def test_genkey(self):
# that it can generate a key
@@ -757,6 +927,8 @@ class TestNoiseFowarder(unittest.TestCase):
with self.assertRaises(ValueError):
await servnf

writer.close()

@async_test
async def test_server(self):
# Test is plumbed:
@@ -893,6 +1065,14 @@ class TestNoiseFowarder(unittest.TestCase):

self.assertEqual(nfs[0], [ 'dec', 'enc' ])

writer.close()
ptwriter.close()

lsock.close()
ssock.close()
await lsock.wait_closed()
await ssock.wait_closed()

@async_test
async def test_serverclient(self):
# plumbing:
@@ -910,8 +1090,7 @@ class TestNoiseFowarder(unittest.TestCase):
ptsbreader, ptsbwriter = await ptssockbpair

async def validateclientkey(pubkey):
if pubkey != self.client_key_pair[0]:
raise ValueError('invalid key')
self.assertEqual(pubkey, self.client_key_pair[0])

return await ptssockapair

@@ -963,3 +1142,9 @@ class TestNoiseFowarder(unittest.TestCase):

self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)

await ptsbwriter.drain()
await ptcawriter.drain()

ptsbwriter.close()
ptcawriter.close()

Loading…
Cancel
Save