@@ -7,7 +7,7 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from noise.connection import NoiseConnection, Keypair
#import tracemalloc; tracemalloc.start()
#import tracemalloc; tracemalloc.start(100 )
import argparse
import asyncio
@@ -81,15 +81,59 @@ def _makeunix(path):
return 'unix:%s' % path
_allowedparameters = {
'unix': {
'path': str,
},
'tcp': {
'host': str,
'port': int,
},
}
def _parsesockstr(sockstr):
'''Parse a socket string to its parts. If there are no
kwargs (no = after the colon), a dictionary w/ a single
key of default will pass the string after the colon.
default is a reserved keyword and MUST NOT be used.'''
proto, rem = sockstr.split(':', 1)
return proto, rem
if '=' not in rem:
if proto == 'unix' and rem[0] != '/':
raise ValueError('bare path MUST start w/ a slash (/).')
if proto == 'unix':
args = { 'path': rem }
else:
args = dict(i.split('=', 1) for i in rem.split(','))
try:
allowed = _allowedparameters[proto]
except KeyError:
raise ValueError('unsupported proto: %s' % repr(proto))
extrakeys = args.keys() - allowed.keys()
if extrakeys:
raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
for i in args:
args[i] = allowed[i](args[i])
return proto, args
async def connectsockstr(sockstr):
proto, rem = _parsesockstr(sockstr)
'''Wrapper for asyncio.open_*_connection.'''
reader, writer = await asyncio.open_unix_connection(rem)
proto, args = _parsesockstr(sockstr)
if proto == 'unix':
fun = asyncio.open_unix_connection
elif proto == 'tcp':
fun = asyncio.open_connection
reader, writer = await fun(**args)
return reader, writer
@@ -101,9 +145,15 @@ async def listensockstr(sockstr, cb):
directly, like: 'proto:value'. This is only allowed when the
value can unambiguously be determined not to be a param.
The cb parameter is passed to asyncio's start_server or related
calls. Per those docs, the cb parameter is calls or scheduled
as a task when a client establishes a connection. It is called
with two arguments, the reader and writer streams. For more
information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
The characters that define 'param' must be all lower case ascii
characters and may contain an underscore. The first character
must not be and underscore.
must not be an underscore.
Supported protocols:
unix:
@@ -113,11 +163,14 @@ async def listensockstr(sockstr, cb):
slash if it is used as a default parameter.
'''
proto, rem = _parsesockstr(sockstr)
proto, args = _parsesockstr(sockstr)
server = await asyncio.start_unix_server(cb, path=rem)
if proto == 'unix':
fun = asyncio.start_unix_server
elif proto == 'tcp':
fun = asyncio.start_server
return server
return await fun(cb, **args)
# !!python makemessagelengths.py
_handshakelens = \
@@ -212,8 +265,12 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake')
reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
try:
reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
except:
wrr.close()
raise
# generate the keys for lengths
# XXX - get_handshake_hash is probably not the best option, but
@@ -273,7 +330,15 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
finally:
wrr.write_eof()
return await asyncio.gather(decses(), encses())
res = await asyncio.gather(decses(), encses())
await wrr.drain() # not sure if needed
wrr.close()
await writer.drain() # not sure if needed
writer.close()
return res
# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
# Slightly modified to timeout and to print trace back when canceled.
@@ -296,9 +361,107 @@ def async_test(f):
return wrapper
class Tests_misc(unittest.TestCase):
def test_listensockstr(self):
# XXX write test
pass
def setUp(self):
# setup temporary directory
d = os.path.realpath(tempfile.mkdtemp())
self.basetempdir = d
self.tempdir = os.path.join(d, 'subdir')
os.mkdir(self.tempdir)
os.chdir(self.tempdir)
def tearDown(self):
#print('td:', time.time())
shutil.rmtree(self.basetempdir)
self.tempdir = None
def test_parsesockstr_bad(self):
badstrs = [
'unix:ff',
'randomnocolon',
'unix:somethingelse=bogus',
'tcp:port=bogus',
]
for i in badstrs:
with self.assertRaises(ValueError,
msg='Should have failed processing: %s' % repr(i)):
_parsesockstr(i)
def test_parsesockstr(self):
results = {
# Not all of these are valid when passed to a *sockstr
# function
'unix:/apath': ('unix', { 'path': '/apath' }),
'unix:path=apath': ('unix', { 'path': 'apath' }),
'tcp:host=apath': ('tcp', { 'host': 'apath' }),
'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
'port': 5 }),
}
for s, r in results.items():
self.assertEqual(_parsesockstr(s), r)
@async_test
async def test_listensockstr_bad(self):
with self.assertRaises(ValueError):
ls = await listensockstr('bogus:some=arg', None)
with self.assertRaises(ValueError):
ls = await connectsockstr('bogus:some=arg')
@async_test
async def test_listenconnectsockstr(self):
msgsent = b'this is a test message'
msgrcv = b'testing message for receive'
# That when a connection is received and receives and sends
async def servconfhandle(rdr, wrr):
msg = await rdr.readexactly(len(msgsent))
self.assertEqual(msg, msgsent)
#print(repr(wrr.get_extra_info('sockname')))
wrr.write(msgrcv)
await wrr.drain()
wrr.close()
return True
# Test listensockstr
for sstr, confun in [
('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
]:
# that listensockstr will bind to the correct path, can call cb
ls = await listensockstr(sstr, servconfhandle)
# that we open a connection to the path
rdr, wrr = await confun()
# and send a message
wrr.write(msgsent)
# and receive the message
rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
self.assertEqual(rcv, msgrcv)
wrr.close()
# Now test that connectsockstr works similarly.
rdr, wrr = await connectsockstr(sstr)
# and send a message
wrr.write(msgsent)
# and receive the message
rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
self.assertEqual(rcv, msgrcv)
wrr.close()
ls.close()
await ls.wait_closed()
def test_genciphfun(self):
enc, dec = _genciphfun(b'0' * 32, b'foobar')
@@ -531,14 +694,10 @@ class TestMain(unittest.TestCase):
servtargstr = _makeunix(servtargpath)
# Setup server target listener
ptsock = []
ptsockevent = asyncio.Event()
def ptsockaccept(reader, writer, ptsock=ptsock):
ptsock.append((reader, writer))
ptsockevent.set()
# Bind to pt listener
lsock = await listensockstr(servtargstr, ptsockaccept )
lsock = await listensockstr(servtargstr, None)
# Startup the server
server = await self.run_with_args('server',
@@ -566,6 +725,8 @@ class TestMain(unittest.TestCase):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)
writer.close()
# Make sure that when the server is terminated
server.terminate()
@@ -577,6 +738,9 @@ class TestMain(unittest.TestCase):
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)
lsock.close()
await lsock.wait_closed()
@async_test
async def test_end2end(self):
# Generate necessar keys
@@ -654,6 +818,12 @@ class TestMain(unittest.TestCase):
endwrr.write(msg)
self.assertEqual(msg, await reader.readexactly(len(msg)))
writer.close()
endwrr.close()
lsock.close()
await lsock.wait_closed()
@async_test
async def test_genkey(self):
# that it can generate a key
@@ -757,6 +927,8 @@ class TestNoiseFowarder(unittest.TestCase):
with self.assertRaises(ValueError):
await servnf
writer.close()
@async_test
async def test_server(self):
# Test is plumbed:
@@ -893,6 +1065,14 @@ class TestNoiseFowarder(unittest.TestCase):
self.assertEqual(nfs[0], [ 'dec', 'enc' ])
writer.close()
ptwriter.close()
lsock.close()
ssock.close()
await lsock.wait_closed()
await ssock.wait_closed()
@async_test
async def test_serverclient(self):
# plumbing:
@@ -910,8 +1090,7 @@ class TestNoiseFowarder(unittest.TestCase):
ptsbreader, ptsbwriter = await ptssockbpair
async def validateclientkey(pubkey):
if pubkey != self.client_key_pair[0]:
raise ValueError('invalid key')
self.assertEqual(pubkey, self.client_key_pair[0])
return await ptssockapair
@@ -963,3 +1142,9 @@ class TestNoiseFowarder(unittest.TestCase):
self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)
await ptsbwriter.drain()
await ptcawriter.drain()
ptsbwriter.close()
ptcawriter.close()