Browse Source

use an asynccontextmanager to make sure that subprocesses are terminated

tags/v0.1.0
John-Mark Gurney 6 years ago
parent
commit
74ff15da8c
1 changed files with 114 additions and 101 deletions
  1. +114
    -101
      ntunnel.py

+ 114
- 101
ntunnel.py View File

@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
@@ -651,34 +652,40 @@ class TestMain(unittest.TestCase):
shutil.rmtree(self.basetempdir) shutil.rmtree(self.basetempdir)
self.tempdir = None self.tempdir = None


@async_test
async def test_noargs(self):
proc = await self.run_with_args()

await proc.wait()

# XXX - not checking error message

# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 5)

def run_with_args(self, *args, pipes=True):
@asynccontextmanager
async def run_with_args(self, *args, pipes=True):
kwargs = {} kwargs = {}
if pipes: if pipes:
kwargs.update(dict( kwargs.update(dict(
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)) stderr=asyncio.subprocess.PIPE))
return asyncio.create_subprocess_exec(sys.executable,
aproc = asyncio.create_subprocess_exec(sys.executable,
# XXX - figure out how to add coverage data on these runs # XXX - figure out how to add coverage data on these runs
#'-m', 'coverage', 'run', '-p', #'-m', 'coverage', 'run', '-p',
__file__, *args, **kwargs) __file__, *args, **kwargs)


async def genkey(self, name):
proc = await self.run_with_args('genkey', name, pipes=False)
try:
proc = await aproc
yield proc
finally:
if proc.returncode is None:
proc.terminate()


await proc.wait()
@async_test
async def test_noargs(self):
async with self.run_with_args() as proc:
await proc.wait()


self.assertEqual(proc.returncode, 0)
# XXX - not checking error message

# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 5)

async def genkey(self, name):
async with self.run_with_args('genkey', name, pipes=False) as proc:
await proc.wait()

self.assertEqual(proc.returncode, 0)


@async_test @async_test
async def test_loadpubkey(self): async def test_loadpubkey(self):
@@ -690,7 +697,8 @@ class TestMain(unittest.TestCase):
enc = serialization.Encoding.Raw enc = serialization.Encoding.Raw
pubformat = serialization.PublicFormat.Raw pubformat = serialization.PublicFormat.Raw


pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
pubkeybytes = privkey.public_key().public_bytes(encoding=enc,
format=pubformat)


pubkey = loadpubkeyraw(keypath + '.pub') pubkey = loadpubkeyraw(keypath + '.pub')


@@ -702,7 +710,8 @@ class TestMain(unittest.TestCase):
privformat = serialization.PrivateFormat.Raw privformat = serialization.PrivateFormat.Raw
encalgo = serialization.NoEncryption() encalgo = serialization.NoEncryption()


rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
rprivrawkey = privkey.private_bytes(encoding=enc,
format=privformat, encryption_algorithm=encalgo)


self.assertEqual(rprivrawkey, privrawkey) self.assertEqual(rprivrawkey, privrawkey)


@@ -738,53 +747,54 @@ class TestMain(unittest.TestCase):
lsock = await listensockstr(servtargstr, None) lsock = await listensockstr(servtargstr, None)


# Startup the server # Startup the server
server = await self.run_with_args('server',
wserver = self.run_with_args('server',
'-c', clientkeypath + '.pub', '-c', clientkeypath + '.pub',
servkeypath, incservstr, servtargstr) servkeypath, incservstr, servtargstr)


# Startup the client with the "bad" key # Startup the client with the "bad" key
client = await self.run_with_args('client',
badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr)
wclient = self.run_with_args('client', badclientkeypath,
servkeypath + '.pub', ptclientstr, incservstr)


# wait for server target to be created
await _awaitfile(servtargpath)
async with wserver as server, wclient as client:
# wait for server target to be created
await _awaitfile(servtargpath)


# wait for server to start
await _awaitfile(incservpath)
# wait for server to start
await _awaitfile(incservpath)


# wait for client to start
await _awaitfile(ptclientpath)
# wait for client to start
await _awaitfile(ptclientpath)


# Connect to the client
reader, writer = await connectsockstr(ptclientstr)
# Connect to the client
reader, writer = await connectsockstr(ptclientstr)


# XXX - this might not be the best test.
with self.assertRaises(asyncio.futures.TimeoutError):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)
# XXX - this might not be the best test.
with self.assertRaises(asyncio.futures.TimeoutError):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)


writer.close()
writer.close()


# Make sure that when the server is terminated
server.terminate()
# Make sure that when the server is terminated
server.terminate()


# that it's stderr
stdout, stderr = await server.communicate()
#print('s:', repr((stdout, stderr)))
# that it's stderr
stdout, stderr = await server.communicate()
#print('s:', repr((stdout, stderr)))


# doesn't have an exceptions never retrieved
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)
# doesn't have an exceptions never retrieved
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)


lsock.close()
await lsock.wait_closed()
lsock.close()
await lsock.wait_closed()


# Kill off the client
client.terminate()
# Kill off the client
client.terminate()


stdout, stderr = await client.communicate()
#print('s:', repr((stdout, stderr)))
# XXX - figure out how to clean up client properly
stdout, stderr = await client.communicate()
#print('s:', repr((stdout, stderr)))
# XXX - figure out how to clean up client properly


@async_test @async_test
async def test_end2end(self): async def test_end2end(self):
@@ -817,72 +827,73 @@ class TestMain(unittest.TestCase):
lsock = await listensockstr(servtargstr, ptsockaccept) lsock = await listensockstr(servtargstr, ptsockaccept)


# Startup the server # Startup the server
server = await self.run_with_args('server',
wserver = self.run_with_args('server',
'-c', clientkeypath + '.pub', '-c', clientkeypath + '.pub',
servkeypath, incservstr, servtargstr, servkeypath, incservstr, servtargstr,
pipes=False) pipes=False)


# Startup the client # Startup the client
client = await self.run_with_args('client',
clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
pipes=False)
wclient = self.run_with_args('client',
clientkeypath, servkeypath + '.pub', ptclientstr,
incservstr, pipes=False)


# wait for server target to be created
await _awaitfile(servtargpath)
async with wserver as server, wclient as client:
# wait for server target to be created
await _awaitfile(servtargpath)


# wait for server to start
await _awaitfile(incservpath)
# wait for server to start
await _awaitfile(incservpath)


# wait for client to start
await _awaitfile(ptclientpath)
# wait for client to start
await _awaitfile(ptclientpath)


# Connect to the client
reader, writer = await connectsockstr(ptclientstr)
# Connect to the client
reader, writer = await connectsockstr(ptclientstr)


# send a message
ptmsg = b'this is a message for testing'
writer.write(ptmsg)
# send a message
ptmsg = b'this is a message for testing'
writer.write(ptmsg)


# make sure that we got the conenction
await ptsockevent.wait()
# make sure that we got the conenction
await ptsockevent.wait()


# get the connection
endrdr, endwrr = ptsock[0]
# get the connection
endrdr, endwrr = ptsock[0]


# make sure we can read back what we sent
self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
# make sure we can read back what we sent
self.assertEqual(ptmsg,
await endrdr.readexactly(len(ptmsg)))


# test some additional messages
for i in [ 129, 1287, 28792, 129872 ]:
# in on direction
msg = os.urandom(i)
writer.write(msg)
self.assertEqual(msg, await endrdr.readexactly(len(msg)))
# test some additional messages
for i in [ 129, 1287, 28792, 129872 ]:
# in on direction
msg = os.urandom(i)
writer.write(msg)
self.assertEqual(msg,
await endrdr.readexactly(len(msg)))


# and the other
endwrr.write(msg)
self.assertEqual(msg, await reader.readexactly(len(msg)))
# and the other
endwrr.write(msg)
self.assertEqual(msg,
await reader.readexactly(len(msg)))


writer.close()
endwrr.close()
writer.close()
endwrr.close()


lsock.close()
await lsock.wait_closed()
lsock.close()
await lsock.wait_closed()


server.terminate()
client.terminate()
# XXX - more clean up testing
# XXX - more testing that things exited properly


@async_test @async_test
async def test_genkey(self): async def test_genkey(self):
# that it can generate a key # that it can generate a key
proc = await self.run_with_args('genkey', 'somefile')

await proc.wait()
async with self.run_with_args('genkey', 'somefile') as proc:
await proc.wait()


#print(await proc.communicate())
#print(await proc.communicate())


self.assertEqual(proc.returncode, 0)
self.assertEqual(proc.returncode, 0)


with open('somefile.pub', encoding='ascii') as fp: with open('somefile.pub', encoding='ascii') as fp:
lines = fp.readlines() lines = fp.readlines()
@@ -891,23 +902,25 @@ class TestMain(unittest.TestCase):
keytype, keyvalue = lines[0].split() keytype, keyvalue = lines[0].split()


self.assertEqual(keytype, 'ntun-x448') self.assertEqual(keytype, 'ntun-x448')
key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
key = x448.X448PublicKey.from_public_bytes(
base64.urlsafe_b64decode(keyvalue))


key = loadprivkey('somefile') key = loadprivkey('somefile')
self.assertIsInstance(key, x448.X448PrivateKey) self.assertIsInstance(key, x448.X448PrivateKey)


# that a second call fails # that a second call fails
proc = await self.run_with_args('genkey', 'somefile')

await proc.wait()
async with self.run_with_args('genkey', 'somefile') as proc:
await proc.wait()


stdoutdata, stderrdata = await proc.communicate()
stdoutdata, stderrdata = await proc.communicate()


self.assertFalse(stdoutdata)
self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
self.assertFalse(stdoutdata)
self.assertEqual(
b'failed to create somefile.pub, file exists.\n',
stderrdata)


# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 2)
# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 2)


class TestNoiseFowarder(unittest.TestCase): class TestNoiseFowarder(unittest.TestCase):
def setUp(self): def setUp(self):


Loading…
Cancel
Save