diff --git a/ntunnel.py b/ntunnel.py index 3aea98b..c895c02 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -54,6 +54,47 @@ def loadpubkeyraw(fname): return base64.urlsafe_b64decode(keyvalue) +class ConnectionValidator(object): + '''This class is used to validate a connection, and initiate the + connection that will be used.''' + + async def validatekey(self, hash, key): + '''Validate that the key is authorized to connect. The + connection hash is passed in, so that the authorizate of + the key can be validated later.''' + + raise NotImplementedError + + async def getconnection(self, hash, key, **kwargs): + '''Return the connection that should be used by this + client.''' + + raise NotImplementedError + +class GenericConnValidator(object): + '''This is a simple implementation of a ConnectionValidator that + can be used w/ most cases. It checks against the list, and then + calls/awaits the function provided, and returns it's value.''' + + def __init__(self, keys, connfun): + '''The parameter keys must be an object that supports the + in operators, aka contains. If the key is in the keys + object, the connection will proceed. + + The parameter connfun must be an async function that + returns a StreamReader, StreamWriter pair of the + connection that they session is supposed to use.''' + + self._keys = keys + self._connfun = connfun + + async def validatekey(self, hash, key): + if key not in self._keys: + raise ValueError('key not authorized: %s' % repr(key)) + + async def getconnection(self, hash, key): + return await self._connfun() + def genkeypair(): '''Generates a keypair, and returns a tuple of (public, private). They are encoded as raw bytes, and sutible for use w/ Noise.''' @@ -214,7 +255,7 @@ def _genciphfun(hash, ad): return encfun, decfun -async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): +async def NoiseForwarder(mode, encrdrwrr, connvalid, priv_key, pub_key=None): '''A function that forwards data between the plain text pair of streams to the encrypted session. @@ -225,12 +266,23 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): reader and writer streams for the encrypted side of the connection. - The ptpairfun parameter is a function that will be passed the - public key bytes for the remote client. This can be used to - both validate that the correct client is connecting, and to - pass back the correct plain text reader/writer objects that - match the provided static key. The function must be an async - function. + The connvalid parameter is an instance of the ConnectionValidator + class, or one that implements it's methods. The validatekey method + will be passed the session hash, and the remote public key of the + client or server. If the key is not authorized, an exception + must be raised. Any non-exception return from the function means + that the key is authorized, and that the session should continue. + + In the case of the initiator, the server's key will be passed, + despite the fact that it was already validated by the XK Noise + Protocol. This is just to keep the calling convention the same, + and it supports moving to an XX protocol possibly in the future + with minimal changes. + + Then the getconnection method will be called. It's expected + return is the connection to forward the data on to. The kwargs + may be used in future protocol versions to allow the client to + request a specific resource, or something similar. In the case of the initiator, pub_key must be provided and will be used to authenticate the responder side of the connection. @@ -284,10 +336,13 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): if not proto.handshake_finished: # pragma: no cover raise RuntimeError('failed to finish handshake') + sesshash = proto.get_handshake_hash() + clientkey = getattr(proto.get_keypair(Keypair.REMOTE_STATIC), + 'public_bytes', None) + try: - reader, writer = await ptpairfun(getattr(proto.get_keypair( - Keypair.REMOTE_STATIC), 'public_bytes', None)) - except: + await connvalid.validatekey(sesshash, clientkey) + except Exception: wrr.close() raise @@ -298,11 +353,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): # It is marginally useful as writing patterns likely expose the # true length. Adding padding could marginally help w/ this. if mode == 'resp': - _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp') - enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit') + _, declenfun = _genciphfun(sesshash, b'toresp') + enclenfun, _ = _genciphfun(sesshash, b'toinit') elif mode == 'init': - enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') - _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') + enclenfun, _ = _genciphfun(sesshash, b'toresp') + _, declenfun = _genciphfun(sesshash, b'toinit') # protocol negotiation @@ -324,6 +379,8 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): raise RuntimeError('unsupported protovol version received: %d' % rempv) + reader, writer = await connvalid.getconnection(sesshash, clientkey) + async def decses(): try: while True: @@ -519,10 +576,11 @@ def cmd_client(args): privkey = loadprivkeyraw(args.clientkey) pubkey = loadpubkeyraw(args.servkey) async def runnf(rdr, wrr): + connval = GenericConnValidator([ pubkey ], + lambda: _makefut((rdr, wrr))) encpair = asyncio.create_task(connectsockstr(args.clienttarget)) - a = await NoiseForwarder('init', - encpair, lambda x: _makefut((rdr, wrr)), + a = await NoiseForwarder('init', encpair, connval, priv_key=privkey, pub_key=pubkey) # Setup client listener @@ -539,14 +597,10 @@ def cmd_server(args): pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ] async def runnf(rdr, wrr): - async def checkclientfun(clientkey): - if clientkey not in pubkeys: - raise RuntimeError('invalid key provided') - - return await connectsockstr(args.servtarget) + connval = GenericConnValidator(pubkeys, lambda: connectsockstr(args.servtarget)) a = await NoiseForwarder('resp', _makefut((rdr, wrr)), - checkclientfun, priv_key=privkey) + connval, priv_key=privkey) # Setup server listener ssock = listensockstr(args.servlisten, runnf) @@ -950,12 +1004,9 @@ class TestNoiseFowarder(unittest.TestCase): clssockapair, clssockbpair = _asyncsockpair() reader, writer = await clssockapair - async def wrongkey(v): - raise ValueError('no key matches') - # create the server servnf = asyncio.create_task(NoiseForwarder('resp', - clssockbpair, wrongkey, + clssockbpair, GenericConnValidator([], None), priv_key=self.server_key_pair[1])) # Create client @@ -1022,10 +1073,12 @@ class TestNoiseFowarder(unittest.TestCase): event = asyncio.Event() async def runnf(rdr, wrr): - ptpairfun = asyncio.create_task(connectsockstr(pttarg)) + connval = GenericConnValidator( + [ self.client_key_pair[0] ], + lambda: connectsockstr(pttarg)) a = await NoiseForwarder('resp', - _makefut((rdr, wrr)), lambda x: ptpairfun, + _makefut((rdr, wrr)), connval, priv_key=self.server_key_pair[1]) nfs.append(a) @@ -1193,8 +1246,11 @@ class TestNoiseFowarder(unittest.TestCase): ptpairfun = asyncio.create_task(connectsockstr(pttarg)) try: + connval = GenericConnValidator( + [ self.client_key_pair[0] ], + lambda: ptpairfun) a = await NoiseForwarder('resp', - _makefut((rdr, wrr)), lambda x: ptpairfun, + _makefut((rdr, wrr)), connval, priv_key=self.server_key_pair[1]) except RuntimeError as e: nfs.append(e) @@ -1283,17 +1339,17 @@ class TestNoiseFowarder(unittest.TestCase): #ptssockapair passed directly ptsbreader, ptsbwriter = await ptssockbpair - async def validateclientkey(pubkey): - self.assertEqual(pubkey, self.client_key_pair[0]) - - return await ptssockapair + validateclientside = GenericConnValidator( + [ self.server_key_pair[0] ], lambda: ptcsockbpair) + validateserverside = GenericConnValidator( + [ self.client_key_pair[0] ], lambda: ptssockapair) clientnf = asyncio.create_task(NoiseForwarder('init', - clssockapair, lambda x: ptcsockbpair, + clssockapair, validateclientside, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0])) servnf = asyncio.create_task(NoiseForwarder('resp', - clssockbpair, validateclientkey, + clssockbpair, validateserverside, priv_key=self.server_key_pair[1])) # send a message