Browse Source

minor reorg work..

make both pairs come from an awaitable object, this will be needed for
future work...
keep lines under 80 chars...
Minor reorg of the main loop to make it clearer...
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
be2e5fd6da
1 changed files with 47 additions and 19 deletions
  1. +47
    -19
      ntunnel.py

+ 47
- 19
ntunnel.py View File

@@ -15,6 +15,13 @@ import unittest

_backend = default_backend()

def _makefut(obj):
loop = asyncio.get_running_loop()
fut = loop.create_future()
fut.set_result(obj)

return fut

def _makeunix(path):
'''Make a properly formed unix path socket string.'''

@@ -91,28 +98,30 @@ def _genciphfun(hash, ad):
return encfun, decfun

async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
rdr, wrr = rdrwrr
rdr, wrr = await rdrwrr

proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')

proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
if pub_key is not None:
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key)
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
pub_key)

if mode == 'resp':
proto.set_as_responder()
elif mode == 'init':
proto.set_as_initiator()

proto.start_handshake()
proto.start_handshake()

if mode == 'resp':
proto.read_message(await rdr.readexactly(_handshakelens[0]))

wrr.write(proto.write_message())

proto.read_message(await rdr.readexactly(_handshakelens[2]))
elif mode == 'init':
proto.set_as_initiator()

proto.start_handshake()

wrr.write(proto.write_message())

proto.read_message(await rdr.readexactly(_handshakelens[1]))
@@ -156,7 +165,8 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
async def encses():
try:
while True:
ptmsg = await reader.read(65535 - 16) # largest message
# largest message
ptmsg = await reader.read(65535 - 16)
if not ptmsg:
# eof
return 'enc'
@@ -211,7 +221,8 @@ def _asyncsockpair():

socka, sockb = socket.socketpair()

return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb)
return asyncio.open_connection(sock=socka), \
asyncio.open_connection(sock=sockb)

class Tests(unittest.TestCase):
def setUp(self):
@@ -256,7 +267,10 @@ class Tests(unittest.TestCase):

async def runnf(rdr, wrr):
ptpair = asyncio.create_task(connectsockstr(pttarg))
a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1])

a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), ptpair,
priv_key=self.server_key_pair[1])

nfs.append(a)
event.set()
@@ -272,8 +286,10 @@ class Tests(unittest.TestCase):
proto.set_as_initiator()

# Setup required keys
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
proto.set_keypair_from_private_bytes(Keypair.STATIC,
self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
self.server_key_pair[0])

proto.start_handshake()

@@ -294,8 +310,10 @@ class Tests(unittest.TestCase):
self.assertTrue(proto.handshake_finished)

# generate the keys for lengths
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(),
b'toinit')

# write a test message
ptmsg = b'this is a test message that should be a little in length'
@@ -345,14 +363,14 @@ class Tests(unittest.TestCase):
writer.write_eof()

# so pt reader should be shut down
r = await ptreader.read(1)
self.assertEqual(b'', await ptreader.read(1))
self.assertTrue(ptreader.at_eof())

# shut down pt
ptwriter.write_eof()

# make sure the enc reader is eof
r = await reader.read(1)
self.assertEqual(b'', await reader.read(1))
self.assertTrue(reader.at_eof())

await event.wait()
@@ -370,14 +388,18 @@ class Tests(unittest.TestCase):
ptcareader, ptcawriter = await ptcsockapair
#ptcsockbpair passed directly
clssockapair, clssockbpair = _asyncsockpair()
clsapair = await clssockapair
clsbpair = await clssockbpair
#both passed directly
ptssockapair, ptssockbpair = _asyncsockpair()
#ptssockapair passed directly
ptsbreader, ptsbwriter = await ptssockbpair

clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1]))
clientnf = asyncio.create_task(NoiseForwarder('init',
clssockapair, ptcsockbpair,
priv_key=self.client_key_pair[1],
pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, ptssockapair,
priv_key=self.server_key_pair[1]))

# send a message
msga = os.urandom(183)
@@ -411,5 +433,11 @@ class Tests(unittest.TestCase):
ptsbwriter.write_eof()
ptcawriter.write_eof()

# make sure they are closed, and there is no more data
self.assertEqual(b'', await ptsbreader.read(1))
self.assertTrue(ptsbreader.at_eof())
self.assertEqual(b'', await ptcareader.read(1))
self.assertTrue(ptcareader.at_eof())

self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)

Loading…
Cancel
Save