|
- from cryptography.hazmat.backends import default_backend
- from cryptography.hazmat.primitives import hashes
- from cryptography.hazmat.primitives import serialization
- from cryptography.hazmat.primitives.asymmetric import x448
- from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
- from cryptography.hazmat.primitives.kdf.hkdf import HKDF
- from cryptography.hazmat.primitives.serialization import load_pem_private_key
- from noise.connection import NoiseConnection, Keypair
-
- #import tracemalloc; tracemalloc.start()
-
- import argparse
- import asyncio
- import base64
- import os.path
- import shutil
- import socket
- import sys
- import tempfile
- import time
- import threading
- import unittest
-
- _backend = default_backend()
-
- def loadprivkey(fname):
- with open(fname, encoding='ascii') as fp:
- data = fp.read().encode('ascii')
- key = load_pem_private_key(data, password=None, backend=default_backend())
-
- return key
-
- def loadprivkeyraw(fname):
- key = loadprivkey(fname)
-
- enc = serialization.Encoding.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
-
- return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
-
- def loadpubkeyraw(fname):
- with open(fname, encoding='ascii') as fp:
- lines = fp.readlines()
-
- # XXX
- #self.assertEqual(len(lines), 1)
-
- keytype, keyvalue = lines[0].split()
-
- if keytype != 'ntun-x448':
- raise RuntimeError
-
- return base64.urlsafe_b64decode(keyvalue)
-
- def genkeypair():
- '''Generates a keypair, and returns a tuple of (public, private).
- They are encoded as raw bytes, and sutible for use w/ Noise.'''
-
- key = x448.X448PrivateKey.generate()
-
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
-
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
- priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
-
- return pub, priv
-
- def _makefut(obj):
- loop = asyncio.get_running_loop()
- fut = loop.create_future()
- fut.set_result(obj)
-
- return fut
-
- def _makeunix(path):
- '''Make a properly formed unix path socket string.'''
-
- return 'unix:%s' % path
-
- def _parsesockstr(sockstr):
- proto, rem = sockstr.split(':', 1)
-
- return proto, rem
-
- async def connectsockstr(sockstr):
- proto, rem = _parsesockstr(sockstr)
-
- reader, writer = await asyncio.open_unix_connection(rem)
-
- return reader, writer
-
- async def listensockstr(sockstr, cb):
- '''Wrapper for asyncio.start_x_server.
-
- The format of sockstr is: 'proto:param=value[,param2=value2]'.
- If the proto has a default parameter, the value can be used
- directly, like: 'proto:value'. This is only allowed when the
- value can unambiguously be determined not to be a param.
-
- The characters that define 'param' must be all lower case ascii
- characters and may contain an underscore. The first character
- must not be and underscore.
-
- Supported protocols:
- unix:
- Default parameter is path.
- The path parameter specifies the path to the
- unix domain socket. The path MUST start w/ a
- slash if it is used as a default parameter.
- '''
-
- proto, rem = _parsesockstr(sockstr)
-
- server = await asyncio.start_unix_server(cb, path=rem)
-
- return server
-
- # !!python makemessagelengths.py
- _handshakelens = \
- [72, 72, 88]
-
- def _genciphfun(hash, ad):
- hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
- salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
-
- key = hkdf.derive(hash)
- cipher = Cipher(algorithms.AES(key), modes.ECB(),
- backend=_backend)
- enctor = cipher.encryptor()
-
- def encfun(data):
- # Returns the two bytes for length
- val = len(data)
- encbytes = enctor.update(data[:16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
-
- return (val ^ mask).to_bytes(length=2, byteorder='big')
-
- def decfun(data):
- # takes off the data and returns the total
- # length
- val = int.from_bytes(data[:2], byteorder='big')
- encbytes = enctor.update(data[2:2 + 16])
- mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
-
- return val ^ mask
-
- return encfun, decfun
-
- async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
- '''A function that forwards data between the plain text pair of
- streams to the encrypted session.
-
- The mode paramater must be one of 'init' or 'resp' for initiator
- and responder.
-
- The encrdrwrr is an await object that will return a tunle of the
- reader and writer streams for the encrypted side of the
- connection.
-
- The ptpairfun parameter is a function that will be passed the
- public key bytes for the remote client. This can be used to
- both validate that the correct client is connecting, and to
- pass back the correct plain text reader/writer objects that
- match the provided static key. The function must be an async
- function.
-
- In the case of the initiator, pub_key must be provided and will
- be used to authenticate the responder side of the connection.
-
- The priv_key parameter is used to authenticate this side of the
- session.
-
- Both priv_key and pub_key parameters must be 56 bytes. For example,
- the pair that is returned by genkeypair.
- '''
-
- rdr, wrr = await encrdrwrr
-
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
-
- proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
- if pub_key is not None:
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- pub_key)
-
- if mode == 'resp':
- proto.set_as_responder()
-
- proto.start_handshake()
-
- proto.read_message(await rdr.readexactly(_handshakelens[0]))
-
- wrr.write(proto.write_message())
-
- proto.read_message(await rdr.readexactly(_handshakelens[2]))
- elif mode == 'init':
- proto.set_as_initiator()
-
- proto.start_handshake()
-
- wrr.write(proto.write_message())
-
- proto.read_message(await rdr.readexactly(_handshakelens[1]))
-
- wrr.write(proto.write_message())
-
- if not proto.handshake_finished: # pragma: no cover
- raise RuntimeError('failed to finish handshake')
-
- reader, writer = await ptpairfun(getattr(proto.get_keypair(
- Keypair.REMOTE_STATIC), 'public_bytes', None))
-
- # generate the keys for lengths
- # XXX - get_handshake_hash is probably not the best option, but
- # this is only to obscure lengths, it is not required to be secure
- # as the underlying NoiseProtocol securely validates everything.
- # It is marginally useful as writing patterns likely expose the
- # true length. Adding padding could marginally help w/ this.
- if mode == 'resp':
- _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
- elif mode == 'init':
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
-
- async def decses():
- try:
- while True:
- try:
- msg = await rdr.readexactly(2 + 16)
- except asyncio.streams.IncompleteReadError:
- if rdr.at_eof():
- return 'dec'
-
- tlen = declenfun(msg)
- rmsg = await rdr.readexactly(tlen - 16)
- tmsg = msg[2:] + rmsg
- writer.write(proto.decrypt(tmsg))
- await writer.drain()
- #except:
- # import traceback
- # traceback.print_exc()
- # raise
- finally:
- try:
- writer.write_eof()
- except OSError as e:
- if e.errno != 57:
- raise
-
- async def encses():
- try:
- while True:
- # largest message
- ptmsg = await reader.read(65535 - 16)
- if not ptmsg:
- # eof
- return 'enc'
-
- encmsg = proto.encrypt(ptmsg)
- wrr.write(enclenfun(encmsg))
- wrr.write(encmsg)
- await wrr.drain()
- #except:
- # import traceback
- # traceback.print_exc()
- # raise
- finally:
- wrr.write_eof()
-
- return await asyncio.gather(decses(), encses())
-
- # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
- # Slightly modified to timeout and to print trace back when canceled.
- # This makes it easier to figure out what "froze".
- def async_test(f):
- def wrapper(*args, **kwargs):
- async def tbcapture():
- try:
- return await f(*args, **kwargs)
- except asyncio.CancelledError as e:
- # if we are going to be cancelled, print out a tb
- import traceback
- traceback.print_exc()
- raise
-
- loop = asyncio.get_event_loop()
-
- # timeout after 4 seconds
- loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
- return wrapper
-
- class Tests_misc(unittest.TestCase):
- def test_listensockstr(self):
- # XXX write test
- pass
-
- def test_genciphfun(self):
- enc, dec = _genciphfun(b'0' * 32, b'foobar')
-
- msg = b'this is a bunch of data'
-
- tb = enc(msg)
-
- self.assertEqual(len(msg), dec(tb + msg))
-
- for i in [ 20, 1384, 64000, 23839, 65535 ]:
- msg = os.urandom(i)
- self.assertEqual(len(msg), dec(enc(msg) + msg))
-
- def cmd_client(args):
- privkey = loadprivkeyraw(args.clientkey)
- pubkey = loadpubkeyraw(args.servkey)
- async def runnf(rdr, wrr):
- encpair = asyncio.create_task(connectsockstr(args.clienttarget))
-
- a = await NoiseForwarder('init',
- encpair, lambda x: _makefut((rdr, wrr)),
- priv_key=privkey, pub_key=pubkey)
-
- # Setup client listener
- ssock = listensockstr(args.clientlisten, runnf)
-
- loop = asyncio.get_event_loop()
-
- obj = loop.run_until_complete(ssock)
-
- loop.run_until_complete(obj.serve_forever())
-
- def cmd_server(args):
- privkey = loadprivkeyraw(args.servkey)
-
- async def runnf(rdr, wrr):
- ptpairfun = asyncio.create_task(connectsockstr(args.servtarget))
-
- a = await NoiseForwarder('resp',
- _makefut((rdr, wrr)), lambda x: ptpairfun,
- priv_key=privkey)
-
- # Setup server listener
- ssock = listensockstr(args.servlisten, runnf)
-
- loop = asyncio.get_event_loop()
-
- obj = loop.run_until_complete(ssock)
-
- loop.run_until_complete(obj.serve_forever())
-
- def cmd_genkey(args):
- keypair = genkeypair()
-
- key = x448.X448PrivateKey.generate()
-
- # public key part
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
- pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
-
- try:
- fname = args.fname + '.pub'
- with open(fname, 'x', encoding='ascii') as fp:
- print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
- except FileExistsError:
- print('failed to create %s, file exists.' % fname, file=sys.stderr)
- sys.exit(2)
-
- enc = serialization.Encoding.PEM
- format = serialization.PrivateFormat.PKCS8
- encalgo = serialization.NoEncryption()
- with open(args.fname, 'x', encoding='ascii') as fp:
- fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
-
- def main():
- parser = argparse.ArgumentParser()
-
- subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
-
- parser_gk = subparsers.add_parser('genkey', help='generate keys')
- parser_gk.add_argument('fname', type=str, help='file name for the key')
- parser_gk.set_defaults(func=cmd_genkey)
-
- parser_serv = subparsers.add_parser('server', help='run a server')
- parser_serv.add_argument('-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
- parser_serv.add_argument('servkey', type=str, help='file name for the server key')
- parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
- parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
- parser_serv.set_defaults(func=cmd_server)
-
- parser_client = subparsers.add_parser('client', help='run a client')
- parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
- parser_client.add_argument('servkey', type=str, help='file name for the server public key')
- parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
- parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
- parser_client.set_defaults(func=cmd_client)
-
- args = parser.parse_args()
-
- try:
- fun = args.func
- except AttributeError:
- parser.print_usage()
- sys.exit(5)
-
- fun(args)
-
- if __name__ == '__main__': # pragma: no cover
- main()
-
- def _asyncsockpair():
- '''Create a pair of sockets that are bound to each other.
- The function will return a tuple of two coroutine's, that
- each, when await'ed upon, will return the reader/writer pair.'''
-
- socka, sockb = socket.socketpair()
-
- return asyncio.open_connection(sock=socka), \
- asyncio.open_connection(sock=sockb)
-
- async def _awaitfile(fname):
- while not os.path.exists(fname):
- await asyncio.sleep(.01)
-
- return True
-
- class TestMain(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
-
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
-
- os.chdir(self.tempdir)
-
- def tearDown(self):
- #print('td:', time.time())
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
-
- @async_test
- async def test_noargs(self):
- proc = await self.run_with_args()
-
- await proc.wait()
-
- # XXX - not checking error message
-
- # And that it exited w/ the correct code
- self.assertEqual(proc.returncode, 5)
-
- def run_with_args(self, *args, pipes=True):
- kwargs = {}
- if pipes:
- kwargs.update(dict(
- stdout=asyncio.subprocess.PIPE,
- stderr=asyncio.subprocess.PIPE))
- return asyncio.create_subprocess_exec(sys.executable,
- # XXX - figure out how to add coverage data on these runs
- #'-m', 'coverage', 'run', '-p',
- __file__, *args, **kwargs)
-
- async def genkey(self, name):
- proc = await self.run_with_args('genkey', name, pipes=False)
-
- await proc.wait()
-
- self.assertEqual(proc.returncode, 0)
-
- @async_test
- async def test_loadpubkey(self):
- keypath = os.path.join(self.tempdir, 'loadpubkeytest')
- await self.genkey(keypath)
-
- privkey = loadprivkey(keypath)
-
- enc = serialization.Encoding.Raw
- pubformat = serialization.PublicFormat.Raw
-
- pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
-
- pubkey = loadpubkeyraw(keypath + '.pub')
-
- self.assertEqual(pubkeybytes, pubkey)
-
- privrawkey = loadprivkeyraw(keypath)
-
- enc = serialization.Encoding.Raw
- privformat = serialization.PrivateFormat.Raw
- encalgo = serialization.NoEncryption()
-
- rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
-
- self.assertEqual(rprivrawkey, privrawkey)
-
- @async_test
- async def test_end2end(self):
- # Generate necessar keys
- servkeypath = os.path.join(self.tempdir, 'server_key')
- await self.genkey(servkeypath)
- clientkeypath = os.path.join(self.tempdir, 'client_key')
- await self.genkey(clientkeypath)
-
- await asyncio.sleep(.1)
- #import pdb; pdb.set_trace()
-
- # forwards connectsion to this socket (created by client)
- ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
- ptclientstr = _makeunix(ptclientpath)
-
- # this is the socket server listen to
- incservpath = os.path.join(self.tempdir, 'incserv.sock')
- incservstr = _makeunix(incservpath)
-
- # to this socket, opened by server
- servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
- servtargstr = _makeunix(servtargpath)
-
- # Setup server target listener
- ptsock = []
- ptsockevent = asyncio.Event()
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
- ptsockevent.set()
-
- # Bind to pt listener
- lsock = await listensockstr(servtargstr, ptsockaccept)
-
- # Startup the server
- server = await self.run_with_args('server',
- '-c', clientkeypath + '.pub',
- servkeypath, incservstr, servtargstr,
- pipes=False)
-
- # Startup the client
- client = await self.run_with_args('client',
- clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
- pipes=False)
-
- # wait for server target to be created
- await _awaitfile(servtargpath)
-
- # wait for server to start
- await _awaitfile(incservpath)
-
- # wait for client to start
- await _awaitfile(ptclientpath)
-
- await asyncio.sleep(.1)
-
- # Connect to the client
- reader, writer = await connectsockstr(ptclientstr)
-
- # send a message
- ptmsg = b'this is a message for testing'
- writer.write(ptmsg)
-
- # make sure that we got the conenction
- await ptsockevent.wait()
-
- # get the connection
- endrdr, endwrr = ptsock[0]
-
- # make sure we can read back what we sent
- self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
-
- # test some additional messages
- for i in [ 129, 1287, 28792, 129872 ]:
- # in on direction
- msg = os.urandom(i)
- writer.write(msg)
- self.assertEqual(msg, await endrdr.readexactly(len(msg)))
-
- # and the other
- endwrr.write(msg)
- self.assertEqual(msg, await reader.readexactly(len(msg)))
-
- @async_test
- async def test_genkey(self):
- # that it can generate a key
- proc = await self.run_with_args('genkey', 'somefile')
-
- await proc.wait()
-
- #print(await proc.communicate())
-
- self.assertEqual(proc.returncode, 0)
-
- with open('somefile.pub', encoding='ascii') as fp:
- lines = fp.readlines()
- self.assertEqual(len(lines), 1)
-
- keytype, keyvalue = lines[0].split()
-
- self.assertEqual(keytype, 'ntun-x448')
- key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
-
- key = loadprivkey('somefile')
- self.assertIsInstance(key, x448.X448PrivateKey)
-
- # that a second call fails
- proc = await self.run_with_args('genkey', 'somefile')
-
- await proc.wait()
-
- stdoutdata, stderrdata = await proc.communicate()
-
- self.assertFalse(stdoutdata)
- self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
-
- # And that it exited w/ the correct code
- self.assertEqual(proc.returncode, 2)
-
- class TestNoiseFowarder(unittest.TestCase):
- def setUp(self):
- # setup temporary directory
- d = os.path.realpath(tempfile.mkdtemp())
- self.basetempdir = d
- self.tempdir = os.path.join(d, 'subdir')
- os.mkdir(self.tempdir)
-
- # Generate key pairs
- self.server_key_pair = genkeypair()
- self.client_key_pair = genkeypair()
-
- def tearDown(self):
- shutil.rmtree(self.basetempdir)
- self.tempdir = None
-
- @async_test
- async def test_clientkeymissmatch(self):
- # generate a key that is incorrect
- wrongclient_key_pair = genkeypair()
-
- # the secure socket
- clssockapair, clssockbpair = _asyncsockpair()
- reader, writer = await clssockapair
-
- async def wrongkey(v):
- raise ValueError('no key matches')
-
- # create the server
- servnf = asyncio.create_task(NoiseForwarder('resp',
- clssockbpair, wrongkey,
- priv_key=self.server_key_pair[1]))
-
- # Create client
- proto = NoiseConnection.from_name(
- b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup wrong client key
- proto.set_keypair_from_private_bytes(Keypair.STATIC,
- wrongclient_key_pair[1])
-
- # but the correct server key
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- self.server_key_pair[0])
-
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
-
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
-
- # Send final reply
- message = proto.write_message()
- writer.write(message)
-
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
-
- with self.assertRaises(ValueError):
- await servnf
-
- @async_test
- async def test_server(self):
- # Test is plumbed:
- # (reader, writer) -> servsock ->
- # (rdr, wrr) NoiseForward (reader, writer) ->
- # servptsock -> (ptsock[0], ptsock[1])
- # Path that the server will sit on
- servsockpath = os.path.join(self.tempdir, 'servsock')
- servarg = _makeunix(servsockpath)
-
- # Path that the server will send pt data to
- servptpath = os.path.join(self.tempdir, 'servptsock')
-
- # Setup pt target listener
- pttarg = _makeunix(servptpath)
- ptsock = []
- def ptsockaccept(reader, writer, ptsock=ptsock):
- ptsock.append((reader, writer))
-
- # Bind to pt listener
- lsock = await listensockstr(pttarg, ptsockaccept)
-
- nfs = []
- event = asyncio.Event()
-
- async def runnf(rdr, wrr):
- ptpairfun = asyncio.create_task(connectsockstr(pttarg))
-
- a = await NoiseForwarder('resp',
- _makefut((rdr, wrr)), lambda x: ptpairfun,
- priv_key=self.server_key_pair[1])
-
- nfs.append(a)
- event.set()
-
- # Setup server listener
- ssock = await listensockstr(servarg, runnf)
-
- # Connect to server
- reader, writer = await connectsockstr(servarg)
-
- # Create client
- proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
- proto.set_as_initiator()
-
- # Setup required keys
- proto.set_keypair_from_private_bytes(Keypair.STATIC,
- self.client_key_pair[1])
- proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
- self.server_key_pair[0])
-
- proto.start_handshake()
-
- # Send first message
- message = proto.write_message()
- self.assertEqual(len(message), _handshakelens[0])
- writer.write(message)
-
- # Get response
- respmsg = await reader.readexactly(_handshakelens[1])
- proto.read_message(respmsg)
-
- # Send final reply
- message = proto.write_message()
- writer.write(message)
-
- # Make sure handshake has completed
- self.assertTrue(proto.handshake_finished)
-
- # generate the keys for lengths
- enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
- b'toresp')
- _, declenfun = _genciphfun(proto.get_handshake_hash(),
- b'toinit')
-
- # write a test message
- ptmsg = b'this is a test message that should be a little in length'
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- ptreader, ptwriter = ptsock[0]
- # read the test message
- rptmsg = await ptreader.readexactly(len(ptmsg))
-
- self.assertEqual(rptmsg, ptmsg)
-
- # write a different message
- ptmsg = os.urandom(2843)
- encmsg = proto.encrypt(ptmsg)
- writer.write(enclenfun(encmsg))
- writer.write(encmsg)
-
- # XXX - how to sync?
- await asyncio.sleep(.1)
-
- # read the test message
- rptmsg = await ptreader.readexactly(len(ptmsg))
-
- self.assertEqual(rptmsg, ptmsg)
-
- # now try the other way
- ptmsg = os.urandom(912)
- ptwriter.write(ptmsg)
-
- # find out how much we need to read
- encmsg = await reader.readexactly(2 + 16)
- tlen = declenfun(encmsg)
-
- # read the rest of the message
- rencmsg = await reader.readexactly(tlen - 16)
- tmsg = encmsg[2:] + rencmsg
- rptmsg = proto.decrypt(tmsg)
-
- self.assertEqual(rptmsg, ptmsg)
-
- # shut down sending
- writer.write_eof()
-
- # so pt reader should be shut down
- self.assertEqual(b'', await ptreader.read(1))
- self.assertTrue(ptreader.at_eof())
-
- # shut down pt
- ptwriter.write_eof()
-
- # make sure the enc reader is eof
- self.assertEqual(b'', await reader.read(1))
- self.assertTrue(reader.at_eof())
-
- await event.wait()
-
- self.assertEqual(nfs[0], [ 'dec', 'enc' ])
-
- @async_test
- async def test_serverclient(self):
- # plumbing:
- #
- # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
- #
-
- ptcsockapair, ptcsockbpair = _asyncsockpair()
- ptcareader, ptcawriter = await ptcsockapair
- #ptcsockbpair passed directly
- clssockapair, clssockbpair = _asyncsockpair()
- #both passed directly
- ptssockapair, ptssockbpair = _asyncsockpair()
- #ptssockapair passed directly
- ptsbreader, ptsbwriter = await ptssockbpair
-
- async def validateclientkey(pubkey):
- if pubkey != self.client_key_pair[0]:
- raise ValueError('invalid key')
-
- return await ptssockapair
-
- clientnf = asyncio.create_task(NoiseForwarder('init',
- clssockapair, lambda x: ptcsockbpair,
- priv_key=self.client_key_pair[1],
- pub_key=self.server_key_pair[0]))
- servnf = asyncio.create_task(NoiseForwarder('resp',
- clssockbpair, validateclientkey,
- priv_key=self.server_key_pair[1]))
-
- # send a message
- msga = os.urandom(183)
- ptcawriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
-
- # send a second message
- msga = os.urandom(2834)
- ptcawriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
-
- # send a message larger than the block size
- msga = os.urandom(103958)
- ptcawriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
-
- # send a message the other direction
- msga = os.urandom(103958)
- ptsbwriter.write(msga)
-
- # make sure we get the same message
- self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
-
- # close down the pt writers, the rest should follow
- ptsbwriter.write_eof()
- ptcawriter.write_eof()
-
- # make sure they are closed, and there is no more data
- self.assertEqual(b'', await ptsbreader.read(1))
- self.assertTrue(ptsbreader.at_eof())
- self.assertEqual(b'', await ptcareader.read(1))
- self.assertTrue(ptcareader.at_eof())
-
- self.assertEqual([ 'dec', 'enc' ], await clientnf)
- self.assertEqual([ 'dec', 'enc' ], await servnf)
|