| @@ -1,3 +1,4 @@ | |||||
| from contextlib import asynccontextmanager | |||||
| from cryptography.hazmat.backends import default_backend | from cryptography.hazmat.backends import default_backend | ||||
| from cryptography.hazmat.primitives import hashes | from cryptography.hazmat.primitives import hashes | ||||
| from cryptography.hazmat.primitives import serialization | from cryptography.hazmat.primitives import serialization | ||||
| @@ -651,34 +652,40 @@ class TestMain(unittest.TestCase): | |||||
| shutil.rmtree(self.basetempdir) | shutil.rmtree(self.basetempdir) | ||||
| self.tempdir = None | self.tempdir = None | ||||
| @async_test | |||||
| async def test_noargs(self): | |||||
| proc = await self.run_with_args() | |||||
| await proc.wait() | |||||
| # XXX - not checking error message | |||||
| # And that it exited w/ the correct code | |||||
| self.assertEqual(proc.returncode, 5) | |||||
| def run_with_args(self, *args, pipes=True): | |||||
| @asynccontextmanager | |||||
| async def run_with_args(self, *args, pipes=True): | |||||
| kwargs = {} | kwargs = {} | ||||
| if pipes: | if pipes: | ||||
| kwargs.update(dict( | kwargs.update(dict( | ||||
| stdout=asyncio.subprocess.PIPE, | stdout=asyncio.subprocess.PIPE, | ||||
| stderr=asyncio.subprocess.PIPE)) | stderr=asyncio.subprocess.PIPE)) | ||||
| return asyncio.create_subprocess_exec(sys.executable, | |||||
| aproc = asyncio.create_subprocess_exec(sys.executable, | |||||
| # XXX - figure out how to add coverage data on these runs | # XXX - figure out how to add coverage data on these runs | ||||
| #'-m', 'coverage', 'run', '-p', | #'-m', 'coverage', 'run', '-p', | ||||
| __file__, *args, **kwargs) | __file__, *args, **kwargs) | ||||
| async def genkey(self, name): | |||||
| proc = await self.run_with_args('genkey', name, pipes=False) | |||||
| try: | |||||
| proc = await aproc | |||||
| yield proc | |||||
| finally: | |||||
| if proc.returncode is None: | |||||
| proc.terminate() | |||||
| await proc.wait() | |||||
| @async_test | |||||
| async def test_noargs(self): | |||||
| async with self.run_with_args() as proc: | |||||
| await proc.wait() | |||||
| self.assertEqual(proc.returncode, 0) | |||||
| # XXX - not checking error message | |||||
| # And that it exited w/ the correct code | |||||
| self.assertEqual(proc.returncode, 5) | |||||
| async def genkey(self, name): | |||||
| async with self.run_with_args('genkey', name, pipes=False) as proc: | |||||
| await proc.wait() | |||||
| self.assertEqual(proc.returncode, 0) | |||||
| @async_test | @async_test | ||||
| async def test_loadpubkey(self): | async def test_loadpubkey(self): | ||||
| @@ -690,7 +697,8 @@ class TestMain(unittest.TestCase): | |||||
| enc = serialization.Encoding.Raw | enc = serialization.Encoding.Raw | ||||
| pubformat = serialization.PublicFormat.Raw | pubformat = serialization.PublicFormat.Raw | ||||
| pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat) | |||||
| pubkeybytes = privkey.public_key().public_bytes(encoding=enc, | |||||
| format=pubformat) | |||||
| pubkey = loadpubkeyraw(keypath + '.pub') | pubkey = loadpubkeyraw(keypath + '.pub') | ||||
| @@ -702,7 +710,8 @@ class TestMain(unittest.TestCase): | |||||
| privformat = serialization.PrivateFormat.Raw | privformat = serialization.PrivateFormat.Raw | ||||
| encalgo = serialization.NoEncryption() | encalgo = serialization.NoEncryption() | ||||
| rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) | |||||
| rprivrawkey = privkey.private_bytes(encoding=enc, | |||||
| format=privformat, encryption_algorithm=encalgo) | |||||
| self.assertEqual(rprivrawkey, privrawkey) | self.assertEqual(rprivrawkey, privrawkey) | ||||
| @@ -738,53 +747,54 @@ class TestMain(unittest.TestCase): | |||||
| lsock = await listensockstr(servtargstr, None) | lsock = await listensockstr(servtargstr, None) | ||||
| # Startup the server | # Startup the server | ||||
| server = await self.run_with_args('server', | |||||
| wserver = self.run_with_args('server', | |||||
| '-c', clientkeypath + '.pub', | '-c', clientkeypath + '.pub', | ||||
| servkeypath, incservstr, servtargstr) | servkeypath, incservstr, servtargstr) | ||||
| # Startup the client with the "bad" key | # Startup the client with the "bad" key | ||||
| client = await self.run_with_args('client', | |||||
| badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr) | |||||
| wclient = self.run_with_args('client', badclientkeypath, | |||||
| servkeypath + '.pub', ptclientstr, incservstr) | |||||
| # wait for server target to be created | |||||
| await _awaitfile(servtargpath) | |||||
| async with wserver as server, wclient as client: | |||||
| # wait for server target to be created | |||||
| await _awaitfile(servtargpath) | |||||
| # wait for server to start | |||||
| await _awaitfile(incservpath) | |||||
| # wait for server to start | |||||
| await _awaitfile(incservpath) | |||||
| # wait for client to start | |||||
| await _awaitfile(ptclientpath) | |||||
| # wait for client to start | |||||
| await _awaitfile(ptclientpath) | |||||
| # Connect to the client | |||||
| reader, writer = await connectsockstr(ptclientstr) | |||||
| # Connect to the client | |||||
| reader, writer = await connectsockstr(ptclientstr) | |||||
| # XXX - this might not be the best test. | |||||
| with self.assertRaises(asyncio.futures.TimeoutError): | |||||
| # make sure that we don't get the conenction | |||||
| await asyncio.wait_for(ptsockevent.wait(), .5) | |||||
| # XXX - this might not be the best test. | |||||
| with self.assertRaises(asyncio.futures.TimeoutError): | |||||
| # make sure that we don't get the conenction | |||||
| await asyncio.wait_for(ptsockevent.wait(), .5) | |||||
| writer.close() | |||||
| writer.close() | |||||
| # Make sure that when the server is terminated | |||||
| server.terminate() | |||||
| # Make sure that when the server is terminated | |||||
| server.terminate() | |||||
| # that it's stderr | |||||
| stdout, stderr = await server.communicate() | |||||
| #print('s:', repr((stdout, stderr))) | |||||
| # that it's stderr | |||||
| stdout, stderr = await server.communicate() | |||||
| #print('s:', repr((stdout, stderr))) | |||||
| # doesn't have an exceptions never retrieved | |||||
| # even the example echo server has this same leak | |||||
| #self.assertNotIn(b'Task exception was never retrieved', stderr) | |||||
| # doesn't have an exceptions never retrieved | |||||
| # even the example echo server has this same leak | |||||
| #self.assertNotIn(b'Task exception was never retrieved', stderr) | |||||
| lsock.close() | |||||
| await lsock.wait_closed() | |||||
| lsock.close() | |||||
| await lsock.wait_closed() | |||||
| # Kill off the client | |||||
| client.terminate() | |||||
| # Kill off the client | |||||
| client.terminate() | |||||
| stdout, stderr = await client.communicate() | |||||
| #print('s:', repr((stdout, stderr))) | |||||
| # XXX - figure out how to clean up client properly | |||||
| stdout, stderr = await client.communicate() | |||||
| #print('s:', repr((stdout, stderr))) | |||||
| # XXX - figure out how to clean up client properly | |||||
| @async_test | @async_test | ||||
| async def test_end2end(self): | async def test_end2end(self): | ||||
| @@ -817,72 +827,73 @@ class TestMain(unittest.TestCase): | |||||
| lsock = await listensockstr(servtargstr, ptsockaccept) | lsock = await listensockstr(servtargstr, ptsockaccept) | ||||
| # Startup the server | # Startup the server | ||||
| server = await self.run_with_args('server', | |||||
| wserver = self.run_with_args('server', | |||||
| '-c', clientkeypath + '.pub', | '-c', clientkeypath + '.pub', | ||||
| servkeypath, incservstr, servtargstr, | servkeypath, incservstr, servtargstr, | ||||
| pipes=False) | pipes=False) | ||||
| # Startup the client | # Startup the client | ||||
| client = await self.run_with_args('client', | |||||
| clientkeypath, servkeypath + '.pub', ptclientstr, incservstr, | |||||
| pipes=False) | |||||
| wclient = self.run_with_args('client', | |||||
| clientkeypath, servkeypath + '.pub', ptclientstr, | |||||
| incservstr, pipes=False) | |||||
| # wait for server target to be created | |||||
| await _awaitfile(servtargpath) | |||||
| async with wserver as server, wclient as client: | |||||
| # wait for server target to be created | |||||
| await _awaitfile(servtargpath) | |||||
| # wait for server to start | |||||
| await _awaitfile(incservpath) | |||||
| # wait for server to start | |||||
| await _awaitfile(incservpath) | |||||
| # wait for client to start | |||||
| await _awaitfile(ptclientpath) | |||||
| # wait for client to start | |||||
| await _awaitfile(ptclientpath) | |||||
| # Connect to the client | |||||
| reader, writer = await connectsockstr(ptclientstr) | |||||
| # Connect to the client | |||||
| reader, writer = await connectsockstr(ptclientstr) | |||||
| # send a message | |||||
| ptmsg = b'this is a message for testing' | |||||
| writer.write(ptmsg) | |||||
| # send a message | |||||
| ptmsg = b'this is a message for testing' | |||||
| writer.write(ptmsg) | |||||
| # make sure that we got the conenction | |||||
| await ptsockevent.wait() | |||||
| # make sure that we got the conenction | |||||
| await ptsockevent.wait() | |||||
| # get the connection | |||||
| endrdr, endwrr = ptsock[0] | |||||
| # get the connection | |||||
| endrdr, endwrr = ptsock[0] | |||||
| # make sure we can read back what we sent | |||||
| self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg))) | |||||
| # make sure we can read back what we sent | |||||
| self.assertEqual(ptmsg, | |||||
| await endrdr.readexactly(len(ptmsg))) | |||||
| # test some additional messages | |||||
| for i in [ 129, 1287, 28792, 129872 ]: | |||||
| # in on direction | |||||
| msg = os.urandom(i) | |||||
| writer.write(msg) | |||||
| self.assertEqual(msg, await endrdr.readexactly(len(msg))) | |||||
| # test some additional messages | |||||
| for i in [ 129, 1287, 28792, 129872 ]: | |||||
| # in on direction | |||||
| msg = os.urandom(i) | |||||
| writer.write(msg) | |||||
| self.assertEqual(msg, | |||||
| await endrdr.readexactly(len(msg))) | |||||
| # and the other | |||||
| endwrr.write(msg) | |||||
| self.assertEqual(msg, await reader.readexactly(len(msg))) | |||||
| # and the other | |||||
| endwrr.write(msg) | |||||
| self.assertEqual(msg, | |||||
| await reader.readexactly(len(msg))) | |||||
| writer.close() | |||||
| endwrr.close() | |||||
| writer.close() | |||||
| endwrr.close() | |||||
| lsock.close() | |||||
| await lsock.wait_closed() | |||||
| lsock.close() | |||||
| await lsock.wait_closed() | |||||
| server.terminate() | |||||
| client.terminate() | |||||
| # XXX - more clean up testing | |||||
| # XXX - more testing that things exited properly | |||||
| @async_test | @async_test | ||||
| async def test_genkey(self): | async def test_genkey(self): | ||||
| # that it can generate a key | # that it can generate a key | ||||
| proc = await self.run_with_args('genkey', 'somefile') | |||||
| await proc.wait() | |||||
| async with self.run_with_args('genkey', 'somefile') as proc: | |||||
| await proc.wait() | |||||
| #print(await proc.communicate()) | |||||
| #print(await proc.communicate()) | |||||
| self.assertEqual(proc.returncode, 0) | |||||
| self.assertEqual(proc.returncode, 0) | |||||
| with open('somefile.pub', encoding='ascii') as fp: | with open('somefile.pub', encoding='ascii') as fp: | ||||
| lines = fp.readlines() | lines = fp.readlines() | ||||
| @@ -891,23 +902,25 @@ class TestMain(unittest.TestCase): | |||||
| keytype, keyvalue = lines[0].split() | keytype, keyvalue = lines[0].split() | ||||
| self.assertEqual(keytype, 'ntun-x448') | self.assertEqual(keytype, 'ntun-x448') | ||||
| key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) | |||||
| key = x448.X448PublicKey.from_public_bytes( | |||||
| base64.urlsafe_b64decode(keyvalue)) | |||||
| key = loadprivkey('somefile') | key = loadprivkey('somefile') | ||||
| self.assertIsInstance(key, x448.X448PrivateKey) | self.assertIsInstance(key, x448.X448PrivateKey) | ||||
| # that a second call fails | # that a second call fails | ||||
| proc = await self.run_with_args('genkey', 'somefile') | |||||
| await proc.wait() | |||||
| async with self.run_with_args('genkey', 'somefile') as proc: | |||||
| await proc.wait() | |||||
| stdoutdata, stderrdata = await proc.communicate() | |||||
| stdoutdata, stderrdata = await proc.communicate() | |||||
| self.assertFalse(stdoutdata) | |||||
| self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata) | |||||
| self.assertFalse(stdoutdata) | |||||
| self.assertEqual( | |||||
| b'failed to create somefile.pub, file exists.\n', | |||||
| stderrdata) | |||||
| # And that it exited w/ the correct code | |||||
| self.assertEqual(proc.returncode, 2) | |||||
| # And that it exited w/ the correct code | |||||
| self.assertEqual(proc.returncode, 2) | |||||
| class TestNoiseFowarder(unittest.TestCase): | class TestNoiseFowarder(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||