diff --git a/ntunnel.py b/ntunnel.py index 6a53135..c9de14c 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -117,13 +117,23 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): reader, writer = await connectsockstr(ptsockstr) async def decses(): - while True: - msg = await rdr.readexactly(2 + 16) - tlen = declenfun(msg) - rmsg = await rdr.readexactly(tlen - 16) - tmsg = msg[2:] + rmsg - writer.write(proto.decrypt(tmsg)) - await writer.drain() + try: + while True: + try: + msg = await rdr.readexactly(2 + 16) + except asyncio.streams.IncompleteReadError: + if rdr.at_eof(): + return 'dec' + + tlen = declenfun(msg) + rmsg = await rdr.readexactly(tlen - 16) + tmsg = msg[2:] + rmsg + writer.write(proto.decrypt(tmsg)) + await writer.drain() + finally: + print('foo') + # XXX - how to test + #writer.write_eof() async def encses(): while True: @@ -133,11 +143,15 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr): wrr.write(encmsg) await wrr.drain() - r = await asyncio.gather(decses(), encses(), return_exceptions=True) + done, pending = await asyncio.wait((decses(), encses()), return_when=asyncio.FIRST_COMPLETED) + for i in done: + print('v:', repr(await i)) - print(repr(r)) + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for i in done: + print('v:', repr(await i)) - return r + return done class TestListenSocket(unittest.TestCase): def test_listensockstr(self): @@ -203,8 +217,21 @@ class Tests(unittest.TestCase): # Bind to pt listener lsock = await listensockstr(pttarg, ptsockaccept) + nfs = [] + event = asyncio.Event() + + async def runnf(rdr, wrr): + print('a') + a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg) + + print('b') + nfs.append(a) + print('c') + event.set() + print('d') + # Setup server listener - ssock = await listensockstr(servarg, lambda rdr, wrr: NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)) + ssock = await listensockstr(servarg, runnf) # Connect to server reader, writer = await connectsockstr(servarg) @@ -283,8 +310,11 @@ class Tests(unittest.TestCase): self.assertEqual(rptmsg, ptmsg) # shut everything down - ptsock[0][1].write_eof() writer.write_eof() + #ptsock[0][1].write_eof() # XXX - how to sync? - await asyncio.sleep(1) + await asyncio.sleep(.1) + + await event.wait() + print(repr(nfs))