diff --git a/README.md b/README.md index 6e69c59..60795e7 100644 --- a/README.md +++ b/README.md @@ -21,3 +21,10 @@ Note that I have not tested this w/ pypy3, as when compiling the cryptography libraries, it would pick the wrong ones, despite setting CFLAGS and LDFLAGS. It is likely I could make this work, but do not know how to. + +TODO +---- + +- DoS protection. Limiting number of connections. Limit resource + consumption by opening connection and starting negotiation but not + completing it, etc. diff --git a/ntunnel.py b/ntunnel.py index 921a587..2ffa262 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -17,11 +17,42 @@ import shutil import socket import sys import tempfile +import time import threading import unittest _backend = default_backend() +def loadprivkey(fname): + with open(fname, encoding='ascii') as fp: + data = fp.read().encode('ascii') + key = load_pem_private_key(data, password=None, backend=default_backend()) + + return key + +def loadprivkeyraw(fname): + key = loadprivkey(fname) + + enc = serialization.Encoding.Raw + privformat = serialization.PrivateFormat.Raw + encalgo = serialization.NoEncryption() + + return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) + +def loadpubkeyraw(fname): + with open(fname, encoding='ascii') as fp: + lines = fp.readlines() + + # XXX + #self.assertEqual(len(lines), 1) + + keytype, keyvalue = lines[0].split() + + if keytype != 'ntun-x448': + raise RuntimeError + + return base64.urlsafe_b64decode(keyvalue) + def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' @@ -188,7 +219,11 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): # traceback.print_exc() # raise finally: - writer.write_eof() + try: + writer.write_eof() + except OSError as e: + if e.errno != 57: + raise async def encses(): try: @@ -213,15 +248,23 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): return await asyncio.gather(decses(), encses()) # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code -# Slightly modified to timeout +# Slightly modified to timeout and to print trace back when canceled. +# This makes it easier to figure out what "froze". def async_test(f): def wrapper(*args, **kwargs): - coro = asyncio.coroutine(f) - future = coro(*args, **kwargs) + async def tbcapture(): + try: + return await f(*args, **kwargs) + except asyncio.CancelledError as e: + # if we are going to be cancelled, print out a tb + import traceback + traceback.print_exc() + raise + loop = asyncio.get_event_loop() - # timeout after 2 seconds - loop.run_until_complete(asyncio.wait_for(future, 2)) + # timeout after 4 seconds + loop.run_until_complete(asyncio.wait_for(tbcapture(), 4)) return wrapper class Tests_misc(unittest.TestCase): @@ -243,10 +286,42 @@ class Tests_misc(unittest.TestCase): self.assertEqual(len(msg), dec(enc(msg) + msg)) def cmd_client(args): - pass + privkey = loadprivkeyraw(args.clientkey) + pubkey = loadpubkeyraw(args.servkey) + async def runnf(rdr, wrr): + encpair = asyncio.create_task(connectsockstr(args.clienttarget)) + + a = await NoiseForwarder('init', + encpair, _makefut((rdr, wrr)), + priv_key=privkey, pub_key=pubkey) + + # Setup client listener + ssock = listensockstr(args.clientlisten, runnf) + + loop = asyncio.get_event_loop() + + obj = loop.run_until_complete(ssock) + + loop.run_until_complete(obj.serve_forever()) def cmd_server(args): - pass + privkey = loadprivkeyraw(args.servkey) + + async def runnf(rdr, wrr): + ptpair = asyncio.create_task(connectsockstr(args.servtarget)) + + a = await NoiseForwarder('resp', + _makefut((rdr, wrr)), ptpair, + priv_key=privkey) + + # Setup server listener + ssock = listensockstr(args.servlisten, runnf) + + loop = asyncio.get_event_loop() + + obj = loop.run_until_complete(ssock) + + loop.run_until_complete(obj.serve_forever()) def cmd_genkey(args): keypair = genkeypair() @@ -318,6 +393,12 @@ def _asyncsockpair(): return asyncio.open_connection(sock=socka), \ asyncio.open_connection(sock=sockb) +async def _awaitfile(fname): + while not os.path.exists(fname): + await asyncio.sleep(.01) + + return True + class TestMain(unittest.TestCase): def setUp(self): # setup temporary directory @@ -333,6 +414,7 @@ class TestMain(unittest.TestCase): os.chdir(self.tempdir) def tearDown(self): + #print('td:', time.time()) shutil.rmtree(self.basetempdir) self.tempdir = None @@ -359,45 +441,120 @@ class TestMain(unittest.TestCase): __file__, *args, **kwargs) async def genkey(self, name): - proc = await self.run_with_args('genkey', name) + proc = await self.run_with_args('genkey', name, pipes=False) await proc.wait() self.assertEqual(proc.returncode, 0) @async_test - async def test_both(self): + async def test_loadpubkey(self): + keypath = os.path.join(self.tempdir, 'loadpubkeytest') + await self.genkey(keypath) + + privkey = loadprivkey(keypath) + + enc = serialization.Encoding.Raw + pubformat = serialization.PublicFormat.Raw + + pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat) + + pubkey = loadpubkeyraw(keypath + '.pub') + + self.assertEqual(pubkeybytes, pubkey) + + privrawkey = loadprivkeyraw(keypath) + + enc = serialization.Encoding.Raw + privformat = serialization.PrivateFormat.Raw + encalgo = serialization.NoEncryption() + + rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) + + self.assertEqual(rprivrawkey, privrawkey) + + @async_test + async def test_end2end(self): # Generate necessar keys - await self.genkey('server_key') - await self.genkey('client_key') + servkeypath = os.path.join(self.tempdir, 'server_key') + await self.genkey(servkeypath) + clientkeypath = os.path.join(self.tempdir, 'client_key') + await self.genkey(clientkeypath) - ptclientstr = _makeunix(os.path.join(self.tempdir, 'incclient.sock')) - incservstr = _makeunix(os.path.join(self.tempdir, 'incserv.sock')) - servtargstr = _makeunix(os.path.join(self.tempdir, 'servtarget.sock')) + await asyncio.sleep(.1) + #import pdb; pdb.set_trace() - # Setup pt target listener - pttarg = _makeunix('servtarget.sock') + # forwards connectsion to this socket (created by client) + ptclientpath = os.path.join(self.tempdir, 'incclient.sock') + ptclientstr = _makeunix(ptclientpath) + + # this is the socket server listen to + incservpath = os.path.join(self.tempdir, 'incserv.sock') + incservstr = _makeunix(incservpath) + + # to this socket, opened by server + servtargpath = os.path.join(self.tempdir, 'servtarget.sock') + servtargstr = _makeunix(servtargpath) + + # Setup server target listener ptsock = [] + ptsockevent = asyncio.Event() def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) + ptsockevent.set() # Bind to pt listener - lsock = await listensockstr(pttarg, ptsockaccept) + lsock = await listensockstr(servtargstr, ptsockaccept) # Startup the server server = await self.run_with_args('server', - '-c', 'client_key.pub', - 'server_key', incservstr, servtargstr, + '-c', clientkeypath + '.pub', + servkeypath, incservstr, servtargstr, pipes=False) # Startup the client - server = await self.run_with_args('client', - 'client_key', 'server_key.pub', ptclientstr, incservstr, + client = await self.run_with_args('client', + clientkeypath, servkeypath + '.pub', ptclientstr, incservstr, pipes=False) + # wait for server target to be created + await _awaitfile(servtargpath) + + # wait for server to start + await _awaitfile(incservpath) + + # wait for client to start + await _awaitfile(ptclientpath) + + await asyncio.sleep(.1) + # Connect to the client reader, writer = await connectsockstr(ptclientstr) + # send a message + ptmsg = b'this is a message for testing' + writer.write(ptmsg) + + # make sure that we got the conenction + await ptsockevent.wait() + + # get the connection + endrdr, endwrr = ptsock[0] + + # make sure we can read back what we sent + self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg))) + + # test some additional messages + for i in [ 129, 1287, 28792, 129872 ]: + # in on direction + msg = os.urandom(i) + writer.write(msg) + self.assertEqual(msg, await endrdr.readexactly(len(msg))) + + # and the other + endwrr.write(msg) + self.assertEqual(msg, await reader.readexactly(len(msg))) + @async_test async def test_genkey(self): # that it can generate a key @@ -418,10 +575,8 @@ class TestMain(unittest.TestCase): self.assertEqual(keytype, 'ntun-x448') key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) - with open('somefile', encoding='ascii') as fp: - data = fp.read().encode('ascii') - key = load_pem_private_key(data, password=None, backend=default_backend()) - self.assertIsInstance(key, x448.X448PrivateKey) + key = loadprivkey('somefile') + self.assertIsInstance(key, x448.X448PrivateKey) # that a second call fails proc = await self.run_with_args('genkey', 'somefile')