from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives.serialization import load_pem_private_key from noise.connection import NoiseConnection, Keypair #import tracemalloc; tracemalloc.start(100) import argparse import asyncio import base64 import os.path import shutil import socket import sys import tempfile import time import threading import unittest _backend = default_backend() def loadprivkey(fname): with open(fname, encoding='ascii') as fp: data = fp.read().encode('ascii') key = load_pem_private_key(data, password=None, backend=default_backend()) return key def loadprivkeyraw(fname): key = loadprivkey(fname) enc = serialization.Encoding.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) def loadpubkeyraw(fname): with open(fname, encoding='ascii') as fp: lines = fp.readlines() # XXX #self.assertEqual(len(lines), 1) keytype, keyvalue = lines[0].split() if keytype != 'ntun-x448': raise RuntimeError return base64.urlsafe_b64decode(keyvalue) def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' key = x448.X448PrivateKey.generate() enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() pub = key.public_key().public_bytes(encoding=enc, format=pubformat) priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) return pub, priv def _makefut(obj): loop = asyncio.get_running_loop() fut = loop.create_future() fut.set_result(obj) return fut def _makeunix(path): '''Make a properly formed unix path socket string.''' return 'unix:%s' % path # Make sure any additions are reflected by tests in test_parsesockstr _allowedparameters = { 'unix': { 'path': str, }, 'tcp': { 'host': str, 'port': int, }, } def parsesockstr(sockstr): '''Parse a socket string to its parts. The format of sockstr is: 'proto:param=value[,param2=value2]'. If the proto has a default parameter, the value can be used directly, like: 'proto:value'. This is only allowed when the value can unambiguously be determined not to be a param. If there needs to be an equals '=', then you MUST use the extended version. The characters that define 'param' must be all lower case ascii characters and may contain an underscore. The first character must not be an underscore. Supported protocols: unix: Default parameter is path. The path parameter specifies the path to the unix domain socket. The path MUST start w/ a slash if it is used as a default parameter. tcp: Default parameter is host[:port]. The host parameter specifies the host, and the port parameter specifies the port of the connection. ''' proto, rem = sockstr.split(':', 1) if '=' not in rem: if proto == 'unix' and rem[0] != '/': raise ValueError('bare path MUST start w/ a slash (/).') if proto == 'unix': args = { 'path': rem } else: args = dict(i.split('=', 1) for i in rem.split(',')) try: allowed = _allowedparameters[proto] except KeyError: raise ValueError('unsupported proto: %s' % repr(proto)) extrakeys = args.keys() - allowed.keys() if extrakeys: raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys)) for i in args: args[i] = allowed[i](args[i]) return proto, args async def connectsockstr(sockstr): '''Wrapper for asyncio.open_*_connection.''' proto, args = parsesockstr(sockstr) if proto == 'unix': fun = asyncio.open_unix_connection elif proto == 'tcp': fun = asyncio.open_connection reader, writer = await fun(**args) return reader, writer async def listensockstr(sockstr, cb): '''Wrapper for asyncio.start_x_server. For the format of sockstr, please see parsesockstr. The cb parameter is passed to asyncio's start_server or related calls. Per those docs, the cb parameter is calls or scheduled as a task when a client establishes a connection. It is called with two arguments, the reader and writer streams. For more information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server ''' proto, args = parsesockstr(sockstr) if proto == 'unix': fun = asyncio.start_unix_server elif proto == 'tcp': fun = asyncio.start_server return await fun(cb, **args) # !!python makemessagelengths.py _handshakelens = \ [72, 72, 88] def _genciphfun(hash, ad): hkdf = HKDF(algorithm=hashes.SHA256(), length=32, salt=b'asdoifjsldkjdsf', info=ad, backend=_backend) key = hkdf.derive(hash) cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=_backend) enctor = cipher.encryptor() def encfun(data): # Returns the two bytes for length val = len(data) encbytes = enctor.update(data[:16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return (val ^ mask).to_bytes(length=2, byteorder='big') def decfun(data): # takes off the data and returns the total # length val = int.from_bytes(data[:2], byteorder='big') encbytes = enctor.update(data[2:2 + 16]) mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff return val ^ mask return encfun, decfun async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): '''A function that forwards data between the plain text pair of streams to the encrypted session. The mode paramater must be one of 'init' or 'resp' for initiator and responder. The encrdrwrr is an await object that will return a tunle of the reader and writer streams for the encrypted side of the connection. The ptpairfun parameter is a function that will be passed the public key bytes for the remote client. This can be used to both validate that the correct client is connecting, and to pass back the correct plain text reader/writer objects that match the provided static key. The function must be an async function. In the case of the initiator, pub_key must be provided and will be used to authenticate the responder side of the connection. The priv_key parameter is used to authenticate this side of the session. Both priv_key and pub_key parameters must be 56 bytes. For example, the pair that is returned by genkeypair. ''' # Send a protocol version so that in the future we can change how # we interface, and possibly be able to send control messages, # allow the client to pass some misc data to the callback, or to # allow a reverse tunnel, were the client talks to the server, # and waits for the server to "connect" to the client w/ a # connection, e.g. reverse tunnel out behind a nat to allow # incoming connections. protocol_version = 0 rdr, wrr = await encrdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) if pub_key is not None: proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key) if mode == 'resp': proto.set_as_responder() proto.start_handshake() proto.read_message(await rdr.readexactly(_handshakelens[0])) wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[2])) elif mode == 'init': proto.set_as_initiator() proto.start_handshake() wrr.write(proto.write_message()) proto.read_message(await rdr.readexactly(_handshakelens[1])) wrr.write(proto.write_message()) if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') try: reader, writer = await ptpairfun(getattr(proto.get_keypair( Keypair.REMOTE_STATIC), 'public_bytes', None)) except: wrr.close() raise # generate the keys for lengths # XXX - get_handshake_hash is probably not the best option, but # this is only to obscure lengths, it is not required to be secure # as the underlying NoiseProtocol securely validates everything. # It is marginally useful as writing patterns likely expose the # true length. Adding padding could marginally help w/ this. if mode == 'resp': _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') elif mode == 'init': enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') # protocol negotiation # send first, then wait for the response pvmsg = protocol_version.to_bytes(1, byteorder='big') encmsg = proto.encrypt(pvmsg) wrr.write(enclenfun(encmsg)) wrr.write(encmsg) # get the protocol version msg = await rdr.readexactly(2 + 16) tlen = declenfun(msg) rmsg = await rdr.readexactly(tlen - 16) tmsg = msg[2:] + rmsg rpv = proto.decrypt(tmsg) rempv = int.from_bytes(rpv, byteorder='big') if rempv != protocol_version: raise RuntimeError('unsupported protovol version received: %d' % rempv) async def decses(): try: while True: try: msg = await rdr.readexactly(2 + 16) except asyncio.streams.IncompleteReadError: if rdr.at_eof(): return 'dec' tlen = declenfun(msg) rmsg = await rdr.readexactly(tlen - 16) tmsg = msg[2:] + rmsg writer.write(proto.decrypt(tmsg)) await writer.drain() #except: # import traceback # traceback.print_exc() # raise finally: try: writer.write_eof() except OSError as e: if e.errno != 57: raise async def encses(): try: while True: # largest message ptmsg = await reader.read(65535 - 16) if not ptmsg: # eof return 'enc' encmsg = proto.encrypt(ptmsg) wrr.write(enclenfun(encmsg)) wrr.write(encmsg) await wrr.drain() #except: # import traceback # traceback.print_exc() # raise finally: wrr.write_eof() res = await asyncio.gather(decses(), encses()) await wrr.drain() # not sure if needed wrr.close() await writer.drain() # not sure if needed writer.close() return res # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code # Slightly modified to timeout and to print trace back when canceled. # This makes it easier to figure out what "froze". def async_test(f): def wrapper(*args, **kwargs): async def tbcapture(): try: return await f(*args, **kwargs) except asyncio.CancelledError as e: # if we are going to be cancelled, print out a tb import traceback traceback.print_exc() raise loop = asyncio.get_event_loop() # timeout after 4 seconds loop.run_until_complete(asyncio.wait_for(tbcapture(), 4)) return wrapper class Tests_misc(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) os.chdir(self.tempdir) def tearDown(self): #print('td:', time.time()) shutil.rmtree(self.basetempdir) self.tempdir = None def test_parsesockstr_bad(self): badstrs = [ 'unix:ff', 'randomnocolon', 'unix:somethingelse=bogus', 'tcp:port=bogus', ] for i in badstrs: with self.assertRaises(ValueError, msg='Should have failed processing: %s' % repr(i)): parsesockstr(i) def test_parsesockstr(self): results = { # Not all of these are valid when passed to a *sockstr # function 'unix:/apath': ('unix', { 'path': '/apath' }), 'unix:path=apath': ('unix', { 'path': 'apath' }), 'tcp:host=apath': ('tcp', { 'host': 'apath' }), 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath', 'port': 5 }), } for s, r in results.items(): self.assertEqual(parsesockstr(s), r) @async_test async def test_listensockstr_bad(self): with self.assertRaises(ValueError): ls = await listensockstr('bogus:some=arg', None) with self.assertRaises(ValueError): ls = await connectsockstr('bogus:some=arg') @async_test async def test_listenconnectsockstr(self): msgsent = b'this is a test message' msgrcv = b'testing message for receive' # That when a connection is received and receives and sends async def servconfhandle(rdr, wrr): msg = await rdr.readexactly(len(msgsent)) self.assertEqual(msg, msgsent) #print(repr(wrr.get_extra_info('sockname'))) wrr.write(msgrcv) await wrr.drain() wrr.close() return True # Test listensockstr for sstr, confun in [ ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')), ('tcp:port=9384', lambda: asyncio.open_connection(port=9384)) ]: # that listensockstr will bind to the correct path, can call cb ls = await listensockstr(sstr, servconfhandle) # that we open a connection to the path rdr, wrr = await confun() # and send a message wrr.write(msgsent) # and receive the message rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5) self.assertEqual(rcv, msgrcv) wrr.close() # Now test that connectsockstr works similarly. rdr, wrr = await connectsockstr(sstr) # and send a message wrr.write(msgsent) # and receive the message rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5) self.assertEqual(rcv, msgrcv) wrr.close() ls.close() await ls.wait_closed() def test_genciphfun(self): enc, dec = _genciphfun(b'0' * 32, b'foobar') msg = b'this is a bunch of data' tb = enc(msg) self.assertEqual(len(msg), dec(tb + msg)) for i in [ 20, 1384, 64000, 23839, 65535 ]: msg = os.urandom(i) self.assertEqual(len(msg), dec(enc(msg) + msg)) def cmd_client(args): privkey = loadprivkeyraw(args.clientkey) pubkey = loadpubkeyraw(args.servkey) async def runnf(rdr, wrr): encpair = asyncio.create_task(connectsockstr(args.clienttarget)) a = await NoiseForwarder('init', encpair, lambda x: _makefut((rdr, wrr)), priv_key=privkey, pub_key=pubkey) # Setup client listener ssock = listensockstr(args.clientlisten, runnf) loop = asyncio.get_event_loop() obj = loop.run_until_complete(ssock) loop.run_until_complete(obj.serve_forever()) def cmd_server(args): privkey = loadprivkeyraw(args.servkey) pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ] async def runnf(rdr, wrr): async def checkclientfun(clientkey): if clientkey not in pubkeys: raise RuntimeError('invalid key provided') return await connectsockstr(args.servtarget) a = await NoiseForwarder('resp', _makefut((rdr, wrr)), checkclientfun, priv_key=privkey) # Setup server listener ssock = listensockstr(args.servlisten, runnf) loop = asyncio.get_event_loop() obj = loop.run_until_complete(ssock) loop.run_until_complete(obj.serve_forever()) def cmd_genkey(args): keypair = genkeypair() key = x448.X448PrivateKey.generate() # public key part enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw pub = key.public_key().public_bytes(encoding=enc, format=pubformat) try: fname = args.fname + '.pub' with open(fname, 'x', encoding='ascii') as fp: print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp) except FileExistsError: print('failed to create %s, file exists.' % fname, file=sys.stderr) sys.exit(2) enc = serialization.Encoding.PEM format = serialization.PrivateFormat.PKCS8 encalgo = serialization.NoEncryption() with open(args.fname, 'x', encoding='ascii') as fp: fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii')) def main(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help') parser_gk = subparsers.add_parser('genkey', help='generate keys') parser_gk.add_argument('fname', type=str, help='file name for the key') parser_gk.set_defaults(func=cmd_genkey) parser_serv = subparsers.add_parser('server', help='run a server') parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file') parser_serv.add_argument('servkey', type=str, help='file name for the server key') parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on') parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to') parser_serv.set_defaults(func=cmd_server) parser_client = subparsers.add_parser('client', help='run a client') parser_client.add_argument('clientkey', type=str, help='file name for the client private key') parser_client.add_argument('servkey', type=str, help='file name for the server public key') parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on') parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to') parser_client.set_defaults(func=cmd_client) args = parser.parse_args() try: fun = args.func except AttributeError: parser.print_usage() sys.exit(5) fun(args) if __name__ == '__main__': # pragma: no cover main() def _asyncsockpair(): '''Create a pair of sockets that are bound to each other. The function will return a tuple of two coroutine's, that each, when await'ed upon, will return the reader/writer pair.''' socka, sockb = socket.socketpair() return asyncio.open_connection(sock=socka), \ asyncio.open_connection(sock=sockb) async def _awaitfile(fname): while not os.path.exists(fname): await asyncio.sleep(.01) return True class TestMain(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() os.chdir(self.tempdir) def tearDown(self): #print('td:', time.time()) shutil.rmtree(self.basetempdir) self.tempdir = None @async_test async def test_noargs(self): proc = await self.run_with_args() await proc.wait() # XXX - not checking error message # And that it exited w/ the correct code self.assertEqual(proc.returncode, 5) def run_with_args(self, *args, pipes=True): kwargs = {} if pipes: kwargs.update(dict( stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)) return asyncio.create_subprocess_exec(sys.executable, # XXX - figure out how to add coverage data on these runs #'-m', 'coverage', 'run', '-p', __file__, *args, **kwargs) async def genkey(self, name): proc = await self.run_with_args('genkey', name, pipes=False) await proc.wait() self.assertEqual(proc.returncode, 0) @async_test async def test_loadpubkey(self): keypath = os.path.join(self.tempdir, 'loadpubkeytest') await self.genkey(keypath) privkey = loadprivkey(keypath) enc = serialization.Encoding.Raw pubformat = serialization.PublicFormat.Raw pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat) pubkey = loadpubkeyraw(keypath + '.pub') self.assertEqual(pubkeybytes, pubkey) privrawkey = loadprivkeyraw(keypath) enc = serialization.Encoding.Raw privformat = serialization.PrivateFormat.Raw encalgo = serialization.NoEncryption() rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo) self.assertEqual(rprivrawkey, privrawkey) @async_test async def test_clientkeymismatch(self): # make sure that if there's a client key mismatch, we # don't connect # Generate necessar keys servkeypath = os.path.join(self.tempdir, 'server_key') await self.genkey(servkeypath) clientkeypath = os.path.join(self.tempdir, 'client_key') await self.genkey(clientkeypath) badclientkeypath = os.path.join(self.tempdir, 'badclient_key') await self.genkey(badclientkeypath) # forwards connectsion to this socket (created by client) ptclientpath = os.path.join(self.tempdir, 'incclient.sock') ptclientstr = _makeunix(ptclientpath) # this is the socket server listen to incservpath = os.path.join(self.tempdir, 'incserv.sock') incservstr = _makeunix(incservpath) # to this socket, opened by server servtargpath = os.path.join(self.tempdir, 'servtarget.sock') servtargstr = _makeunix(servtargpath) # Setup server target listener ptsockevent = asyncio.Event() # Bind to pt listener lsock = await listensockstr(servtargstr, None) # Startup the server server = await self.run_with_args('server', '-c', clientkeypath + '.pub', servkeypath, incservstr, servtargstr) # Startup the client with the "bad" key client = await self.run_with_args('client', badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr) # wait for server target to be created await _awaitfile(servtargpath) # wait for server to start await _awaitfile(incservpath) # wait for client to start await _awaitfile(ptclientpath) # Connect to the client reader, writer = await connectsockstr(ptclientstr) # XXX - this might not be the best test. with self.assertRaises(asyncio.futures.TimeoutError): # make sure that we don't get the conenction await asyncio.wait_for(ptsockevent.wait(), .5) writer.close() # Make sure that when the server is terminated server.terminate() # that it's stderr stdout, stderr = await server.communicate() #print('s:', repr((stdout, stderr))) # doesn't have an exceptions never retrieved # even the example echo server has this same leak #self.assertNotIn(b'Task exception was never retrieved', stderr) lsock.close() await lsock.wait_closed() # Kill off the client client.terminate() stdout, stderr = await client.communicate() #print('s:', repr((stdout, stderr))) # XXX - figure out how to clean up client properly @async_test async def test_end2end(self): # Generate necessar keys servkeypath = os.path.join(self.tempdir, 'server_key') await self.genkey(servkeypath) clientkeypath = os.path.join(self.tempdir, 'client_key') await self.genkey(clientkeypath) # forwards connectsion to this socket (created by client) ptclientpath = os.path.join(self.tempdir, 'incclient.sock') ptclientstr = _makeunix(ptclientpath) # this is the socket server listen to incservpath = os.path.join(self.tempdir, 'incserv.sock') incservstr = _makeunix(incservpath) # to this socket, opened by server servtargpath = os.path.join(self.tempdir, 'servtarget.sock') servtargstr = _makeunix(servtargpath) # Setup server target listener ptsock = [] ptsockevent = asyncio.Event() def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) ptsockevent.set() # Bind to pt listener lsock = await listensockstr(servtargstr, ptsockaccept) # Startup the server server = await self.run_with_args('server', '-c', clientkeypath + '.pub', servkeypath, incservstr, servtargstr, pipes=False) # Startup the client client = await self.run_with_args('client', clientkeypath, servkeypath + '.pub', ptclientstr, incservstr, pipes=False) # wait for server target to be created await _awaitfile(servtargpath) # wait for server to start await _awaitfile(incservpath) # wait for client to start await _awaitfile(ptclientpath) # Connect to the client reader, writer = await connectsockstr(ptclientstr) # send a message ptmsg = b'this is a message for testing' writer.write(ptmsg) # make sure that we got the conenction await ptsockevent.wait() # get the connection endrdr, endwrr = ptsock[0] # make sure we can read back what we sent self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg))) # test some additional messages for i in [ 129, 1287, 28792, 129872 ]: # in on direction msg = os.urandom(i) writer.write(msg) self.assertEqual(msg, await endrdr.readexactly(len(msg))) # and the other endwrr.write(msg) self.assertEqual(msg, await reader.readexactly(len(msg))) writer.close() endwrr.close() lsock.close() await lsock.wait_closed() server.terminate() client.terminate() # XXX - more clean up testing @async_test async def test_genkey(self): # that it can generate a key proc = await self.run_with_args('genkey', 'somefile') await proc.wait() #print(await proc.communicate()) self.assertEqual(proc.returncode, 0) with open('somefile.pub', encoding='ascii') as fp: lines = fp.readlines() self.assertEqual(len(lines), 1) keytype, keyvalue = lines[0].split() self.assertEqual(keytype, 'ntun-x448') key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue)) key = loadprivkey('somefile') self.assertIsInstance(key, x448.X448PrivateKey) # that a second call fails proc = await self.run_with_args('genkey', 'somefile') await proc.wait() stdoutdata, stderrdata = await proc.communicate() self.assertFalse(stdoutdata) self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata) # And that it exited w/ the correct code self.assertEqual(proc.returncode, 2) class TestNoiseFowarder(unittest.TestCase): def setUp(self): # setup temporary directory d = os.path.realpath(tempfile.mkdtemp()) self.basetempdir = d self.tempdir = os.path.join(d, 'subdir') os.mkdir(self.tempdir) # Generate key pairs self.server_key_pair = genkeypair() self.client_key_pair = genkeypair() def tearDown(self): shutil.rmtree(self.basetempdir) self.tempdir = None @async_test async def test_clientkeymissmatch(self): # generate a key that is incorrect wrongclient_key_pair = genkeypair() # the secure socket clssockapair, clssockbpair = _asyncsockpair() reader, writer = await clssockapair async def wrongkey(v): raise ValueError('no key matches') # create the server servnf = asyncio.create_task(NoiseForwarder('resp', clssockbpair, wrongkey, priv_key=self.server_key_pair[1])) # Create client proto = NoiseConnection.from_name( b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup wrong client key proto.set_keypair_from_private_bytes(Keypair.STATIC, wrongclient_key_pair[1]) # but the correct server key proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.assertEqual(len(message), _handshakelens[0]) writer.write(message) # Get response respmsg = await reader.readexactly(_handshakelens[1]) proto.read_message(respmsg) # Send final reply message = proto.write_message() writer.write(message) # Make sure handshake has completed self.assertTrue(proto.handshake_finished) with self.assertRaises(ValueError): await servnf writer.close() @async_test async def test_server(self): # Test is plumbed: # (reader, writer) -> servsock -> # (rdr, wrr) NoiseForward (reader, writer) -> # servptsock -> (ptsock[0], ptsock[1]) # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') servarg = _makeunix(servsockpath) # Path that the server will send pt data to servptpath = os.path.join(self.tempdir, 'servptsock') # Setup pt target listener pttarg = _makeunix(servptpath) ptsock = [] ptsockevent = asyncio.Event() def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) ptsockevent.set() # Bind to pt listener lsock = await listensockstr(pttarg, ptsockaccept) nfs = [] event = asyncio.Event() async def runnf(rdr, wrr): ptpairfun = asyncio.create_task(connectsockstr(pttarg)) a = await NoiseForwarder('resp', _makefut((rdr, wrr)), lambda x: ptpairfun, priv_key=self.server_key_pair[1]) nfs.append(a) event.set() # Setup server listener ssock = await listensockstr(servarg, runnf) # Connect to server reader, writer = await connectsockstr(servarg) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.assertEqual(len(message), _handshakelens[0]) writer.write(message) # Get response respmsg = await reader.readexactly(_handshakelens[1]) proto.read_message(respmsg) # Send final reply message = proto.write_message() writer.write(message) # Make sure handshake has completed self.assertTrue(proto.handshake_finished) # generate the keys for lengths enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') pversion = 0 # Send the protocol version string first encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) writer.write(enclenfun(encmsg)) writer.write(encmsg) # Read the peer's protocol version # find out how much we need to read encmsg = await reader.readexactly(2 + 16) tlen = declenfun(encmsg) # read the rest of the message rencmsg = await reader.readexactly(tlen - 16) tmsg = encmsg[2:] + rencmsg rptmsg = proto.decrypt(tmsg) self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion) # write a test message ptmsg = b'this is a test message that should be a little in length' encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # wait for the connection to arrive await ptsockevent.wait() ptreader, ptwriter = ptsock[0] # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # write a different message ptmsg = os.urandom(2843) encmsg = proto.encrypt(ptmsg) writer.write(enclenfun(encmsg)) writer.write(encmsg) # read the test message rptmsg = await ptreader.readexactly(len(ptmsg)) self.assertEqual(rptmsg, ptmsg) # now try the other way ptmsg = os.urandom(912) ptwriter.write(ptmsg) # find out how much we need to read encmsg = await reader.readexactly(2 + 16) tlen = declenfun(encmsg) # read the rest of the message rencmsg = await reader.readexactly(tlen - 16) tmsg = encmsg[2:] + rencmsg rptmsg = proto.decrypt(tmsg) self.assertEqual(rptmsg, ptmsg) # shut down sending writer.write_eof() # so pt reader should be shut down self.assertEqual(b'', await ptreader.read(1)) self.assertTrue(ptreader.at_eof()) # shut down pt ptwriter.write_eof() # make sure the enc reader is eof self.assertEqual(b'', await reader.read(1)) self.assertTrue(reader.at_eof()) await event.wait() self.assertEqual(nfs[0], [ 'dec', 'enc' ]) writer.close() ptwriter.close() lsock.close() ssock.close() await lsock.wait_closed() await ssock.wait_closed() @async_test async def test_protocolversionmismatch(self): # make sure that if we send a future version, that we # still get a protocol version, and that the connection # is closed w/o establishing a connection to the remote # side # Test is plumbed: # (reader, writer) -> servsock -> # (rdr, wrr) NoiseForward (reader, writer) -> # servptsock -> (ptsock[0], ptsock[1]) # Path that the server will sit on servsockpath = os.path.join(self.tempdir, 'servsock') servarg = _makeunix(servsockpath) # Path that the server will send pt data to servptpath = os.path.join(self.tempdir, 'servptsock') # Setup pt target listener pttarg = _makeunix(servptpath) ptsock = [] ptsockevent = asyncio.Event() def ptsockaccept(reader, writer, ptsock=ptsock): ptsock.append((reader, writer)) ptsockevent.set() # Bind to pt listener lsock = await listensockstr(pttarg, ptsockaccept) nfs = [] event = asyncio.Event() async def runnf(rdr, wrr): ptpairfun = asyncio.create_task(connectsockstr(pttarg)) try: a = await NoiseForwarder('resp', _makefut((rdr, wrr)), lambda x: ptpairfun, priv_key=self.server_key_pair[1]) except RuntimeError as e: nfs.append(e) event.set() return nfs.append(a) event.set() # Setup server listener ssock = await listensockstr(servarg, runnf) # Connect to server reader, writer = await connectsockstr(servarg) # Create client proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto.set_as_initiator() # Setup required keys proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1]) proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0]) proto.start_handshake() # Send first message message = proto.write_message() self.assertEqual(len(message), _handshakelens[0]) writer.write(message) # Get response respmsg = await reader.readexactly(_handshakelens[1]) proto.read_message(respmsg) # Send final reply message = proto.write_message() writer.write(message) # Make sure handshake has completed self.assertTrue(proto.handshake_finished) # generate the keys for lengths enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') pversion = 1 # Send the protocol version string first encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) writer.write(enclenfun(encmsg)) writer.write(encmsg) # Read the peer's protocol version # find out how much we need to read encmsg = await reader.readexactly(2 + 16) tlen = declenfun(encmsg) # read the rest of the message rencmsg = await reader.readexactly(tlen - 16) tmsg = encmsg[2:] + rencmsg rptmsg = proto.decrypt(tmsg) self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0) await event.wait() self.assertIsInstance(nfs[0], RuntimeError) @async_test async def test_serverclient(self): # plumbing: # # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb # ptcsockapair, ptcsockbpair = _asyncsockpair() ptcareader, ptcawriter = await ptcsockapair #ptcsockbpair passed directly clssockapair, clssockbpair = _asyncsockpair() #both passed directly ptssockapair, ptssockbpair = _asyncsockpair() #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair async def validateclientkey(pubkey): self.assertEqual(pubkey, self.client_key_pair[0]) return await ptssockapair clientnf = asyncio.create_task(NoiseForwarder('init', clssockapair, lambda x: ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) servnf = asyncio.create_task(NoiseForwarder('resp', clssockbpair, validateclientkey, priv_key=self.server_key_pair[1])) # send a message msga = os.urandom(183) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a second message msga = os.urandom(2834) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message larger than the block size msga = os.urandom(103958) ptcawriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptsbreader.readexactly(len(msga))) # send a message the other direction msga = os.urandom(103958) ptsbwriter.write(msga) # make sure we get the same message self.assertEqual(msga, await ptcareader.readexactly(len(msga))) # close down the pt writers, the rest should follow ptsbwriter.write_eof() ptcawriter.write_eof() # make sure they are closed, and there is no more data self.assertEqual(b'', await ptsbreader.read(1)) self.assertTrue(ptsbreader.at_eof()) self.assertEqual(b'', await ptcareader.read(1)) self.assertTrue(ptcareader.at_eof()) self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await servnf) await ptsbwriter.drain() await ptcawriter.drain() ptsbwriter.close() ptcawriter.close()