|
@@ -90,31 +90,47 @@ def _genciphfun(hash, ad): |
|
|
|
|
|
|
|
|
return encfun, decfun |
|
|
return encfun, decfun |
|
|
|
|
|
|
|
|
async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): |
|
|
|
|
|
|
|
|
async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): |
|
|
rdr, wrr = rdrwrr |
|
|
rdr, wrr = rdrwrr |
|
|
|
|
|
|
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
|
|
|
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) |
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) |
|
|
|
|
|
if pub_key is not None: |
|
|
|
|
|
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) |
|
|
|
|
|
|
|
|
proto.set_as_responder() |
|
|
|
|
|
|
|
|
if mode == 'resp': |
|
|
|
|
|
proto.set_as_responder() |
|
|
|
|
|
elif mode == 'init': |
|
|
|
|
|
proto.set_as_initiator() |
|
|
|
|
|
|
|
|
proto.start_handshake() |
|
|
proto.start_handshake() |
|
|
|
|
|
|
|
|
proto.read_message(await rdr.readexactly(_handshakelens[0])) |
|
|
|
|
|
|
|
|
if mode == 'resp': |
|
|
|
|
|
proto.read_message(await rdr.readexactly(_handshakelens[0])) |
|
|
|
|
|
|
|
|
|
|
|
wrr.write(proto.write_message()) |
|
|
|
|
|
|
|
|
|
|
|
proto.read_message(await rdr.readexactly(_handshakelens[2])) |
|
|
|
|
|
elif mode == 'init': |
|
|
|
|
|
wrr.write(proto.write_message()) |
|
|
|
|
|
|
|
|
wrr.write(proto.write_message()) |
|
|
|
|
|
|
|
|
proto.read_message(await rdr.readexactly(_handshakelens[1])) |
|
|
|
|
|
|
|
|
proto.read_message(await rdr.readexactly(_handshakelens[2])) |
|
|
|
|
|
|
|
|
wrr.write(proto.write_message()) |
|
|
|
|
|
|
|
|
if not proto.handshake_finished: # pragma: no cover |
|
|
if not proto.handshake_finished: # pragma: no cover |
|
|
raise RuntimeError('failed to finish handshake') |
|
|
raise RuntimeError('failed to finish handshake') |
|
|
|
|
|
|
|
|
# generate the keys for lengths |
|
|
# generate the keys for lengths |
|
|
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') |
|
|
|
|
|
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') |
|
|
|
|
|
|
|
|
if mode == 'resp': |
|
|
|
|
|
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') |
|
|
|
|
|
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') |
|
|
|
|
|
elif mode == 'init': |
|
|
|
|
|
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') |
|
|
|
|
|
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') |
|
|
|
|
|
|
|
|
reader, writer = await connectsockstr(ptsockstr) |
|
|
|
|
|
|
|
|
reader, writer = await ptpair |
|
|
|
|
|
|
|
|
async def decses(): |
|
|
async def decses(): |
|
|
try: |
|
|
try: |
|
@@ -130,6 +146,10 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): |
|
|
tmsg = msg[2:] + rmsg |
|
|
tmsg = msg[2:] + rmsg |
|
|
writer.write(proto.decrypt(tmsg)) |
|
|
writer.write(proto.decrypt(tmsg)) |
|
|
await writer.drain() |
|
|
await writer.drain() |
|
|
|
|
|
#except: |
|
|
|
|
|
# import traceback |
|
|
|
|
|
# traceback.print_exc() |
|
|
|
|
|
# raise |
|
|
finally: |
|
|
finally: |
|
|
writer.write_eof() |
|
|
writer.write_eof() |
|
|
|
|
|
|
|
@@ -145,17 +165,17 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): |
|
|
wrr.write(enclenfun(encmsg)) |
|
|
wrr.write(enclenfun(encmsg)) |
|
|
wrr.write(encmsg) |
|
|
wrr.write(encmsg) |
|
|
await wrr.drain() |
|
|
await wrr.drain() |
|
|
|
|
|
#except: |
|
|
|
|
|
# import traceback |
|
|
|
|
|
# traceback.print_exc() |
|
|
|
|
|
# raise |
|
|
finally: |
|
|
finally: |
|
|
wrr.write_eof() |
|
|
wrr.write_eof() |
|
|
|
|
|
|
|
|
return asyncio.gather(decses(), encses()) |
|
|
|
|
|
|
|
|
|
|
|
class TestListenSocket(unittest.TestCase): |
|
|
|
|
|
def test_listensockstr(self): |
|
|
|
|
|
# XXX write test |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
return await asyncio.gather(decses(), encses()) |
|
|
|
|
|
|
|
|
# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code |
|
|
# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code |
|
|
|
|
|
# Slightly modified to timeout |
|
|
def async_test(f): |
|
|
def async_test(f): |
|
|
def wrapper(*args, **kwargs): |
|
|
def wrapper(*args, **kwargs): |
|
|
coro = asyncio.coroutine(f) |
|
|
coro = asyncio.coroutine(f) |
|
@@ -167,6 +187,10 @@ def async_test(f): |
|
|
return wrapper |
|
|
return wrapper |
|
|
|
|
|
|
|
|
class Tests_misc(unittest.TestCase): |
|
|
class Tests_misc(unittest.TestCase): |
|
|
|
|
|
def test_listensockstr(self): |
|
|
|
|
|
# XXX write test |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
def test_genciphfun(self): |
|
|
def test_genciphfun(self): |
|
|
enc, dec = _genciphfun(b'0' * 32, b'foobar') |
|
|
enc, dec = _genciphfun(b'0' * 32, b'foobar') |
|
|
|
|
|
|
|
@@ -180,6 +204,15 @@ class Tests_misc(unittest.TestCase): |
|
|
msg = os.urandom(i) |
|
|
msg = os.urandom(i) |
|
|
self.assertEqual(len(msg), dec(enc(msg) + msg)) |
|
|
self.assertEqual(len(msg), dec(enc(msg) + msg)) |
|
|
|
|
|
|
|
|
|
|
|
def _asyncsockpair(): |
|
|
|
|
|
'''Create a pair of sockets that are bound to each other. |
|
|
|
|
|
The function will return a tuple of two coroutine's, that |
|
|
|
|
|
each, when await'ed upon, will return the reader/writer pair.''' |
|
|
|
|
|
|
|
|
|
|
|
socka, sockb = socket.socketpair() |
|
|
|
|
|
|
|
|
|
|
|
return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb) |
|
|
|
|
|
|
|
|
class Tests(unittest.TestCase): |
|
|
class Tests(unittest.TestCase): |
|
|
def setUp(self): |
|
|
def setUp(self): |
|
|
# setup temporary directory |
|
|
# setup temporary directory |
|
@@ -222,7 +255,8 @@ class Tests(unittest.TestCase): |
|
|
event = asyncio.Event() |
|
|
event = asyncio.Event() |
|
|
|
|
|
|
|
|
async def runnf(rdr, wrr): |
|
|
async def runnf(rdr, wrr): |
|
|
a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg) |
|
|
|
|
|
|
|
|
ptpair = asyncio.create_task(connectsockstr(pttarg)) |
|
|
|
|
|
a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1]) |
|
|
|
|
|
|
|
|
nfs.append(a) |
|
|
nfs.append(a) |
|
|
event.set() |
|
|
event.set() |
|
@@ -323,4 +357,59 @@ class Tests(unittest.TestCase): |
|
|
|
|
|
|
|
|
await event.wait() |
|
|
await event.wait() |
|
|
|
|
|
|
|
|
self.assertEqual(await nfs[0], [ 'dec', 'enc' ]) |
|
|
|
|
|
|
|
|
self.assertEqual(nfs[0], [ 'dec', 'enc' ]) |
|
|
|
|
|
|
|
|
|
|
|
@async_test |
|
|
|
|
|
async def test_serverclient(self): |
|
|
|
|
|
# plumbing: |
|
|
|
|
|
# |
|
|
|
|
|
# ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb |
|
|
|
|
|
# |
|
|
|
|
|
|
|
|
|
|
|
ptcsockapair, ptcsockbpair = _asyncsockpair() |
|
|
|
|
|
ptcareader, ptcawriter = await ptcsockapair |
|
|
|
|
|
#ptcsockbpair passed directly |
|
|
|
|
|
clssockapair, clssockbpair = _asyncsockpair() |
|
|
|
|
|
clsapair = await clssockapair |
|
|
|
|
|
clsbpair = await clssockbpair |
|
|
|
|
|
ptssockapair, ptssockbpair = _asyncsockpair() |
|
|
|
|
|
#ptssockapair passed directly |
|
|
|
|
|
ptsbreader, ptsbwriter = await ptssockbpair |
|
|
|
|
|
|
|
|
|
|
|
clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) |
|
|
|
|
|
servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1])) |
|
|
|
|
|
|
|
|
|
|
|
# send a message |
|
|
|
|
|
msga = os.urandom(183) |
|
|
|
|
|
ptcawriter.write(msga) |
|
|
|
|
|
|
|
|
|
|
|
# make sure we get the same message |
|
|
|
|
|
self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) |
|
|
|
|
|
|
|
|
|
|
|
# send a second message |
|
|
|
|
|
msga = os.urandom(2834) |
|
|
|
|
|
ptcawriter.write(msga) |
|
|
|
|
|
|
|
|
|
|
|
# make sure we get the same message |
|
|
|
|
|
self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) |
|
|
|
|
|
|
|
|
|
|
|
# send a message larger than the block size |
|
|
|
|
|
msga = os.urandom(103958) |
|
|
|
|
|
ptcawriter.write(msga) |
|
|
|
|
|
|
|
|
|
|
|
# make sure we get the same message |
|
|
|
|
|
self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) |
|
|
|
|
|
|
|
|
|
|
|
# send a message the other direction |
|
|
|
|
|
msga = os.urandom(103958) |
|
|
|
|
|
ptsbwriter.write(msga) |
|
|
|
|
|
|
|
|
|
|
|
# make sure we get the same message |
|
|
|
|
|
self.assertEqual(msga, await ptcareader.readexactly(len(msga))) |
|
|
|
|
|
|
|
|
|
|
|
# close down the pt writers, the rest should follow |
|
|
|
|
|
ptsbwriter.write_eof() |
|
|
|
|
|
ptcawriter.write_eof() |
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual([ 'dec', 'enc' ], await clientnf) |
|
|
|
|
|
self.assertEqual([ 'dec', 'enc' ], await servnf) |