Browse Source

use an asynccontextmanager to make sure that subprocesses are terminated

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
74ff15da8c
1 changed files with 114 additions and 101 deletions
  1. +114
    -101
      ntunnel.py

+ 114
- 101
ntunnel.py View File

@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
@@ -651,34 +652,40 @@ class TestMain(unittest.TestCase):
shutil.rmtree(self.basetempdir)
self.tempdir = None

@async_test
async def test_noargs(self):
proc = await self.run_with_args()

await proc.wait()

# XXX - not checking error message

# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 5)

def run_with_args(self, *args, pipes=True):
@asynccontextmanager
async def run_with_args(self, *args, pipes=True):
kwargs = {}
if pipes:
kwargs.update(dict(
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE))
return asyncio.create_subprocess_exec(sys.executable,
aproc = asyncio.create_subprocess_exec(sys.executable,
# XXX - figure out how to add coverage data on these runs
#'-m', 'coverage', 'run', '-p',
__file__, *args, **kwargs)

async def genkey(self, name):
proc = await self.run_with_args('genkey', name, pipes=False)
try:
proc = await aproc
yield proc
finally:
if proc.returncode is None:
proc.terminate()

await proc.wait()
@async_test
async def test_noargs(self):
async with self.run_with_args() as proc:
await proc.wait()

self.assertEqual(proc.returncode, 0)
# XXX - not checking error message

# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 5)

async def genkey(self, name):
async with self.run_with_args('genkey', name, pipes=False) as proc:
await proc.wait()

self.assertEqual(proc.returncode, 0)

@async_test
async def test_loadpubkey(self):
@@ -690,7 +697,8 @@ class TestMain(unittest.TestCase):
enc = serialization.Encoding.Raw
pubformat = serialization.PublicFormat.Raw

pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
pubkeybytes = privkey.public_key().public_bytes(encoding=enc,
format=pubformat)

pubkey = loadpubkeyraw(keypath + '.pub')

@@ -702,7 +710,8 @@ class TestMain(unittest.TestCase):
privformat = serialization.PrivateFormat.Raw
encalgo = serialization.NoEncryption()

rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
rprivrawkey = privkey.private_bytes(encoding=enc,
format=privformat, encryption_algorithm=encalgo)

self.assertEqual(rprivrawkey, privrawkey)

@@ -738,53 +747,54 @@ class TestMain(unittest.TestCase):
lsock = await listensockstr(servtargstr, None)

# Startup the server
server = await self.run_with_args('server',
wserver = self.run_with_args('server',
'-c', clientkeypath + '.pub',
servkeypath, incservstr, servtargstr)

# Startup the client with the "bad" key
client = await self.run_with_args('client',
badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr)
wclient = self.run_with_args('client', badclientkeypath,
servkeypath + '.pub', ptclientstr, incservstr)

# wait for server target to be created
await _awaitfile(servtargpath)
async with wserver as server, wclient as client:
# wait for server target to be created
await _awaitfile(servtargpath)

# wait for server to start
await _awaitfile(incservpath)
# wait for server to start
await _awaitfile(incservpath)

# wait for client to start
await _awaitfile(ptclientpath)
# wait for client to start
await _awaitfile(ptclientpath)

# Connect to the client
reader, writer = await connectsockstr(ptclientstr)
# Connect to the client
reader, writer = await connectsockstr(ptclientstr)

# XXX - this might not be the best test.
with self.assertRaises(asyncio.futures.TimeoutError):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)
# XXX - this might not be the best test.
with self.assertRaises(asyncio.futures.TimeoutError):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), .5)

writer.close()
writer.close()

# Make sure that when the server is terminated
server.terminate()
# Make sure that when the server is terminated
server.terminate()

# that it's stderr
stdout, stderr = await server.communicate()
#print('s:', repr((stdout, stderr)))
# that it's stderr
stdout, stderr = await server.communicate()
#print('s:', repr((stdout, stderr)))

# doesn't have an exceptions never retrieved
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)
# doesn't have an exceptions never retrieved
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)

lsock.close()
await lsock.wait_closed()
lsock.close()
await lsock.wait_closed()

# Kill off the client
client.terminate()
# Kill off the client
client.terminate()

stdout, stderr = await client.communicate()
#print('s:', repr((stdout, stderr)))
# XXX - figure out how to clean up client properly
stdout, stderr = await client.communicate()
#print('s:', repr((stdout, stderr)))
# XXX - figure out how to clean up client properly

@async_test
async def test_end2end(self):
@@ -817,72 +827,73 @@ class TestMain(unittest.TestCase):
lsock = await listensockstr(servtargstr, ptsockaccept)

# Startup the server
server = await self.run_with_args('server',
wserver = self.run_with_args('server',
'-c', clientkeypath + '.pub',
servkeypath, incservstr, servtargstr,
pipes=False)

# Startup the client
client = await self.run_with_args('client',
clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
pipes=False)
wclient = self.run_with_args('client',
clientkeypath, servkeypath + '.pub', ptclientstr,
incservstr, pipes=False)

# wait for server target to be created
await _awaitfile(servtargpath)
async with wserver as server, wclient as client:
# wait for server target to be created
await _awaitfile(servtargpath)

# wait for server to start
await _awaitfile(incservpath)
# wait for server to start
await _awaitfile(incservpath)

# wait for client to start
await _awaitfile(ptclientpath)
# wait for client to start
await _awaitfile(ptclientpath)

# Connect to the client
reader, writer = await connectsockstr(ptclientstr)
# Connect to the client
reader, writer = await connectsockstr(ptclientstr)

# send a message
ptmsg = b'this is a message for testing'
writer.write(ptmsg)
# send a message
ptmsg = b'this is a message for testing'
writer.write(ptmsg)

# make sure that we got the conenction
await ptsockevent.wait()
# make sure that we got the conenction
await ptsockevent.wait()

# get the connection
endrdr, endwrr = ptsock[0]
# get the connection
endrdr, endwrr = ptsock[0]

# make sure we can read back what we sent
self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
# make sure we can read back what we sent
self.assertEqual(ptmsg,
await endrdr.readexactly(len(ptmsg)))

# test some additional messages
for i in [ 129, 1287, 28792, 129872 ]:
# in on direction
msg = os.urandom(i)
writer.write(msg)
self.assertEqual(msg, await endrdr.readexactly(len(msg)))
# test some additional messages
for i in [ 129, 1287, 28792, 129872 ]:
# in on direction
msg = os.urandom(i)
writer.write(msg)
self.assertEqual(msg,
await endrdr.readexactly(len(msg)))

# and the other
endwrr.write(msg)
self.assertEqual(msg, await reader.readexactly(len(msg)))
# and the other
endwrr.write(msg)
self.assertEqual(msg,
await reader.readexactly(len(msg)))

writer.close()
endwrr.close()
writer.close()
endwrr.close()

lsock.close()
await lsock.wait_closed()
lsock.close()
await lsock.wait_closed()

server.terminate()
client.terminate()
# XXX - more clean up testing
# XXX - more testing that things exited properly

@async_test
async def test_genkey(self):
# that it can generate a key
proc = await self.run_with_args('genkey', 'somefile')

await proc.wait()
async with self.run_with_args('genkey', 'somefile') as proc:
await proc.wait()

#print(await proc.communicate())
#print(await proc.communicate())

self.assertEqual(proc.returncode, 0)
self.assertEqual(proc.returncode, 0)

with open('somefile.pub', encoding='ascii') as fp:
lines = fp.readlines()
@@ -891,23 +902,25 @@ class TestMain(unittest.TestCase):
keytype, keyvalue = lines[0].split()

self.assertEqual(keytype, 'ntun-x448')
key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
key = x448.X448PublicKey.from_public_bytes(
base64.urlsafe_b64decode(keyvalue))

key = loadprivkey('somefile')
self.assertIsInstance(key, x448.X448PrivateKey)

# that a second call fails
proc = await self.run_with_args('genkey', 'somefile')

await proc.wait()
async with self.run_with_args('genkey', 'somefile') as proc:
await proc.wait()

stdoutdata, stderrdata = await proc.communicate()
stdoutdata, stderrdata = await proc.communicate()

self.assertFalse(stdoutdata)
self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
self.assertFalse(stdoutdata)
self.assertEqual(
b'failed to create somefile.pub, file exists.\n',
stderrdata)

# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 2)
# And that it exited w/ the correct code
self.assertEqual(proc.returncode, 2)

class TestNoiseFowarder(unittest.TestCase):
def setUp(self):


Loading…
Cancel
Save