From be2e5fd6dab051a9a11c2650c29685f471b145ee Mon Sep 17 00:00:00 2001 From: John-Mark Gurney Date: Fri, 25 Oct 2019 09:50:11 -0700 Subject: [PATCH] minor reorg work.. make both pairs come from an awaitable object, this will be needed for future work... keep lines under 80 chars... Minor reorg of the main loop to make it clearer... --- ntunnel.py | 66 ++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/ntunnel.py b/ntunnel.py index e5aaf3d..e012e4d 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -15,6 +15,13 @@ import unittest _backend = default_backend() +def _makefut(obj): + loop = asyncio.get_running_loop() + fut = loop.create_future() + fut.set_result(obj) + + return fut + def _makeunix(path): '''Make a properly formed unix path socket string.''' @@ -91,28 +98,30 @@ def _genciphfun(hash, ad): return encfun, decfun async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): - rdr, wrr = rdrwrr + rdr, wrr = await rdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) if pub_key is not None: - proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, + pub_key) if mode == 'resp': proto.set_as_responder() - elif mode == 'init': - proto.set_as_initiator() - proto.start_handshake() + proto.start_handshake() - if mode == 'resp': proto.read_message(await rdr.readexactly(_handshakelens[0])) wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[2])) elif mode == 'init': + proto.set_as_initiator() + + proto.start_handshake() + wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[1])) @@ -156,7 +165,8 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): async def encses(): try: while True: - ptmsg = await reader.read(65535 - 16) # largest message + # largest message + ptmsg = await reader.read(65535 - 16) if not ptmsg: # eof return 'enc' @@ -211,7 +221,8 @@ def _asyncsockpair(): socka, sockb = socket.socketpair() - return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb) + return asyncio.open_connection(sock=socka), \ + asyncio.open_connection(sock=sockb) class Tests(unittest.TestCase): def setUp(self): @@ -256,7 +267,10 @@ class Tests(unittest.TestCase): async def runnf(rdr, wrr): ptpair = asyncio.create_task(connectsockstr(pttarg)) - a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1]) + + a = await NoiseForwarder('resp', + _makefut((rdr, wrr)), ptpair, + priv_key=self.server_key_pair[1]) nfs.append(a) event.set() @@ -272,8 +286,10 @@ class Tests(unittest.TestCase): proto.set_as_initiator() # Setup required keys - proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) - proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) + proto.set_keypair_from_private_bytes(Keypair.STATIC, + self.client_key_pair[1]) + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, + self.server_key_pair[0]) proto.start_handshake() @@ -294,8 +310,10 @@ class Tests(unittest.TestCase): self.assertTrue(proto.handshake_finished) # generate the keys for lengths - enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') - _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), + b'toresp') + _, declenfun = _genciphfun(proto.get_handshake_hash(), + b'toinit') # write a test message ptmsg = b'this is a test message that should be a little in length' @@ -345,14 +363,14 @@ class Tests(unittest.TestCase): writer.write_eof() # so pt reader should be shut down - r = await ptreader.read(1) + self.assertEqual(b'', await ptreader.read(1)) self.assertTrue(ptreader.at_eof()) # shut down pt ptwriter.write_eof() # make sure the enc reader is eof - r = await reader.read(1) + self.assertEqual(b'', await reader.read(1)) self.assertTrue(reader.at_eof()) await event.wait() @@ -370,14 +388,18 @@ class Tests(unittest.TestCase): ptcareader, ptcawriter = await ptcsockapair #ptcsockbpair passed directly clssockapair, clssockbpair = _asyncsockpair() - clsapair = await clssockapair - clsbpair = await clssockbpair + #both passed directly ptssockapair, ptssockbpair = _asyncsockpair() #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair - clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) - servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1])) + clientnf = asyncio.create_task(NoiseForwarder('init', + clssockapair, ptcsockbpair, + priv_key=self.client_key_pair[1], + pub_key=self.server_key_pair[0])) + servnf = asyncio.create_task(NoiseForwarder('resp', + clssockbpair, ptssockapair, + priv_key=self.server_key_pair[1])) # send a message msga = os.urandom(183) @@ -411,5 +433,11 @@ class Tests(unittest.TestCase): ptsbwriter.write_eof() ptcawriter.write_eof() + # make sure they are closed, and there is no more data + self.assertEqual(b'', await ptsbreader.read(1)) + self.assertTrue(ptsbreader.at_eof()) + self.assertEqual(b'', await ptcareader.read(1)) + self.assertTrue(ptcareader.at_eof()) + self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf)