Browse Source

change ptpair to a function and pass in the remote static key (if

we're the responder) so that the responder can validate and possibly
change the route based upon it...
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
d33583e340
2 changed files with 97 additions and 12 deletions
  1. +96
    -11
      ntunnel.py
  2. +1
    -1
      requirements.txt

+ 96
- 11
ntunnel.py View File

@@ -151,8 +151,35 @@ def _genciphfun(hash, ad):


return encfun, decfun return encfun, decfun


async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
rdr, wrr = await rdrwrr
async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
'''A function that forwards data between the plain text pair of
streams to the encrypted session.

The mode paramater must be one of 'init' or 'resp' for initiator
and responder.

The encrdrwrr is an await object that will return a tunle of the
reader and writer streams for the encrypted side of the
connection.

The ptpairfun parameter is a function that will be passed the
public key bytes for the remote client. This can be used to
both validate that the correct client is connecting, and to
pass back the correct plain text reader/writer objects that
match the provided static key. The function must be an async
function.

In the case of the initiator, pub_key must be provided and will
be used to authenticate the responder side of the connection.

The priv_key parameter is used to authenticate this side of the
session.

Both priv_key and pub_key parameters must be 56 bytes. For example,
the pair that is returned by genkeypair.
'''

rdr, wrr = await encrdrwrr


proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')


@@ -185,6 +212,9 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
if not proto.handshake_finished: # pragma: no cover if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake') raise RuntimeError('failed to finish handshake')


reader, writer = await ptpairfun(getattr(proto.get_keypair(
Keypair.REMOTE_STATIC), 'public_bytes', None))

# generate the keys for lengths # generate the keys for lengths
# XXX - get_handshake_hash is probably not the best option, but # XXX - get_handshake_hash is probably not the best option, but
# this is only to obscure lengths, it is not required to be secure # this is only to obscure lengths, it is not required to be secure
@@ -198,8 +228,6 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')


reader, writer = await ptpair

async def decses(): async def decses():
try: try:
while True: while True:
@@ -292,7 +320,7 @@ def cmd_client(args):
encpair = asyncio.create_task(connectsockstr(args.clienttarget)) encpair = asyncio.create_task(connectsockstr(args.clienttarget))


a = await NoiseForwarder('init', a = await NoiseForwarder('init',
encpair, _makefut((rdr, wrr)),
encpair, lambda x: _makefut((rdr, wrr)),
priv_key=privkey, pub_key=pubkey) priv_key=privkey, pub_key=pubkey)


# Setup client listener # Setup client listener
@@ -308,10 +336,10 @@ def cmd_server(args):
privkey = loadprivkeyraw(args.servkey) privkey = loadprivkeyraw(args.servkey)


async def runnf(rdr, wrr): async def runnf(rdr, wrr):
ptpair = asyncio.create_task(connectsockstr(args.servtarget))
ptpairfun = asyncio.create_task(connectsockstr(args.servtarget))


a = await NoiseForwarder('resp', a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), ptpair,
_makefut((rdr, wrr)), lambda x: ptpairfun,
priv_key=privkey) priv_key=privkey)


# Setup server listener # Setup server listener
@@ -607,6 +635,57 @@ class TestNoiseFowarder(unittest.TestCase):
shutil.rmtree(self.basetempdir) shutil.rmtree(self.basetempdir)
self.tempdir = None self.tempdir = None


@async_test
async def test_clientkeymissmatch(self):
# generate a key that is incorrect
wrongclient_key_pair = genkeypair()

# the secure socket
clssockapair, clssockbpair = _asyncsockpair()
reader, writer = await clssockapair

async def wrongkey(v):
raise ValueError('no key matches')

# create the server
servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, wrongkey,
priv_key=self.server_key_pair[1]))

# Create client
proto = NoiseConnection.from_name(
b'Noise_XK_448_ChaChaPoly_SHA256')
proto.set_as_initiator()

# Setup wrong client key
proto.set_keypair_from_private_bytes(Keypair.STATIC,
wrongclient_key_pair[1])

# but the correct server key
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
self.server_key_pair[0])

proto.start_handshake()

# Send first message
message = proto.write_message()
self.assertEqual(len(message), _handshakelens[0])
writer.write(message)

# Get response
respmsg = await reader.readexactly(_handshakelens[1])
proto.read_message(respmsg)

# Send final reply
message = proto.write_message()
writer.write(message)

# Make sure handshake has completed
self.assertTrue(proto.handshake_finished)

with self.assertRaises(ValueError):
await servnf

@async_test @async_test
async def test_server(self): async def test_server(self):
# Test is plumbed: # Test is plumbed:
@@ -633,10 +712,10 @@ class TestNoiseFowarder(unittest.TestCase):
event = asyncio.Event() event = asyncio.Event()


async def runnf(rdr, wrr): async def runnf(rdr, wrr):
ptpair = asyncio.create_task(connectsockstr(pttarg))
ptpairfun = asyncio.create_task(connectsockstr(pttarg))


a = await NoiseForwarder('resp', a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), ptpair,
_makefut((rdr, wrr)), lambda x: ptpairfun,
priv_key=self.server_key_pair[1]) priv_key=self.server_key_pair[1])


nfs.append(a) nfs.append(a)
@@ -760,12 +839,18 @@ class TestNoiseFowarder(unittest.TestCase):
#ptssockapair passed directly #ptssockapair passed directly
ptsbreader, ptsbwriter = await ptssockbpair ptsbreader, ptsbwriter = await ptssockbpair


async def validateclientkey(pubkey):
if pubkey != self.client_key_pair[0]:
raise ValueError('invalid key')

return await ptssockapair

clientnf = asyncio.create_task(NoiseForwarder('init', clientnf = asyncio.create_task(NoiseForwarder('init',
clssockapair, ptcsockbpair,
clssockapair, lambda x: ptcsockbpair,
priv_key=self.client_key_pair[1], priv_key=self.client_key_pair[1],
pub_key=self.server_key_pair[0])) pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp', servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, ptssockapair,
clssockbpair, validateclientkey,
priv_key=self.server_key_pair[1])) priv_key=self.server_key_pair[1]))


# send a message # send a message


+ 1
- 1
requirements.txt View File

@@ -1,2 +1,2 @@
coverage coverage
-e git+https://github.com/jmgurney/noiseprotocol.git@ab6f8ebe0e28f5a4105928c13baddcfdc43b7e82#egg=noiseprotocol
-e git+https://github.com/jmgurney/noiseprotocol.git@f1c048242c807328724c8119505293975fe7c614#egg=noiseprotocol

Loading…
Cancel
Save