@@ -54,6 +54,47 @@ def loadpubkeyraw(fname):
return base64.urlsafe_b64decode(keyvalue)
class ConnectionValidator(object):
'''This class is used to validate a connection, and initiate the
connection that will be used.'''
async def validatekey(self, hash, key):
'''Validate that the key is authorized to connect. The
connection hash is passed in, so that the authorizate of
the key can be validated later.'''
raise NotImplementedError
async def getconnection(self, hash, key, **kwargs):
'''Return the connection that should be used by this
client.'''
raise NotImplementedError
class GenericConnValidator(object):
'''This is a simple implementation of a ConnectionValidator that
can be used w/ most cases. It checks against the list, and then
calls/awaits the function provided, and returns it's value.'''
def __init__(self, keys, connfun):
'''The parameter keys must be an object that supports the
in operators, aka contains. If the key is in the keys
object, the connection will proceed.
The parameter connfun must be an async function that
returns a StreamReader, StreamWriter pair of the
connection that they session is supposed to use.'''
self._keys = keys
self._connfun = connfun
async def validatekey(self, hash, key):
if key not in self._keys:
raise ValueError('key not authorized: %s' % repr(key))
async def getconnection(self, hash, key):
return await self._connfun()
def genkeypair():
'''Generates a keypair, and returns a tuple of (public, private).
They are encoded as raw bytes, and sutible for use w/ Noise.'''
@@ -214,7 +255,7 @@ def _genciphfun(hash, ad):
return encfun, decfun
async def NoiseForwarder(mode, encrdrwrr, ptpairfun , priv_key, pub_key=None):
async def NoiseForwarder(mode, encrdrwrr, connvalid , priv_key, pub_key=None):
'''A function that forwards data between the plain text pair of
streams to the encrypted session.
@@ -225,12 +266,23 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
reader and writer streams for the encrypted side of the
connection.
The ptpairfun parameter is a function that will be passed the
public key bytes for the remote client. This can be used to
both validate that the correct client is connecting, and to
pass back the correct plain text reader/writer objects that
match the provided static key. The function must be an async
function.
The connvalid parameter is an instance of the ConnectionValidator
class, or one that implements it's methods. The validatekey method
will be passed the session hash, and the remote public key of the
client or server. If the key is not authorized, an exception
must be raised. Any non-exception return from the function means
that the key is authorized, and that the session should continue.
In the case of the initiator, the server's key will be passed,
despite the fact that it was already validated by the XK Noise
Protocol. This is just to keep the calling convention the same,
and it supports moving to an XX protocol possibly in the future
with minimal changes.
Then the getconnection method will be called. It's expected
return is the connection to forward the data on to. The kwargs
may be used in future protocol versions to allow the client to
request a specific resource, or something similar.
In the case of the initiator, pub_key must be provided and will
be used to authenticate the responder side of the connection.
@@ -284,10 +336,13 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake')
sesshash = proto.get_handshake_hash()
clientkey = getattr(proto.get_keypair(Keypair.REMOTE_STATIC),
'public_bytes', None)
try:
reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))
except:
await connvalid.validatekey(sesshash, clientkey)
except Exception:
wrr.close()
raise
@@ -298,11 +353,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
# It is marginally useful as writing patterns likely expose the
# true length. Adding padding could marginally help w/ this.
if mode == 'resp':
_, declenfun = _genciphfun(proto.get_handshake_hash() , b'toresp')
enclenfun, _ = _genciphfun(proto.get_handshake_hash() , b'toinit')
_, declenfun = _genciphfun(sesshash , b'toresp')
enclenfun, _ = _genciphfun(sesshash , b'toinit')
elif mode == 'init':
enclenfun, _ = _genciphfun(proto.get_handshake_hash() , b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash() , b'toinit')
enclenfun, _ = _genciphfun(sesshash , b'toresp')
_, declenfun = _genciphfun(sesshash , b'toinit')
# protocol negotiation
@@ -324,6 +379,8 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
raise RuntimeError('unsupported protovol version received: %d' %
rempv)
reader, writer = await connvalid.getconnection(sesshash, clientkey)
async def decses():
try:
while True:
@@ -519,10 +576,11 @@ def cmd_client(args):
privkey = loadprivkeyraw(args.clientkey)
pubkey = loadpubkeyraw(args.servkey)
async def runnf(rdr, wrr):
connval = GenericConnValidator([ pubkey ],
lambda: _makefut((rdr, wrr)))
encpair = asyncio.create_task(connectsockstr(args.clienttarget))
a = await NoiseForwarder('init',
encpair, lambda x: _makefut((rdr, wrr)),
a = await NoiseForwarder('init', encpair, connval,
priv_key=privkey, pub_key=pubkey)
# Setup client listener
@@ -539,14 +597,10 @@ def cmd_server(args):
pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
async def runnf(rdr, wrr):
async def checkclientfun(clientkey):
if clientkey not in pubkeys:
raise RuntimeError('invalid key provided')
return await connectsockstr(args.servtarget)
connval = GenericConnValidator(pubkeys, lambda: connectsockstr(args.servtarget))
a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
checkclientfun , priv_key=privkey)
connval, priv_key=privkey)
# Setup server listener
ssock = listensockstr(args.servlisten, runnf)
@@ -950,12 +1004,9 @@ class TestNoiseFowarder(unittest.TestCase):
clssockapair, clssockbpair = _asyncsockpair()
reader, writer = await clssockapair
async def wrongkey(v):
raise ValueError('no key matches')
# create the server
servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, wrongkey ,
clssockbpair, GenericConnValidator([], None),
priv_key=self.server_key_pair[1]))
# Create client
@@ -1022,10 +1073,12 @@ class TestNoiseFowarder(unittest.TestCase):
event = asyncio.Event()
async def runnf(rdr, wrr):
ptpairfun = asyncio.create_task(connectsockstr(pttarg))
connval = GenericConnValidator(
[ self.client_key_pair[0] ],
lambda: connectsockstr(pttarg))
a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), lambda x: ptpairfun ,
_makefut((rdr, wrr)), connval ,
priv_key=self.server_key_pair[1])
nfs.append(a)
@@ -1193,8 +1246,11 @@ class TestNoiseFowarder(unittest.TestCase):
ptpairfun = asyncio.create_task(connectsockstr(pttarg))
try:
connval = GenericConnValidator(
[ self.client_key_pair[0] ],
lambda: ptpairfun)
a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), lambda x: ptpairfun,
_makefut((rdr, wrr)), connval ,
priv_key=self.server_key_pair[1])
except RuntimeError as e:
nfs.append(e)
@@ -1283,17 +1339,17 @@ class TestNoiseFowarder(unittest.TestCase):
#ptssockapair passed directly
ptsbreader, ptsbwriter = await ptssockbpair
async def validateclientkey(pubkey):
self.assertEqual(pubkey, self.client_key_pair[0] )
return await ptssockapair
validateclientside = GenericConnValidator(
[ self.server_key_pair[0] ], lambda: ptcsockbpair )
validateserverside = GenericConnValidator(
[ self.client_key_pair[0] ], lambda: ptssockapair)
clientnf = asyncio.create_task(NoiseForwarder('init',
clssockapair, lambda x: ptcsockbpair ,
clssockapair, validateclientside ,
priv_key=self.client_key_pair[1],
pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, validateclientkey ,
clssockbpair, validateserverside ,
priv_key=self.server_key_pair[1]))
# send a message