Browse Source

minor reorg work..

make both pairs come from an awaitable object, this will be needed for
future work...
keep lines under 80 chars...
Minor reorg of the main loop to make it clearer...
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
be2e5fd6da
1 changed files with 47 additions and 19 deletions
  1. +47
    -19
      ntunnel.py

+ 47
- 19
ntunnel.py View File

@@ -15,6 +15,13 @@ import unittest


_backend = default_backend() _backend = default_backend()


def _makefut(obj):
loop = asyncio.get_running_loop()
fut = loop.create_future()
fut.set_result(obj)

return fut

def _makeunix(path): def _makeunix(path):
'''Make a properly formed unix path socket string.''' '''Make a properly formed unix path socket string.'''


@@ -91,28 +98,30 @@ def _genciphfun(hash, ad):
return encfun, decfun return encfun, decfun


async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None): async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
rdr, wrr = rdrwrr
rdr, wrr = await rdrwrr


proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')


proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
if pub_key is not None: if pub_key is not None:
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key)
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
pub_key)


if mode == 'resp': if mode == 'resp':
proto.set_as_responder() proto.set_as_responder()
elif mode == 'init':
proto.set_as_initiator()


proto.start_handshake()
proto.start_handshake()


if mode == 'resp':
proto.read_message(await rdr.readexactly(_handshakelens[0])) proto.read_message(await rdr.readexactly(_handshakelens[0]))


wrr.write(proto.write_message()) wrr.write(proto.write_message())


proto.read_message(await rdr.readexactly(_handshakelens[2])) proto.read_message(await rdr.readexactly(_handshakelens[2]))
elif mode == 'init': elif mode == 'init':
proto.set_as_initiator()

proto.start_handshake()

wrr.write(proto.write_message()) wrr.write(proto.write_message())


proto.read_message(await rdr.readexactly(_handshakelens[1])) proto.read_message(await rdr.readexactly(_handshakelens[1]))
@@ -156,7 +165,8 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
async def encses(): async def encses():
try: try:
while True: while True:
ptmsg = await reader.read(65535 - 16) # largest message
# largest message
ptmsg = await reader.read(65535 - 16)
if not ptmsg: if not ptmsg:
# eof # eof
return 'enc' return 'enc'
@@ -211,7 +221,8 @@ def _asyncsockpair():


socka, sockb = socket.socketpair() socka, sockb = socket.socketpair()


return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb)
return asyncio.open_connection(sock=socka), \
asyncio.open_connection(sock=sockb)


class Tests(unittest.TestCase): class Tests(unittest.TestCase):
def setUp(self): def setUp(self):
@@ -256,7 +267,10 @@ class Tests(unittest.TestCase):


async def runnf(rdr, wrr): async def runnf(rdr, wrr):
ptpair = asyncio.create_task(connectsockstr(pttarg)) ptpair = asyncio.create_task(connectsockstr(pttarg))
a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1])

a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), ptpair,
priv_key=self.server_key_pair[1])


nfs.append(a) nfs.append(a)
event.set() event.set()
@@ -272,8 +286,10 @@ class Tests(unittest.TestCase):
proto.set_as_initiator() proto.set_as_initiator()


# Setup required keys # Setup required keys
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
proto.set_keypair_from_private_bytes(Keypair.STATIC,
self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
self.server_key_pair[0])


proto.start_handshake() proto.start_handshake()


@@ -294,8 +310,10 @@ class Tests(unittest.TestCase):
self.assertTrue(proto.handshake_finished) self.assertTrue(proto.handshake_finished)


# generate the keys for lengths # generate the keys for lengths
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(),
b'toinit')


# write a test message # write a test message
ptmsg = b'this is a test message that should be a little in length' ptmsg = b'this is a test message that should be a little in length'
@@ -345,14 +363,14 @@ class Tests(unittest.TestCase):
writer.write_eof() writer.write_eof()


# so pt reader should be shut down # so pt reader should be shut down
r = await ptreader.read(1)
self.assertEqual(b'', await ptreader.read(1))
self.assertTrue(ptreader.at_eof()) self.assertTrue(ptreader.at_eof())


# shut down pt # shut down pt
ptwriter.write_eof() ptwriter.write_eof()


# make sure the enc reader is eof # make sure the enc reader is eof
r = await reader.read(1)
self.assertEqual(b'', await reader.read(1))
self.assertTrue(reader.at_eof()) self.assertTrue(reader.at_eof())


await event.wait() await event.wait()
@@ -370,14 +388,18 @@ class Tests(unittest.TestCase):
ptcareader, ptcawriter = await ptcsockapair ptcareader, ptcawriter = await ptcsockapair
#ptcsockbpair passed directly #ptcsockbpair passed directly
clssockapair, clssockbpair = _asyncsockpair() clssockapair, clssockbpair = _asyncsockpair()
clsapair = await clssockapair
clsbpair = await clssockbpair
#both passed directly
ptssockapair, ptssockbpair = _asyncsockpair() ptssockapair, ptssockbpair = _asyncsockpair()
#ptssockapair passed directly #ptssockapair passed directly
ptsbreader, ptsbwriter = await ptssockbpair ptsbreader, ptsbwriter = await ptssockbpair


clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1]))
clientnf = asyncio.create_task(NoiseForwarder('init',
clssockapair, ptcsockbpair,
priv_key=self.client_key_pair[1],
pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, ptssockapair,
priv_key=self.server_key_pair[1]))


# send a message # send a message
msga = os.urandom(183) msga = os.urandom(183)
@@ -411,5 +433,11 @@ class Tests(unittest.TestCase):
ptsbwriter.write_eof() ptsbwriter.write_eof()
ptcawriter.write_eof() ptcawriter.write_eof()


# make sure they are closed, and there is no more data
self.assertEqual(b'', await ptsbreader.read(1))
self.assertTrue(ptsbreader.at_eof())
self.assertEqual(b'', await ptcareader.read(1))
self.assertTrue(ptcareader.at_eof())

self.assertEqual([ 'dec', 'enc' ], await clientnf) self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf) self.assertEqual([ 'dec', 'enc' ], await servnf)

Loading…
Cancel
Save