| @@ -1,3 +1,4 @@ | |||
| from contextlib import asynccontextmanager | |||
| from cryptography.hazmat.backends import default_backend | |||
| from cryptography.hazmat.primitives import hashes | |||
| from cryptography.hazmat.primitives import serialization | |||
| @@ -651,34 +652,40 @@ class TestMain(unittest.TestCase): | |||
| shutil.rmtree(self.basetempdir) | |||
| self.tempdir = None | |||
| @async_test | |||
| async def test_noargs(self): | |||
| proc = await self.run_with_args() | |||
| await proc.wait() | |||
| # XXX - not checking error message | |||
| # And that it exited w/ the correct code | |||
| self.assertEqual(proc.returncode, 5) | |||
| def run_with_args(self, *args, pipes=True): | |||
| @asynccontextmanager | |||
| async def run_with_args(self, *args, pipes=True): | |||
| kwargs = {} | |||
| if pipes: | |||
| kwargs.update(dict( | |||
| stdout=asyncio.subprocess.PIPE, | |||
| stderr=asyncio.subprocess.PIPE)) | |||
| return asyncio.create_subprocess_exec(sys.executable, | |||
| aproc = asyncio.create_subprocess_exec(sys.executable, | |||
| # XXX - figure out how to add coverage data on these runs | |||
| #'-m', 'coverage', 'run', '-p', | |||
| __file__, *args, **kwargs) | |||
| async def genkey(self, name): | |||
| proc = await self.run_with_args('genkey', name, pipes=False) | |||
| try: | |||
| proc = await aproc | |||
| yield proc | |||
| finally: | |||
| if proc.returncode is None: | |||
| proc.terminate() | |||
| await proc.wait() | |||
| @async_test | |||
| async def test_noargs(self): | |||
| async with self.run_with_args() as proc: | |||
| await proc.wait() | |||
| self.assertEqual(proc.returncode, 0) | |||
| # XXX - not checking error message | |||
| # And that it exited w/ the correct code | |||
| self.assertEqual(proc.returncode, 5) | |||
| async def genkey(self, name): | |||
| async with self.run_with_args('genkey', name, pipes=False) as proc: | |||
| await proc.wait() | |||
| self.assertEqual(proc.returncode, 0) | |||
| @async_test | |||
| async def test_loadpubkey(self): | |||
| @@ -690,7 +697,8 @@ class TestMain(unittest.TestCase): | |||
| enc = serialization.Encoding.Raw | |||
| pubformat = serialization.PublicFormat.Raw | |||
| pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat) | |||
| pubkeybytes = privkey.public_key().public_bytes(encoding=enc, | |||
| format=pubformat) | |||
| pubkey = loadpubkeyraw(keypath + '.pub') | |||
| @@ -702,7 +710,8 @@ class TestMain(unittest.TestCase): | |||
| privformat = serialization.PrivateFormat.Raw | |||
| encalgo = serialization.NoEncryption() | |||
| rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) | |||
| rprivrawkey = privkey.private_bytes(encoding=enc, | |||
| format=privformat, encryption_algorithm=encalgo) | |||
| self.assertEqual(rprivrawkey, privrawkey) | |||
| @@ -738,53 +747,54 @@ class TestMain(unittest.TestCase): | |||
| lsock = await listensockstr(servtargstr, None) | |||
| # Startup the server | |||
| server = await self.run_with_args('server', | |||
| wserver = self.run_with_args('server', | |||
| '-c', clientkeypath + '.pub', | |||
| servkeypath, incservstr, servtargstr) | |||
| # Startup the client with the "bad" key | |||
| client = await self.run_with_args('client', | |||
| badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr) | |||
| wclient = self.run_with_args('client', badclientkeypath, | |||
| servkeypath + '.pub', ptclientstr, incservstr) | |||
| # wait for server target to be created | |||
| await _awaitfile(servtargpath) | |||
| async with wserver as server, wclient as client: | |||
| # wait for server target to be created | |||
| await _awaitfile(servtargpath) | |||
| # wait for server to start | |||
| await _awaitfile(incservpath) | |||
| # wait for server to start | |||
| await _awaitfile(incservpath) | |||
| # wait for client to start | |||
| await _awaitfile(ptclientpath) | |||
| # wait for client to start | |||
| await _awaitfile(ptclientpath) | |||
| # Connect to the client | |||
| reader, writer = await connectsockstr(ptclientstr) | |||
| # Connect to the client | |||
| reader, writer = await connectsockstr(ptclientstr) | |||
| # XXX - this might not be the best test. | |||
| with self.assertRaises(asyncio.futures.TimeoutError): | |||
| # make sure that we don't get the conenction | |||
| await asyncio.wait_for(ptsockevent.wait(), .5) | |||
| # XXX - this might not be the best test. | |||
| with self.assertRaises(asyncio.futures.TimeoutError): | |||
| # make sure that we don't get the conenction | |||
| await asyncio.wait_for(ptsockevent.wait(), .5) | |||
| writer.close() | |||
| writer.close() | |||
| # Make sure that when the server is terminated | |||
| server.terminate() | |||
| # Make sure that when the server is terminated | |||
| server.terminate() | |||
| # that it's stderr | |||
| stdout, stderr = await server.communicate() | |||
| #print('s:', repr((stdout, stderr))) | |||
| # that it's stderr | |||
| stdout, stderr = await server.communicate() | |||
| #print('s:', repr((stdout, stderr))) | |||
| # doesn't have an exceptions never retrieved | |||
| # even the example echo server has this same leak | |||
| #self.assertNotIn(b'Task exception was never retrieved', stderr) | |||
| # doesn't have an exceptions never retrieved | |||
| # even the example echo server has this same leak | |||
| #self.assertNotIn(b'Task exception was never retrieved', stderr) | |||
| lsock.close() | |||
| await lsock.wait_closed() | |||
| lsock.close() | |||
| await lsock.wait_closed() | |||
| # Kill off the client | |||
| client.terminate() | |||
| # Kill off the client | |||
| client.terminate() | |||
| stdout, stderr = await client.communicate() | |||
| #print('s:', repr((stdout, stderr))) | |||
| # XXX - figure out how to clean up client properly | |||
| stdout, stderr = await client.communicate() | |||
| #print('s:', repr((stdout, stderr))) | |||
| # XXX - figure out how to clean up client properly | |||
| @async_test | |||
| async def test_end2end(self): | |||
| @@ -817,72 +827,73 @@ class TestMain(unittest.TestCase): | |||
| lsock = await listensockstr(servtargstr, ptsockaccept) | |||
| # Startup the server | |||
| server = await self.run_with_args('server', | |||
| wserver = self.run_with_args('server', | |||
| '-c', clientkeypath + '.pub', | |||
| servkeypath, incservstr, servtargstr, | |||
| pipes=False) | |||
| # Startup the client | |||
| client = await self.run_with_args('client', | |||
| clientkeypath, servkeypath + '.pub', ptclientstr, incservstr, | |||
| pipes=False) | |||
| wclient = self.run_with_args('client', | |||
| clientkeypath, servkeypath + '.pub', ptclientstr, | |||
| incservstr, pipes=False) | |||
| # wait for server target to be created | |||
| await _awaitfile(servtargpath) | |||
| async with wserver as server, wclient as client: | |||
| # wait for server target to be created | |||
| await _awaitfile(servtargpath) | |||
| # wait for server to start | |||
| await _awaitfile(incservpath) | |||
| # wait for server to start | |||
| await _awaitfile(incservpath) | |||
| # wait for client to start | |||
| await _awaitfile(ptclientpath) | |||
| # wait for client to start | |||
| await _awaitfile(ptclientpath) | |||
| # Connect to the client | |||
| reader, writer = await connectsockstr(ptclientstr) | |||
| # Connect to the client | |||
| reader, writer = await connectsockstr(ptclientstr) | |||
| # send a message | |||
| ptmsg = b'this is a message for testing' | |||
| writer.write(ptmsg) | |||
| # send a message | |||
| ptmsg = b'this is a message for testing' | |||
| writer.write(ptmsg) | |||
| # make sure that we got the conenction | |||
| await ptsockevent.wait() | |||
| # make sure that we got the conenction | |||
| await ptsockevent.wait() | |||
| # get the connection | |||
| endrdr, endwrr = ptsock[0] | |||
| # get the connection | |||
| endrdr, endwrr = ptsock[0] | |||
| # make sure we can read back what we sent | |||
| self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg))) | |||
| # make sure we can read back what we sent | |||
| self.assertEqual(ptmsg, | |||
| await endrdr.readexactly(len(ptmsg))) | |||
| # test some additional messages | |||
| for i in [ 129, 1287, 28792, 129872 ]: | |||
| # in on direction | |||
| msg = os.urandom(i) | |||
| writer.write(msg) | |||
| self.assertEqual(msg, await endrdr.readexactly(len(msg))) | |||
| # test some additional messages | |||
| for i in [ 129, 1287, 28792, 129872 ]: | |||
| # in on direction | |||
| msg = os.urandom(i) | |||
| writer.write(msg) | |||
| self.assertEqual(msg, | |||
| await endrdr.readexactly(len(msg))) | |||
| # and the other | |||
| endwrr.write(msg) | |||
| self.assertEqual(msg, await reader.readexactly(len(msg))) | |||
| # and the other | |||
| endwrr.write(msg) | |||
| self.assertEqual(msg, | |||
| await reader.readexactly(len(msg))) | |||
| writer.close() | |||
| endwrr.close() | |||
| writer.close() | |||
| endwrr.close() | |||
| lsock.close() | |||
| await lsock.wait_closed() | |||
| lsock.close() | |||
| await lsock.wait_closed() | |||
| server.terminate() | |||
| client.terminate() | |||
| # XXX - more clean up testing | |||
| # XXX - more testing that things exited properly | |||
| @async_test | |||
| async def test_genkey(self): | |||
| # that it can generate a key | |||
| proc = await self.run_with_args('genkey', 'somefile') | |||
| await proc.wait() | |||
| async with self.run_with_args('genkey', 'somefile') as proc: | |||
| await proc.wait() | |||
| #print(await proc.communicate()) | |||
| #print(await proc.communicate()) | |||
| self.assertEqual(proc.returncode, 0) | |||
| self.assertEqual(proc.returncode, 0) | |||
| with open('somefile.pub', encoding='ascii') as fp: | |||
| lines = fp.readlines() | |||
| @@ -891,23 +902,25 @@ class TestMain(unittest.TestCase): | |||
| keytype, keyvalue = lines[0].split() | |||
| self.assertEqual(keytype, 'ntun-x448') | |||
| key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) | |||
| key = x448.X448PublicKey.from_public_bytes( | |||
| base64.urlsafe_b64decode(keyvalue)) | |||
| key = loadprivkey('somefile') | |||
| self.assertIsInstance(key, x448.X448PrivateKey) | |||
| # that a second call fails | |||
| proc = await self.run_with_args('genkey', 'somefile') | |||
| await proc.wait() | |||
| async with self.run_with_args('genkey', 'somefile') as proc: | |||
| await proc.wait() | |||
| stdoutdata, stderrdata = await proc.communicate() | |||
| stdoutdata, stderrdata = await proc.communicate() | |||
| self.assertFalse(stdoutdata) | |||
| self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata) | |||
| self.assertFalse(stdoutdata) | |||
| self.assertEqual( | |||
| b'failed to create somefile.pub, file exists.\n', | |||
| stderrdata) | |||
| # And that it exited w/ the correct code | |||
| self.assertEqual(proc.returncode, 2) | |||
| # And that it exited w/ the correct code | |||
| self.assertEqual(proc.returncode, 2) | |||
| class TestNoiseFowarder(unittest.TestCase): | |||
| def setUp(self): | |||