- from contextlib import asynccontextmanager
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives import serialization
- from cryptography.hazmat.primitives.asymmetric import x448
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives.kdf.hkdf import HKDF
- from cryptography.hazmat.primitives.serialization import load_pem_private_key
- from noise.connection import NoiseConnection, Keypair
- #import tracemalloc; tracemalloc.start(100)
- import argparse
- import asyncio
- import base64
- import os.path
- import shutil
- import socket
- import sys
- import tempfile
- import time
- import threading
- import unittest
- _backend = default_backend()
- def loadprivkey(fname):
- with open(fname, encoding='ascii') as fp:
- data = fp.read().encode('ascii')
- key = load_pem_private_key(data, password=None, backend=default_backend())
- return key
- def loadprivkeyraw(fname):
- key = loadprivkey(fname)
- enc = serialization.Encoding.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
- return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
- def loadpubkeyraw(fname):
- with open(fname, encoding='ascii') as fp:
- lines = fp.readlines()
- # XXX
- #self.assertEqual(len(lines), 1)
- keytype, keyvalue = lines[0].split()
- if keytype != 'ntun-x448':
- raise RuntimeError
- return base64.urlsafe_b64decode(keyvalue)
- class ConnectionValidator(object):
- '''This class is used to validate a connection, and initiate the
- connection that will be used.'''
- async def validatekey(self, hash, key):
- '''Validate that the key is authorized to connect. The
- connection hash is passed in, so that the authorizate of
- the key can be validated later.'''
- raise NotImplementedError
- async def getconnection(self, hash, key, **kwargs):
- '''Return the connection that should be used by this
- client.'''
- raise NotImplementedError
- class GenericConnValidator(object):
- '''This is a simple implementation of a ConnectionValidator that
- can be used w/ most cases. It checks against the list, and then
- calls/awaits the function provided, and returns it's value.'''
- def __init__(self, keys, connfun):
- '''The parameter keys must be an object that supports the
- in operators, aka contains. If the key is in the keys
- object, the connection will proceed.
- The parameter connfun must be an async function that
- returns a StreamReader, StreamWriter pair of the
- connection that they session is supposed to use.'''
- self._keys = keys
- self._connfun = connfun
- async def validatekey(self, hash, key):
- if key not in self._keys:
- raise ValueError('key not authorized: %s' % repr(key))
- async def getconnection(self, hash, key):
- return await self._connfun()
- def genkeypair():
- '''Generates a keypair, and returns a tuple of (public, private).
- They are encoded as raw bytes, and sutible for use w/ Noise.'''
- key = x448.X448PrivateKey.generate()
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
- return pub, priv
- def _makefut(obj):
- loop = asyncio.get_running_loop()
- fut = loop.create_future()
- fut.set_result(obj)
- return fut
- def _makeunix(path):
- '''Make a properly formed unix path socket string.'''
- return 'unix:%s' % path
- # Make sure any additions are reflected by tests in test_parsesockstr
- _allowedparameters = {
- 'unix': {
- 'path': str,
- },
- 'tcp': {
- 'host': str,
- 'port': int,
- },
- }
- def parsesockstr(sockstr):
- '''Parse a socket string to its parts.
- The format of sockstr is: 'proto:param=value[,param2=value2]'.
- If the proto has a default parameter, the value can be used
- directly, like: 'proto:value'. This is only allowed when the
- value can unambiguously be determined not to be a param. If
- there needs to be an equals '=', then you MUST use the extended
- version.
- The characters that define 'param' must be all lower case ascii
- characters and may contain an underscore. The first character
- must not be an underscore.
- Supported protocols:
- unix:
- Default parameter is path.
- The path parameter specifies the path to the
- unix domain socket. The path MUST start w/ a
- slash if it is used as a default parameter.
- tcp:
- Default parameter is host[:port].
- The host parameter specifies the host, and the
- port parameter specifies the port of the
- connection.
- '''
- proto, rem = sockstr.split(':', 1)
- if '=' not in rem:
- if proto == 'unix' and rem[0] != '/':
- raise ValueError('bare path MUST start w/ a slash (/).')
- if proto == 'unix':
- args = { 'path': rem }
- else:
- args = dict(i.split('=', 1) for i in rem.split(','))
- try:
- allowed = _allowedparameters[proto]
- except KeyError:
- raise ValueError('unsupported proto: %s' % repr(proto))
- extrakeys = args.keys() - allowed.keys()
- if extrakeys:
- raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
- for i in args:
- args[i] = allowed[i](args[i])
- return proto, args
- async def connectsockstr(sockstr):
- '''Wrapper for asyncio.open_*_connection.'''
- proto, args = parsesockstr(sockstr)
- if proto == 'unix':
- fun = asyncio.open_unix_connection
- elif proto == 'tcp':
- fun = asyncio.open_connection
- reader, writer = await fun(**args)
- return reader, writer
- async def listensockstr(sockstr, cb):
- '''Wrapper for asyncio.start_x_server.
- For the format of sockstr, please see parsesockstr.
- The cb parameter is passed to asyncio's start_server or related
- calls. Per those docs, the cb parameter is calls or scheduled
- as a task when a client establishes a connection. It is called
- with two arguments, the reader and writer streams. For more
- information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
- '''
- proto, args = parsesockstr(sockstr)
- if proto == 'unix':
- fun = asyncio.start_unix_server
- elif proto == 'tcp':
- fun = asyncio.start_server
- return await fun(cb, **args)
- # !!python makemessagelengths.py
- _handshakelens = \
- [72, 72, 88]
- def _genciphfun(hash, ad):
- hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
- salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
- key = hkdf.derive(hash)
- cipher = Cipher(algorithms.AES(key), modes.ECB(),
- backend=_backend)
- enctor = cipher.encryptor()
- def encfun(data):
- # Returns the two bytes for length
- val = len(data)
- encbytes = enctor.update(data[:16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
- return (val ^ mask).to_bytes(length=2, byteorder='big')
- def decfun(data):
- # takes off the data and returns the total
- # length
- val = int.from_bytes(data[:2], byteorder='big')
- encbytes = enctor.update(data[2:2 + 16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
- return val ^ mask
- return encfun, decfun
- async def NoiseForwarder(mode, encrdrwrr, connvalid, priv_key, pub_key=None):
- '''A function that forwards data between the plain text pair of
- streams to the encrypted session.
- The mode paramater must be one of 'init' or 'resp' for initiator
- and responder.
- The encrdrwrr is an await object that will return a tunle of the
- reader and writer streams for the encrypted side of the
- connection.
- The connvalid parameter is an instance of the ConnectionValidator
- class, or one that implements it's methods. The validatekey method
- will be passed the session hash, and the remote public key of the
- client or server. If the key is not authorized, an exception
- must be raised. Any non-exception return from the function means
- that the key is authorized, and that the session should continue.
- In the case of the initiator, the server's key will be passed,
- despite the fact that it was already validated by the XK Noise
- Protocol. This is just to keep the calling convention the same,
- and it supports moving to an XX protocol possibly in the future
- with minimal changes.
- Then the getconnection method will be called. It's expected
- return is the connection to forward the data on to. The kwargs
- may be used in future protocol versions to allow the client to
- request a specific resource, or something similar.
- In the case of the initiator, pub_key must be provided and will
- be used to authenticate the responder side of the connection.
- The priv_key parameter is used to authenticate this side of the
- session.
- Both priv_key and pub_key parameters must be 56 bytes. For example,
- the pair that is returned by genkeypair.
- '''
- # Send a protocol version so that in the future we can change how
- # we interface, and possibly be able to send control messages,
- # allow the client to pass some misc data to the callback, or to
- # allow a reverse tunnel, were the client talks to the server,
- # and waits for the server to "connect" to the client w/ a
- # connection, e.g. reverse tunnel out behind a nat to allow
- # incoming connections.
- #
- # Add protocol version to getconnection when it gets bumped
- #
- protocol_version = 0
- rdr, wrr = await encrdrwrr
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
- if pub_key is not None:
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- pub_key)
- if mode == 'resp':
- proto.set_as_responder()
- proto.start_handshake()
- proto.read_message(await rdr.readexactly(_handshakelens[0]))
- wrr.write(proto.write_message())
- proto.read_message(await rdr.readexactly(_handshakelens[2]))
- elif mode == 'init':
- proto.set_as_initiator()
- proto.start_handshake()
- wrr.write(proto.write_message())
- proto.read_message(await rdr.readexactly(_handshakelens[1]))
- wrr.write(proto.write_message())
- if not proto.handshake_finished: # pragma: no cover
- raise RuntimeError('failed to finish handshake')
- sesshash = proto.get_handshake_hash()
- clientkey = getattr(proto.get_keypair(Keypair.REMOTE_STATIC),
- 'public_bytes', None)
- try:
- await connvalid.validatekey(sesshash, clientkey)
- except Exception:
- wrr.close()
- raise
- # generate the keys for lengths
- # XXX - get_handshake_hash is probably not the best option, but
- # this is only to obscure lengths, it is not required to be secure
- # as the underlying NoiseProtocol securely validates everything.
- # It is marginally useful as writing patterns likely expose the
- # true length. Adding padding could marginally help w/ this.
- if mode == 'resp':
- _, declenfun = _genciphfun(sesshash, b'toresp')
- enclenfun, _ = _genciphfun(sesshash, b'toinit')
- elif mode == 'init':
- enclenfun, _ = _genciphfun(sesshash, b'toresp')
- _, declenfun = _genciphfun(sesshash, b'toinit')
- # protocol negotiation
- # send first, then wait for the response
- pvmsg = protocol_version.to_bytes(1, byteorder='big')
- encmsg = proto.encrypt(pvmsg)
- wrr.write(enclenfun(encmsg))
- wrr.write(encmsg)
- # get the protocol version
- msg = await rdr.readexactly(2 + 16)
- tlen = declenfun(msg)
- rmsg = await rdr.readexactly(tlen - 16)
- tmsg = msg[2:] + rmsg
- rpv = proto.decrypt(tmsg)
- rempv = int.from_bytes(rpv, byteorder='big')
- if rempv != protocol_version:
- raise RuntimeError('unsupported protovol version received: %d' %
- rempv)
- reader, writer = await connvalid.getconnection(sesshash, clientkey)
- async def decses():
- try:
- while True:
- try:
- msg = await rdr.readexactly(2 + 16)
- except asyncio.streams.IncompleteReadError:
- if rdr.at_eof():
- return 'dec'
- tlen = declenfun(msg)
- rmsg = await rdr.readexactly(tlen - 16)
- tmsg = msg[2:] + rmsg
- writer.write(proto.decrypt(tmsg))
- await writer.drain()
- #except:
- # import traceback
- # traceback.print_exc()
- # raise
- finally:
- try:
- writer.write_eof()
- except OSError as e:
- if e.errno != 57:
- raise
- async def encses():
- try:
- while True:
- # largest message
- ptmsg = await reader.read(65535 - 16)
- if not ptmsg:
- # eof
- return 'enc'
- encmsg = proto.encrypt(ptmsg)
- wrr.write(enclenfun(encmsg))
- wrr.write(encmsg)
- await wrr.drain()
- #except:
- # import traceback
- # traceback.print_exc()
- # raise
- finally:
- wrr.write_eof()
- res = await asyncio.gather(decses(), encses())
- await wrr.drain() # not sure if needed
- wrr.close()
- await writer.drain() # not sure if needed
- writer.close()
- return res
- # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
- # Slightly modified to timeout and to print trace back when canceled.
- # This makes it easier to figure out what "froze".
- def async_test(f):
- def wrapper(*args, **kwargs):
- async def tbcapture():
- try:
- return await f(*args, **kwargs)
- except asyncio.CancelledError as e:
- # if we are going to be cancelled, print out a tb
- import traceback
- traceback.print_exc()
- raise
- loop = asyncio.get_event_loop()
- # timeout after 4 seconds
- loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
- return wrapper
- class Tests_misc(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
- os.chdir(self.tempdir)
- def tearDown(self):
- #print('td:', time.time())
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
- def test_parsesockstr_bad(self):
- badstrs = [
- 'unix:ff',
- 'randomnocolon',
- 'unix:somethingelse=bogus',
- 'tcp:port=bogus',
- ]
- for i in badstrs:
- with self.assertRaises(ValueError,
- msg='Should have failed processing: %s' % repr(i)):
- parsesockstr(i)
- def test_parsesockstr(self):
- results = {
- # Not all of these are valid when passed to a *sockstr
- # function
- 'unix:/apath': ('unix', { 'path': '/apath' }),
- 'unix:path=apath': ('unix', { 'path': 'apath' }),
- 'tcp:host=apath': ('tcp', { 'host': 'apath' }),
- 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
- 'port': 5 }),
- }
- for s, r in results.items():
- self.assertEqual(parsesockstr(s), r)
- @async_test
- async def test_listensockstr_bad(self):
- with self.assertRaises(ValueError):
- ls = await listensockstr('bogus:some=arg', None)
- with self.assertRaises(ValueError):
- ls = await connectsockstr('bogus:some=arg')
- @async_test
- async def test_listenconnectsockstr(self):
- msgsent = b'this is a test message'
- msgrcv = b'testing message for receive'
- # That when a connection is received and receives and sends
- async def servconfhandle(rdr, wrr):
- msg = await rdr.readexactly(len(msgsent))
- self.assertEqual(msg, msgsent)
- #print(repr(wrr.get_extra_info('sockname')))
- wrr.write(msgrcv)
- await wrr.drain()
- wrr.close()
- return True
- # Test listensockstr
- for sstr, confun in [
- ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
- ('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
- ]:
- # that listensockstr will bind to the correct path, can call cb
- ls = await listensockstr(sstr, servconfhandle)
- # that we open a connection to the path
- rdr, wrr = await confun()
- # and send a message
- wrr.write(msgsent)
- # and receive the message
- rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
- self.assertEqual(rcv, msgrcv)
- wrr.close()
- # Now test that connectsockstr works similarly.
- rdr, wrr = await connectsockstr(sstr)
- # and send a message
- wrr.write(msgsent)
- # and receive the message
- rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
- self.assertEqual(rcv, msgrcv)
- wrr.close()
- ls.close()
- await ls.wait_closed()
- def test_genciphfun(self):
- enc, dec = _genciphfun(b'0' * 32, b'foobar')
- msg = b'this is a bunch of data'
- tb = enc(msg)
- self.assertEqual(len(msg), dec(tb + msg))
- for i in [ 20, 1384, 64000, 23839, 65535 ]:
- msg = os.urandom(i)
- self.assertEqual(len(msg), dec(enc(msg) + msg))
- def cmd_client(args):
- privkey = loadprivkeyraw(args.clientkey)
- pubkey = loadpubkeyraw(args.servkey)
- async def runnf(rdr, wrr):
- connval = GenericConnValidator([ pubkey ],
- lambda: _makefut((rdr, wrr)))
- encpair = asyncio.create_task(connectsockstr(args.clienttarget))
- a = await NoiseForwarder('init', encpair, connval,
- priv_key=privkey, pub_key=pubkey)
- # Setup client listener
- ssock = listensockstr(args.clientlisten, runnf)
- loop = asyncio.get_event_loop()
- obj = loop.run_until_complete(ssock)
- loop.run_until_complete(obj.serve_forever())
- def cmd_server(args):
- privkey = loadprivkeyraw(args.servkey)
- pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
- async def runnf(rdr, wrr):
- connval = GenericConnValidator(pubkeys, lambda: connectsockstr(args.servtarget))
- a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
- connval, priv_key=privkey)
- # Setup server listener
- ssock = listensockstr(args.servlisten, runnf)
- loop = asyncio.get_event_loop()
- obj = loop.run_until_complete(ssock)
- loop.run_until_complete(obj.serve_forever())
- def cmd_genkey(args):
- keypair = genkeypair()
- key = x448.X448PrivateKey.generate()
- # public key part
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- try:
- fname = args.fname + '.pub'
- with open(fname, 'x', encoding='ascii') as fp:
- print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
- except FileExistsError:
- print('failed to create %s, file exists.' % fname, file=sys.stderr)
- sys.exit(2)
- enc = serialization.Encoding.PEM
- format = serialization.PrivateFormat.PKCS8
- encalgo = serialization.NoEncryption()
- with open(args.fname, 'x', encoding='ascii') as fp:
- fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
- def main():
- parser = argparse.ArgumentParser()
- subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
- parser_gk = subparsers.add_parser('genkey', help='generate keys')
- parser_gk.add_argument('fname', type=str, help='file name for the key')
- parser_gk.set_defaults(func=cmd_genkey)
- parser_serv = subparsers.add_parser('server', help='run a server')
- parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
- parser_serv.add_argument('servkey', type=str, help='file name for the server key')
- parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
- parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
- parser_serv.set_defaults(func=cmd_server)
- parser_client = subparsers.add_parser('client', help='run a client')
- parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
- parser_client.add_argument('servkey', type=str, help='file name for the server public key')
- parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
- parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
- parser_client.set_defaults(func=cmd_client)
- args = parser.parse_args()
- try:
- fun = args.func
- except AttributeError:
- parser.print_usage()
- sys.exit(5)
- fun(args)
- if __name__ == '__main__': # pragma: no cover
- main()
- def _asyncsockpair():
- '''Create a pair of sockets that are bound to each other.
- The function will return a tuple of two coroutine's, that
- each, when await'ed upon, will return the reader/writer pair.'''
- socka, sockb = socket.socketpair()
- return asyncio.open_connection(sock=socka), \
- asyncio.open_connection(sock=sockb)
- async def _awaitfile(fname):
- while not os.path.exists(fname):
- await asyncio.sleep(.01)
- return True
- class TestMain(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
- os.chdir(self.tempdir)
- def tearDown(self):
- #print('td:', time.time())
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
- @asynccontextmanager
- async def run_with_args(self, *args, pipes=True):
- kwargs = {}
- if pipes:
- kwargs.update(dict(
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE))
- aproc = asyncio.create_subprocess_exec(sys.executable,
- # XXX - figure out how to add coverage data on these runs
- #'-m', 'coverage', 'run', '-p',
- __file__, *args, **kwargs)
- try:
- proc = await aproc
- yield proc
- finally:
- if proc.returncode is None:
- proc.terminate()
- # Make sure that process exits before continuing
- await proc.wait()
- @async_test
- async def test_noargs(self):
- async with self.run_with_args() as proc:
- await proc.wait()
- # XXX - not checking error message
- # And that it exited w/ the correct code
- self.assertEqual(proc.returncode, 5)
- async def genkey(self, name):
- async with self.run_with_args('genkey', name, pipes=False) as proc:
- await proc.wait()
- self.assertEqual(proc.returncode, 0)
- @async_test
- async def test_loadpubkey(self):
- keypath = os.path.join(self.tempdir, 'loadpubkeytest')
- await self.genkey(keypath)
- privkey = loadprivkey(keypath)
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- pubkeybytes = privkey.public_key().public_bytes(encoding=enc,
- format=pubformat)
- pubkey = loadpubkeyraw(keypath + '.pub')
- self.assertEqual(pubkeybytes, pubkey)
- privrawkey = loadprivkeyraw(keypath)
- enc = serialization.Encoding.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
- rprivrawkey = privkey.private_bytes(encoding=enc,
- format=privformat, encryption_algorithm=encalgo)
- self.assertEqual(rprivrawkey, privrawkey)
- @async_test
- async def test_clientkeymismatch(self):
- # make sure that if there's a client key mismatch, we
- # don't connect
- # Generate necessar keys
- servkeypath = os.path.join(self.tempdir, 'server_key')
- await self.genkey(servkeypath)
- clientkeypath = os.path.join(self.tempdir, 'client_key')
- await self.genkey(clientkeypath)
- badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
- await self.genkey(badclientkeypath)
- # forwards connectsion to this socket (created by client)
- ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
- ptclientstr = _makeunix(ptclientpath)
- # this is the socket server listen to
- incservpath = os.path.join(self.tempdir, 'incserv.sock')
- incservstr = _makeunix(incservpath)
- # to this socket, opened by server
- servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
- servtargstr = _makeunix(servtargpath)
- # Setup server target listener
- ptsockevent = asyncio.Event()
- # Bind to pt listener
- lsock = await listensockstr(servtargstr, None)
- # Startup the server
- wserver = self.run_with_args('server',
- '-c', clientkeypath + '.pub',
- servkeypath, incservstr, servtargstr)
- # Startup the client with the "bad" key
- wclient = self.run_with_args('client', badclientkeypath,
- servkeypath + '.pub', ptclientstr, incservstr)
- async with wserver as server, wclient as client:
- # wait for server target to be created
- await _awaitfile(servtargpath)
- # wait for server to start
- await _awaitfile(incservpath)
- # wait for client to start
- await _awaitfile(ptclientpath)
- # Connect to the client
- reader, writer = await connectsockstr(ptclientstr)
- # XXX - this might not be the best test.
- with self.assertRaises(asyncio.futures.TimeoutError):
- # make sure that we don't get the conenction
- await asyncio.wait_for(ptsockevent.wait(), .5)
- writer.close()
- # Make sure that when the server is terminated
- server.terminate()
- # that it's stderr
- stdout, stderr = await server.communicate()
- #print('s:', repr((stdout, stderr)))
- # doesn't have an exceptions never retrieved
- # even the example echo server has this same leak
- #self.assertNotIn(b'Task exception was never retrieved', stderr)
- lsock.close()
- await lsock.wait_closed()
- # Kill off the client
- client.terminate()
- stdout, stderr = await client.communicate()
- #print('s:', repr((stdout, stderr)))
- # XXX - figure out how to clean up client properly
- @async_test
- async def test_end2end(self):
- # Generate necessar keys
- servkeypath = os.path.join(self.tempdir, 'server_key')
- await self.genkey(servkeypath)
- clientkeypath = os.path.join(self.tempdir, 'client_key')
- await self.genkey(clientkeypath)
- # forwards connectsion to this socket (created by client)
- ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
- ptclientstr = _makeunix(ptclientpath)
- # this is the socket server listen to
- incservpath = os.path.join(self.tempdir, 'incserv.sock')
- incservstr = _makeunix(incservpath)
- # to this socket, opened by server
- servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
- servtargstr = _makeunix(servtargpath)
- # Setup server target listener
- ptsock = []
- ptsockevent = asyncio.Event()
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
- ptsockevent.set()
- # Bind to pt listener
- lsock = await listensockstr(servtargstr, ptsockaccept)
- # Startup the server
- wserver = self.run_with_args('server',
- '-c', clientkeypath + '.pub',
- servkeypath, incservstr, servtargstr,
- pipes=False)
- # Startup the client
- wclient = self.run_with_args('client',
- clientkeypath, servkeypath + '.pub', ptclientstr,
- incservstr, pipes=False)
- async with wserver as server, wclient as client:
- # wait for server target to be created
- await _awaitfile(servtargpath)
- # wait for server to start
- await _awaitfile(incservpath)
- # wait for client to start
- await _awaitfile(ptclientpath)
- # Connect to the client
- reader, writer = await connectsockstr(ptclientstr)
- # send a message
- ptmsg = b'this is a message for testing'
- writer.write(ptmsg)
- # make sure that we got the conenction
- await ptsockevent.wait()
- # get the connection
- endrdr, endwrr = ptsock[0]
- # make sure we can read back what we sent
- self.assertEqual(ptmsg,
- await endrdr.readexactly(len(ptmsg)))
- # test some additional messages
- for i in [ 129, 1287, 28792, 129872 ]:
- # in on direction
- msg = os.urandom(i)
- writer.write(msg)
- self.assertEqual(msg,
- await endrdr.readexactly(len(msg)))
- # and the other
- endwrr.write(msg)
- self.assertEqual(msg,
- await reader.readexactly(len(msg)))
- writer.close()
- endwrr.close()
- lsock.close()
- await lsock.wait_closed()
- # XXX - more testing that things exited properly
- @async_test
- async def test_genkey(self):
- # that it can generate a key
- async with self.run_with_args('genkey', 'somefile') as proc:
- await proc.wait()
- #print(await proc.communicate())
- self.assertEqual(proc.returncode, 0)
- with open('somefile.pub', encoding='ascii') as fp:
- lines = fp.readlines()
- self.assertEqual(len(lines), 1)
- keytype, keyvalue = lines[0].split()
- self.assertEqual(keytype, 'ntun-x448')
- key = x448.X448PublicKey.from_public_bytes(
- base64.urlsafe_b64decode(keyvalue))
- key = loadprivkey('somefile')
- self.assertIsInstance(key, x448.X448PrivateKey)
- # that a second call fails
- async with self.run_with_args('genkey', 'somefile') as proc:
- await proc.wait()
- stdoutdata, stderrdata = await proc.communicate()
- self.assertFalse(stdoutdata)
- self.assertEqual(
- b'failed to create somefile.pub, file exists.\n',
- stderrdata)
- # And that it exited w/ the correct code
- self.assertEqual(proc.returncode, 2)
- class TestNoiseFowarder(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
- def tearDown(self):
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
- @async_test
- async def test_clientkeymissmatch(self):
- # generate a key that is incorrect
- wrongclient_key_pair = genkeypair()
- # the secure socket
- clssockapair, clssockbpair = _asyncsockpair()
- reader, writer = await clssockapair
- # create the server
- servnf = asyncio.create_task(NoiseForwarder('resp',
- clssockbpair, GenericConnValidator([], None),
- priv_key=self.server_key_pair[1]))
- # Create client
- proto = NoiseConnection.from_name(
- b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
- # Setup wrong client key
- proto.set_keypair_from_private_bytes(Keypair.STATIC,
- wrongclient_key_pair[1])
- # but the correct server key
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- self.server_key_pair[0])
- proto.start_handshake()
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
- # Send final reply
- message = proto.write_message()
- writer.write(message)
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
- with self.assertRaises(ValueError):
- await servnf
- writer.close()
- @async_test
- async def test_server(self):
- # Test is plumbed:
- # (reader, writer) -> servsock ->
- # (rdr, wrr) NoiseForward (reader, writer) ->
- # servptsock -> (ptsock[0], ptsock[1])
- # Path that the server will sit on
- servsockpath = os.path.join(self.tempdir, 'servsock')
- servarg = _makeunix(servsockpath)
- # Path that the server will send pt data to
- servptpath = os.path.join(self.tempdir, 'servptsock')
- # Setup pt target listener
- pttarg = _makeunix(servptpath)
- ptsock = []
- ptsockevent = asyncio.Event()
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
- ptsockevent.set()
- # Bind to pt listener
- lsock = await listensockstr(pttarg, ptsockaccept)
- nfs = []
- event = asyncio.Event()
- async def runnf(rdr, wrr):
- connval = GenericConnValidator(
- [ self.client_key_pair[0] ],
- lambda: connectsockstr(pttarg))
- a = await NoiseForwarder('resp',
- _makefut((rdr, wrr)), connval,
- priv_key=self.server_key_pair[1])
- nfs.append(a)
- event.set()
- # Setup server listener
- ssock = await listensockstr(servarg, runnf)
- # Connect to server
- reader, writer = await connectsockstr(servarg)
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC,
- self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- self.server_key_pair[0])
- proto.start_handshake()
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
- # Send final reply
- message = proto.write_message()
- writer.write(message)
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
- # generate the keys for lengths
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
- b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(),
- b'toinit')
- pversion = 0
- # Send the protocol version string first
- encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
- # Read the peer's protocol version
- # find out how much we need to read
- encmsg = await reader.readexactly(2 + 16)
- tlen = declenfun(encmsg)
- # read the rest of the message
- rencmsg = await reader.readexactly(tlen - 16)
- tmsg = encmsg[2:] + rencmsg
- rptmsg = proto.decrypt(tmsg)
- self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion)
- # write a test message
- ptmsg = b'this is a test message that should be a little in length'
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
- # wait for the connection to arrive
- await ptsockevent.wait()
- ptreader, ptwriter = ptsock[0]
- # read the test message
- rptmsg = await ptreader.readexactly(len(ptmsg))
- self.assertEqual(rptmsg, ptmsg)
- # write a different message
- ptmsg = os.urandom(2843)
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
- # read the test message
- rptmsg = await ptreader.readexactly(len(ptmsg))
- self.assertEqual(rptmsg, ptmsg)
- # now try the other way
- ptmsg = os.urandom(912)
- ptwriter.write(ptmsg)
- # find out how much we need to read
- encmsg = await reader.readexactly(2 + 16)
- tlen = declenfun(encmsg)
- # read the rest of the message
- rencmsg = await reader.readexactly(tlen - 16)
- tmsg = encmsg[2:] + rencmsg
- rptmsg = proto.decrypt(tmsg)
- self.assertEqual(rptmsg, ptmsg)
- # shut down sending
- writer.write_eof()
- # so pt reader should be shut down
- self.assertEqual(b'', await ptreader.read(1))
- self.assertTrue(ptreader.at_eof())
- # shut down pt
- ptwriter.write_eof()
- # make sure the enc reader is eof
- self.assertEqual(b'', await reader.read(1))
- self.assertTrue(reader.at_eof())
- await event.wait()
- self.assertEqual(nfs[0], [ 'dec', 'enc' ])
- writer.close()
- ptwriter.close()
- lsock.close()
- ssock.close()
- await lsock.wait_closed()
- await ssock.wait_closed()
- @async_test
- async def test_protocolversionmismatch(self):
- # make sure that if we send a future version, that we
- # still get a protocol version, and that the connection
- # is closed w/o establishing a connection to the remote
- # side
- # Test is plumbed:
- # (reader, writer) -> servsock ->
- # (rdr, wrr) NoiseForward (reader, writer) ->
- # servptsock -> (ptsock[0], ptsock[1])
- # Path that the server will sit on
- servsockpath = os.path.join(self.tempdir, 'servsock')
- servarg = _makeunix(servsockpath)
- # Path that the server will send pt data to
- servptpath = os.path.join(self.tempdir, 'servptsock')
- # Setup pt target listener
- pttarg = _makeunix(servptpath)
- ptsock = []
- ptsockevent = asyncio.Event()
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
- ptsockevent.set()
- # Bind to pt listener
- lsock = await listensockstr(pttarg, ptsockaccept)
- nfs = []
- event = asyncio.Event()
- async def runnf(rdr, wrr):
- ptpairfun = asyncio.create_task(connectsockstr(pttarg))
- try:
- connval = GenericConnValidator(
- [ self.client_key_pair[0] ],
- lambda: ptpairfun)
- a = await NoiseForwarder('resp',
- _makefut((rdr, wrr)), connval,
- priv_key=self.server_key_pair[1])
- except RuntimeError as e:
- nfs.append(e)
- event.set()
- return
- nfs.append(a)
- event.set()
- # Setup server listener
- ssock = await listensockstr(servarg, runnf)
- # Connect to server
- reader, writer = await connectsockstr(servarg)
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC,
- self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- self.server_key_pair[0])
- proto.start_handshake()
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
- # Send final reply
- message = proto.write_message()
- writer.write(message)
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
- # generate the keys for lengths
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
- b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(),
- b'toinit')
- pversion = 1
- # Send the protocol version string first
- encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
- # Read the peer's protocol version
- # find out how much we need to read
- encmsg = await reader.readexactly(2 + 16)
- tlen = declenfun(encmsg)
- # read the rest of the message
- rencmsg = await reader.readexactly(tlen - 16)
- tmsg = encmsg[2:] + rencmsg
- rptmsg = proto.decrypt(tmsg)
- self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0)
- await event.wait()
- self.assertIsInstance(nfs[0], RuntimeError)
- @async_test
- async def test_serverclient(self):
- # plumbing:
- #
- # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
- #
- ptcsockapair, ptcsockbpair = _asyncsockpair()
- ptcareader, ptcawriter = await ptcsockapair
- #ptcsockbpair passed directly
- clssockapair, clssockbpair = _asyncsockpair()
- #both passed directly
- ptssockapair, ptssockbpair = _asyncsockpair()
- #ptssockapair passed directly
- ptsbreader, ptsbwriter = await ptssockbpair
- validateclientside = GenericConnValidator(
- [ self.server_key_pair[0] ], lambda: ptcsockbpair)
- validateserverside = GenericConnValidator(
- [ self.client_key_pair[0] ], lambda: ptssockapair)
- clientnf = asyncio.create_task(NoiseForwarder('init',
- clssockapair, validateclientside,
- priv_key=self.client_key_pair[1],
- pub_key=self.server_key_pair[0]))
- servnf = asyncio.create_task(NoiseForwarder('resp',
- clssockbpair, validateserverside,
- priv_key=self.server_key_pair[1]))
- # send a message
- msga = os.urandom(183)
- ptcawriter.write(msga)
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
- # send a second message
- msga = os.urandom(2834)
- ptcawriter.write(msga)
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
- # send a message larger than the block size
- msga = os.urandom(103958)
- ptcawriter.write(msga)
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
- # send a message the other direction
- msga = os.urandom(103958)
- ptsbwriter.write(msga)
- # make sure we get the same message
- self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
- # close down the pt writers, the rest should follow
- ptsbwriter.write_eof()
- ptcawriter.write_eof()
- # make sure they are closed, and there is no more data
- self.assertEqual(b'', await ptsbreader.read(1))
- self.assertTrue(ptsbreader.at_eof())
- self.assertEqual(b'', await ptcareader.read(1))
- self.assertTrue(ptcareader.at_eof())
- self.assertEqual([ 'dec', 'enc' ], await clientnf)
- self.assertEqual([ 'dec', 'enc' ], await servnf)
- await ptsbwriter.drain()
- await ptcawriter.drain()
- ptsbwriter.close()
- ptcawriter.close()