Browse Source

get the other side working... This was a lot easier to get functional

than I expected...
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
beb82c02e5
1 changed files with 105 additions and 16 deletions
  1. +105
    -16
      ntunnel.py

+ 105
- 16
ntunnel.py View File

@@ -90,31 +90,47 @@ def _genciphfun(hash, ad):


return encfun, decfun return encfun, decfun


async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
rdr, wrr = rdrwrr rdr, wrr = rdrwrr


proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')


proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key) proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
if pub_key is not None:
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key)


proto.set_as_responder()
if mode == 'resp':
proto.set_as_responder()
elif mode == 'init':
proto.set_as_initiator()


proto.start_handshake() proto.start_handshake()


proto.read_message(await rdr.readexactly(_handshakelens[0]))
if mode == 'resp':
proto.read_message(await rdr.readexactly(_handshakelens[0]))

wrr.write(proto.write_message())

proto.read_message(await rdr.readexactly(_handshakelens[2]))
elif mode == 'init':
wrr.write(proto.write_message())


wrr.write(proto.write_message())
proto.read_message(await rdr.readexactly(_handshakelens[1]))


proto.read_message(await rdr.readexactly(_handshakelens[2]))
wrr.write(proto.write_message())


if not proto.handshake_finished: # pragma: no cover if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake') raise RuntimeError('failed to finish handshake')


# generate the keys for lengths # generate the keys for lengths
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
if mode == 'resp':
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
elif mode == 'init':
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')


reader, writer = await connectsockstr(ptsockstr)
reader, writer = await ptpair


async def decses(): async def decses():
try: try:
@@ -130,6 +146,10 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
tmsg = msg[2:] + rmsg tmsg = msg[2:] + rmsg
writer.write(proto.decrypt(tmsg)) writer.write(proto.decrypt(tmsg))
await writer.drain() await writer.drain()
#except:
# import traceback
# traceback.print_exc()
# raise
finally: finally:
writer.write_eof() writer.write_eof()


@@ -145,17 +165,17 @@ async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
wrr.write(enclenfun(encmsg)) wrr.write(enclenfun(encmsg))
wrr.write(encmsg) wrr.write(encmsg)
await wrr.drain() await wrr.drain()
#except:
# import traceback
# traceback.print_exc()
# raise
finally: finally:
wrr.write_eof() wrr.write_eof()


return asyncio.gather(decses(), encses())

class TestListenSocket(unittest.TestCase):
def test_listensockstr(self):
# XXX write test
pass
return await asyncio.gather(decses(), encses())


# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
# Slightly modified to timeout
def async_test(f): def async_test(f):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
coro = asyncio.coroutine(f) coro = asyncio.coroutine(f)
@@ -167,6 +187,10 @@ def async_test(f):
return wrapper return wrapper


class Tests_misc(unittest.TestCase): class Tests_misc(unittest.TestCase):
def test_listensockstr(self):
# XXX write test
pass

def test_genciphfun(self): def test_genciphfun(self):
enc, dec = _genciphfun(b'0' * 32, b'foobar') enc, dec = _genciphfun(b'0' * 32, b'foobar')


@@ -180,6 +204,15 @@ class Tests_misc(unittest.TestCase):
msg = os.urandom(i) msg = os.urandom(i)
self.assertEqual(len(msg), dec(enc(msg) + msg)) self.assertEqual(len(msg), dec(enc(msg) + msg))


def _asyncsockpair():
'''Create a pair of sockets that are bound to each other.
The function will return a tuple of two coroutine's, that
each, when await'ed upon, will return the reader/writer pair.'''

socka, sockb = socket.socketpair()

return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb)

class Tests(unittest.TestCase): class Tests(unittest.TestCase):
def setUp(self): def setUp(self):
# setup temporary directory # setup temporary directory
@@ -222,7 +255,8 @@ class Tests(unittest.TestCase):
event = asyncio.Event() event = asyncio.Event()


async def runnf(rdr, wrr): async def runnf(rdr, wrr):
a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)
ptpair = asyncio.create_task(connectsockstr(pttarg))
a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1])


nfs.append(a) nfs.append(a)
event.set() event.set()
@@ -323,4 +357,59 @@ class Tests(unittest.TestCase):


await event.wait() await event.wait()


self.assertEqual(await nfs[0], [ 'dec', 'enc' ])
self.assertEqual(nfs[0], [ 'dec', 'enc' ])

@async_test
async def test_serverclient(self):
# plumbing:
#
# ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
#

ptcsockapair, ptcsockbpair = _asyncsockpair()
ptcareader, ptcawriter = await ptcsockapair
#ptcsockbpair passed directly
clssockapair, clssockbpair = _asyncsockpair()
clsapair = await clssockapair
clsbpair = await clssockbpair
ptssockapair, ptssockbpair = _asyncsockpair()
#ptssockapair passed directly
ptsbreader, ptsbwriter = await ptssockbpair

clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1]))

# send a message
msga = os.urandom(183)
ptcawriter.write(msga)

# make sure we get the same message
self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))

# send a second message
msga = os.urandom(2834)
ptcawriter.write(msga)

# make sure we get the same message
self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))

# send a message larger than the block size
msga = os.urandom(103958)
ptcawriter.write(msga)

# make sure we get the same message
self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))

# send a message the other direction
msga = os.urandom(103958)
ptsbwriter.write(msga)

# make sure we get the same message
self.assertEqual(msga, await ptcareader.readexactly(len(msga)))

# close down the pt writers, the rest should follow
ptsbwriter.write_eof()
ptcawriter.write_eof()

self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)

Loading…
Cancel
Save