diff --git a/ntunnel.py b/ntunnel.py index 2ffa262..cd97c73 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -151,8 +151,35 @@ def _genciphfun(hash, ad): return encfun, decfun -async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): - rdr, wrr = await rdrwrr +async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): + '''A function that forwards data between the plain text pair of + streams to the encrypted session. + + The mode paramater must be one of 'init' or 'resp' for initiator + and responder. + + The encrdrwrr is an await object that will return a tunle of the + reader and writer streams for the encrypted side of the + connection. + + The ptpairfun parameter is a function that will be passed the + public key bytes for the remote client. This can be used to + both validate that the correct client is connecting, and to + pass back the correct plain text reader/writer objects that + match the provided static key. The function must be an async + function. + + In the case of the initiator, pub_key must be provided and will + be used to authenticate the responder side of the connection. + + The priv_key parameter is used to authenticate this side of the + session. + + Both priv_key and pub_key parameters must be 56 bytes. For example, + the pair that is returned by genkeypair. + ''' + + rdr, wrr = await encrdrwrr proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') @@ -185,6 +212,9 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') + reader, writer = await ptpairfun(getattr(proto.get_keypair( + Keypair.REMOTE_STATIC), 'public_bytes', None)) + # generate the keys for lengths # XXX - get_handshake_hash is probably not the best option, but # this is only to obscure lengths, it is not required to be secure @@ -198,8 +228,6 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') - reader, writer = await ptpair - async def decses(): try: while True: @@ -292,7 +320,7 @@ def cmd_client(args): encpair = asyncio.create_task(connectsockstr(args.clienttarget)) a = await NoiseForwarder('init', - encpair, _makefut((rdr, wrr)), + encpair, lambda x: _makefut((rdr, wrr)), priv_key=privkey, pub_key=pubkey) # Setup client listener @@ -308,10 +336,10 @@ def cmd_server(args): privkey = loadprivkeyraw(args.servkey) async def runnf(rdr, wrr): - ptpair = asyncio.create_task(connectsockstr(args.servtarget)) + ptpairfun = asyncio.create_task(connectsockstr(args.servtarget)) a = await NoiseForwarder('resp', - _makefut((rdr, wrr)), ptpair, + _makefut((rdr, wrr)), lambda x: ptpairfun, priv_key=privkey) # Setup server listener @@ -607,6 +635,57 @@ class TestNoiseFowarder(unittest.TestCase): shutil.rmtree(self.basetempdir) self.tempdir = None + @async_test + async def test_clientkeymissmatch(self): + # generate a key that is incorrect + wrongclient_key_pair = genkeypair() + + # the secure socket + clssockapair, clssockbpair = _asyncsockpair() + reader, writer = await clssockapair + + async def wrongkey(v): + raise ValueError('no key matches') + + # create the server + servnf = asyncio.create_task(NoiseForwarder('resp', + clssockbpair, wrongkey, + priv_key=self.server_key_pair[1])) + + # Create client + proto = NoiseConnection.from_name( + b'Noise_XK_448_ChaChaPoly_SHA256') + proto.set_as_initiator() + + # Setup wrong client key + proto.set_keypair_from_private_bytes(Keypair.STATIC, + wrongclient_key_pair[1]) + + # but the correct server key + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, + self.server_key_pair[0]) + + proto.start_handshake() + + # Send first message + message = proto.write_message() + self.assertEqual(len(message), _handshakelens[0]) + writer.write(message) + + # Get response + respmsg = await reader.readexactly(_handshakelens[1]) + proto.read_message(respmsg) + + # Send final reply + message = proto.write_message() + writer.write(message) + + # Make sure handshake has completed + self.assertTrue(proto.handshake_finished) + + with self.assertRaises(ValueError): + await servnf + @async_test async def test_server(self): # Test is plumbed: @@ -633,10 +712,10 @@ class TestNoiseFowarder(unittest.TestCase): event = asyncio.Event() async def runnf(rdr, wrr): - ptpair = asyncio.create_task(connectsockstr(pttarg)) + ptpairfun = asyncio.create_task(connectsockstr(pttarg)) a = await NoiseForwarder('resp', - _makefut((rdr, wrr)), ptpair, + _makefut((rdr, wrr)), lambda x: ptpairfun, priv_key=self.server_key_pair[1]) nfs.append(a) @@ -760,12 +839,18 @@ class TestNoiseFowarder(unittest.TestCase): #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair + async def validateclientkey(pubkey): + if pubkey != self.client_key_pair[0]: + raise ValueError('invalid key') + + return await ptssockapair + clientnf = asyncio.create_task(NoiseForwarder('init', - clssockapair, ptcsockbpair, + clssockapair, lambda x: ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) servnf = asyncio.create_task(NoiseForwarder('resp', - clssockbpair, ptssockapair, + clssockbpair, validateclientkey, priv_key=self.server_key_pair[1])) # send a message diff --git a/requirements.txt b/requirements.txt index 8c1878d..fbf6982 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ coverage --e git+https://github.com/jmgurney/noiseprotocol.git@ab6f8ebe0e28f5a4105928c13baddcfdc43b7e82#egg=noiseprotocol +-e git+https://github.com/jmgurney/noiseprotocol.git@f1c048242c807328724c8119505293975fe7c614#egg=noiseprotocol