diff --git a/ntunnel.py b/ntunnel.py index 178039a..e5aaf3d 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -90,31 +90,47 @@ def _genciphfun(hash, ad): return encfun, decfun -async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): +async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): rdr, wrr = rdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) + if pub_key is not None: + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) - proto.set_as_responder() + if mode == 'resp': + proto.set_as_responder() + elif mode == 'init': + proto.set_as_initiator() proto.start_handshake() - proto.read_message(await rdr.readexactly(_handshakelens[0])) + if mode == 'resp': + proto.read_message(await rdr.readexactly(_handshakelens[0])) + + wrr.write(proto.write_message()) + + proto.read_message(await rdr.readexactly(_handshakelens[2])) + elif mode == 'init': + wrr.write(proto.write_message()) - wrr.write(proto.write_message()) + proto.read_message(await rdr.readexactly(_handshakelens[1])) - proto.read_message(await rdr.readexactly(_handshakelens[2])) + wrr.write(proto.write_message()) if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') # generate the keys for lengths - _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') - enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') + if mode == 'resp': + _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') + elif mode == 'init': + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') + _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') - reader, writer = await connectsockstr(ptsockstr) + reader, writer = await ptpair async def decses(): try: @@ -130,6 +146,10 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): tmsg = msg[2:] + rmsg writer.write(proto.decrypt(tmsg)) await writer.drain() + #except: + # import traceback + # traceback.print_exc() + # raise finally: writer.write_eof() @@ -145,17 +165,17 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): wrr.write(enclenfun(encmsg)) wrr.write(encmsg) await wrr.drain() + #except: + # import traceback + # traceback.print_exc() + # raise finally: wrr.write_eof() - return asyncio.gather(decses(), encses()) - -class TestListenSocket(unittest.TestCase): - def test_listensockstr(self): - # XXX write test - pass + return await asyncio.gather(decses(), encses()) # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code +# Slightly modified to timeout def async_test(f): def wrapper(*args, **kwargs): coro = asyncio.coroutine(f) @@ -167,6 +187,10 @@ def async_test(f): return wrapper class Tests_misc(unittest.TestCase): + def test_listensockstr(self): + # XXX write test + pass + def test_genciphfun(self): enc, dec = _genciphfun(b'0' * 32, b'foobar') @@ -180,6 +204,15 @@ class Tests_misc(unittest.TestCase): msg = os.urandom(i) self.assertEqual(len(msg), dec(enc(msg) + msg)) +def _asyncsockpair(): + '''Create a pair of sockets that are bound to each other. + The function will return a tuple of two coroutine's, that + each, when await'ed upon, will return the reader/writer pair.''' + + socka, sockb = socket.socketpair() + + return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb) + class Tests(unittest.TestCase): def setUp(self): # setup temporary directory @@ -222,7 +255,8 @@ class Tests(unittest.TestCase): event = asyncio.Event() async def runnf(rdr, wrr): - a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg) + ptpair = asyncio.create_task(connectsockstr(pttarg)) + a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1]) nfs.append(a) event.set() @@ -323,4 +357,59 @@ class Tests(unittest.TestCase): await event.wait() - self.assertEqual(await nfs[0], [ 'dec', 'enc' ]) + self.assertEqual(nfs[0], [ 'dec', 'enc' ]) + + @async_test + async def test_serverclient(self): + # plumbing: + # + # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb + # + + ptcsockapair, ptcsockbpair = _asyncsockpair() + ptcareader, ptcawriter = await ptcsockapair + #ptcsockbpair passed directly + clssockapair, clssockbpair = _asyncsockpair() + clsapair = await clssockapair + clsbpair = await clssockbpair + ptssockapair, ptssockbpair = _asyncsockpair() + #ptssockapair passed directly + ptsbreader, ptsbwriter = await ptssockbpair + + clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) + servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1])) + + # send a message + msga = os.urandom(183) + ptcawriter.write(msga) + + # make sure we get the same message + self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) + + # send a second message + msga = os.urandom(2834) + ptcawriter.write(msga) + + # make sure we get the same message + self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) + + # send a message larger than the block size + msga = os.urandom(103958) + ptcawriter.write(msga) + + # make sure we get the same message + self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) + + # send a message the other direction + msga = os.urandom(103958) + ptsbwriter.write(msga) + + # make sure we get the same message + self.assertEqual(msga, await ptcareader.readexactly(len(msga))) + + # close down the pt writers, the rest should follow + ptsbwriter.write_eof() + ptcawriter.write_eof() + + self.assertEqual([ 'dec', 'enc' ], await clientnf) + self.assertEqual([ 'dec', 'enc' ], await servnf)