Browse Source

switch to using asyncio. It's so much cleaner than twisted.

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
fa9a3d20dc
1 changed files with 202 additions and 28 deletions
  1. +202
    -28
      ntunnel.py

+ 202
- 28
ntunnel.py View File

@@ -1,6 +1,11 @@
from noise.connection import NoiseConnection, Keypair
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import hashes
from twistednoise import genkeypair
from cryptography.hazmat.backends import default_backend

import asyncio
import os.path
import shutil
import socket
@@ -8,23 +13,27 @@ import tempfile
import threading
import unittest

_backend = default_backend()

def _makeunix(path):
'''Make a properly formed unix path socket string.'''

return 'unix:%s' % path

def _acceptfun(s, fun):
while True:
sock = s.accept()
def _parsesockstr(sockstr):
proto, rem = sockstr.split(':', 1)

return proto, rem

async def connectsockstr(sockstr):
proto, rem = _parsesockstr(sockstr)

fun(*sock)
reader, writer = await asyncio.open_unix_connection(rem)

def listensocket(sockstr, fun):
'''Listen for connections on sockstr. When ever a connection
is accepted, the parameter fun is called with the socket and
the from address. The return will be a Thread object. Note
that fun MUST NOT block, as if it does, it will stop accepting
other connections.
return reader, writer

async def listensockstr(sockstr, cb):
'''Wrapper for asyncio.start_x_server.

The format of sockstr is: 'proto:param=value[,param2=value2]'.
If the proto has a default parameter, the value can be used
@@ -43,29 +52,123 @@ def listensocket(sockstr, fun):
slash if it is used as a default parameter.
'''

proto, rem = sockstr.split(':', 1)
proto, rem = _parsesockstr(sockstr)

s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.bind(rem)
s.listen(-1)
server = await asyncio.start_unix_server(cb, path=rem)

thr = threading.Thread(target=_acceptfun, name='accept thread: %s' % repr(sockstr), args=(s, fun))
thr.setDaemon(True)
return server

thr.start()
# !!python makemessagelengths.py
_handshakelens = \
[72, 72, 88]

return thr
def _genciphfun(hash, ad):
hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)

class NoiseForwarder(object):
def __init__(self, mode, sock, ):
nf = NoiseForwarder('resp', self.server_key_pair[1], ssock, pttarg)
pass
key = hkdf.derive(hash)
cipher = Cipher(algorithms.AES(key), modes.ECB(),
backend=_backend)
enctor = cipher.encryptor()

def encfun(data):
# Returns the two bytes for length
val = len(data)
encbytes = enctor.update(data[:16])
mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff

return (val ^ mask).to_bytes(length=2, byteorder='big')

def decfun(data):
# takes off the data and returns the total
# length
val = int.from_bytes(data[:2], byteorder='big')
encbytes = enctor.update(data[2:2 + 16])
mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff

return val ^ mask

return encfun, decfun

async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
rdr, wrr = rdrwrr

proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')

proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)

proto.set_as_responder()

proto.start_handshake()

proto.read_message(await rdr.readexactly(_handshakelens[0]))

wrr.write(proto.write_message())

proto.read_message(await rdr.readexactly(_handshakelens[2]))

if not proto.handshake_finished: # pragma: no cover
raise RuntimeError('failed to finish handshake')

# generate the keys for lengths
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')

reader, writer = await connectsockstr(ptsockstr)

async def decses():
while True:
msg = await rdr.readexactly(2 + 16)
tlen = declenfun(msg)
rmsg = await rdr.readexactly(tlen - 16)
tmsg = msg[2:] + rmsg
writer.write(proto.decrypt(tmsg))
await writer.drain()

async def encses():
while True:
ptmsg = await reader.read(65535 - 16) # largest message
encmsg = proto.encrypt(ptmsg)
wrr.write(enclenfun(encmsg))
wrr.write(encmsg)
await wrr.drain()

r = await asyncio.gather(decses(), encses(), return_exceptions=True)

print(repr(r))

return r

class TestListenSocket(unittest.TestCase):
def test_listensocket(self):
def test_listensockstr(self):
# XXX write test
pass

# https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
def async_test(f):
def wrapper(*args, **kwargs):
coro = asyncio.coroutine(f)
future = coro(*args, **kwargs)
loop = asyncio.get_event_loop()

# timeout after 2 seconds
loop.run_until_complete(asyncio.wait_for(future, 2))
return wrapper

class Tests_misc(unittest.TestCase):
def test_genciphfun(self):
enc, dec = _genciphfun(b'0' * 32, b'foobar')

msg = b'this is a bunch of data'

tb = enc(msg)

self.assertEqual(len(msg), dec(tb + msg))

for i in [ 20, 1384, 64000, 23839, 65535 ]:
msg = os.urandom(i)
self.assertEqual(len(msg), dec(enc(msg) + msg))

class Tests(unittest.TestCase):
def setUp(self):
# setup temporary directory
@@ -82,7 +185,8 @@ class Tests(unittest.TestCase):
shutil.rmtree(self.basetempdir)
self.tempdir = None

def test_server(self):
@async_test
async def test_server(self):
# Path that the server will sit on
servsockpath = os.path.join(self.tempdir, 'servsock')
servarg = _makeunix(servsockpath)
@@ -93,14 +197,17 @@ class Tests(unittest.TestCase):
# Setup pt target listener
pttarg = _makeunix(servsockpath)
ptsock = []
def ptsockaccept(sock, frm, ptsock=ptsock):
ptsock.append(sock)
def ptsockaccept(reader, writer, ptsock=ptsock):
ptsock.append((reader, writer))

# Bind to pt listener
lsock = listensocket(pttarg, ptsockaccept)
lsock = await listensockstr(pttarg, ptsockaccept)

# Setup server listener
ssock = listensocket(servarg, lambda x, y: NoiseForwarder('resp', self.server_key_pair[1], x, pttarg))
ssock = await listensockstr(servarg, lambda rdr, wrr: NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg))

# Connect to server
reader, writer = await connectsockstr(servarg)

# Create client
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
@@ -114,3 +221,70 @@ class Tests(unittest.TestCase):

# Send first message
message = proto.write_message()
self.assertEqual(len(message), _handshakelens[0])
writer.write(message)

# Get response
respmsg = await reader.readexactly(_handshakelens[1])
proto.read_message(respmsg)

# Send final reply
message = proto.write_message()
writer.write(message)

# Make sure handshake has completed
self.assertTrue(proto.handshake_finished)

# generate the keys for lengths
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')

# write a test message
ptmsg = b'this is a test message that should be a little in length'
encmsg = proto.encrypt(ptmsg)
writer.write(enclenfun(encmsg))
writer.write(encmsg)

# XXX - how to sync?
await asyncio.sleep(.1)

# read the test message
rptmsg = await ptsock[0][0].readexactly(len(ptmsg))

self.assertEqual(rptmsg, ptmsg)

# write a different message
ptmsg = os.urandom(2843)
encmsg = proto.encrypt(ptmsg)
writer.write(enclenfun(encmsg))
writer.write(encmsg)

# XXX - how to sync?
await asyncio.sleep(.1)

# read the test message
rptmsg = await ptsock[0][0].readexactly(len(ptmsg))

self.assertEqual(rptmsg, ptmsg)

# now try the other way
ptmsg = os.urandom(912)
ptsock[0][1].write(ptmsg)

# find out how much we need to read
encmsg = await reader.readexactly(2 + 16)
tlen = declenfun(encmsg)

# read the rest of the message
rencmsg = await reader.readexactly(tlen - 16)
tmsg = encmsg[2:] + rencmsg
rptmsg = proto.decrypt(tmsg)

self.assertEqual(rptmsg, ptmsg)

# shut everything down
ptsock[0][1].write_eof()
writer.write_eof()

# XXX - how to sync?
await asyncio.sleep(1)

Loading…
Cancel
Save