diff --git a/ntunnel.py b/ntunnel.py index e5aaf3d..e012e4d 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -15,6 +15,13 @@ import unittest _backend = default_backend() +def _makefut(obj): + loop = asyncio.get_running_loop() + fut = loop.create_future() + fut.set_result(obj) + + return fut + def _makeunix(path): '''Make a properly formed unix path socket string.''' @@ -91,28 +98,30 @@ def _genciphfun(hash, ad): return encfun, decfun async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): - rdr, wrr = rdrwrr + rdr, wrr = await rdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) if pub_key is not None: - proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, + pub_key) if mode == 'resp': proto.set_as_responder() - elif mode == 'init': - proto.set_as_initiator() - proto.start_handshake() + proto.start_handshake() - if mode == 'resp': proto.read_message(await rdr.readexactly(_handshakelens[0])) wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[2])) elif mode == 'init': + proto.set_as_initiator() + + proto.start_handshake() + wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[1])) @@ -156,7 +165,8 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): async def encses(): try: while True: - ptmsg = await reader.read(65535 - 16) # largest message + # largest message + ptmsg = await reader.read(65535 - 16) if not ptmsg: # eof return 'enc' @@ -211,7 +221,8 @@ def _asyncsockpair(): socka, sockb = socket.socketpair() - return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb) + return asyncio.open_connection(sock=socka), \ + asyncio.open_connection(sock=sockb) class Tests(unittest.TestCase): def setUp(self): @@ -256,7 +267,10 @@ class Tests(unittest.TestCase): async def runnf(rdr, wrr): ptpair = asyncio.create_task(connectsockstr(pttarg)) - a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1]) + + a = await NoiseForwarder('resp', + _makefut((rdr, wrr)), ptpair, + priv_key=self.server_key_pair[1]) nfs.append(a) event.set() @@ -272,8 +286,10 @@ class Tests(unittest.TestCase): proto.set_as_initiator() # Setup required keys - proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) - proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) + proto.set_keypair_from_private_bytes(Keypair.STATIC, + self.client_key_pair[1]) + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, + self.server_key_pair[0]) proto.start_handshake() @@ -294,8 +310,10 @@ class Tests(unittest.TestCase): self.assertTrue(proto.handshake_finished) # generate the keys for lengths - enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') - _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), + b'toresp') + _, declenfun = _genciphfun(proto.get_handshake_hash(), + b'toinit') # write a test message ptmsg = b'this is a test message that should be a little in length' @@ -345,14 +363,14 @@ class Tests(unittest.TestCase): writer.write_eof() # so pt reader should be shut down - r = await ptreader.read(1) + self.assertEqual(b'', await ptreader.read(1)) self.assertTrue(ptreader.at_eof()) # shut down pt ptwriter.write_eof() # make sure the enc reader is eof - r = await reader.read(1) + self.assertEqual(b'', await reader.read(1)) self.assertTrue(reader.at_eof()) await event.wait() @@ -370,14 +388,18 @@ class Tests(unittest.TestCase): ptcareader, ptcawriter = await ptcsockapair #ptcsockbpair passed directly clssockapair, clssockbpair = _asyncsockpair() - clsapair = await clssockapair - clsbpair = await clssockbpair + #both passed directly ptssockapair, ptssockbpair = _asyncsockpair() #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair - clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) - servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1])) + clientnf = asyncio.create_task(NoiseForwarder('init', + clssockapair, ptcsockbpair, + priv_key=self.client_key_pair[1], + pub_key=self.server_key_pair[0])) + servnf = asyncio.create_task(NoiseForwarder('resp', + clssockbpair, ptssockapair, + priv_key=self.server_key_pair[1])) # send a message msga = os.urandom(183) @@ -411,5 +433,11 @@ class Tests(unittest.TestCase): ptsbwriter.write_eof() ptcawriter.write_eof() + # make sure they are closed, and there is no more data + self.assertEqual(b'', await ptsbreader.read(1)) + self.assertTrue(ptsbreader.at_eof()) + self.assertEqual(b'', await ptcareader.read(1)) + self.assertTrue(ptcareader.at_eof()) + self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf)