@@ -15,6 +15,13 @@ import unittest
_backend = default_backend()
_backend = default_backend()
def _makefut(obj):
loop = asyncio.get_running_loop()
fut = loop.create_future()
fut.set_result(obj)
return fut
def _makeunix(path):
def _makeunix(path):
'''Make a properly formed unix path socket string.'''
'''Make a properly formed unix path socket string.'''
@@ -91,28 +98,30 @@ def _genciphfun(hash, ad):
return encfun, decfun
return encfun, decfun
async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
rdr, wrr = rdrwrr
rdr, wrr = await rdrwrr
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
if pub_key is not None:
if pub_key is not None:
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key)
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
pub_key)
if mode == 'resp':
if mode == 'resp':
proto.set_as_responder()
proto.set_as_responder()
elif mode == 'init':
proto.set_as_initiator()
proto.start_handshake()
proto.start_handshake()
if mode == 'resp':
proto.read_message(await rdr.readexactly(_handshakelens[0]))
proto.read_message(await rdr.readexactly(_handshakelens[0]))
wrr.write(proto.write_message())
wrr.write(proto.write_message())
proto.read_message(await rdr.readexactly(_handshakelens[2]))
proto.read_message(await rdr.readexactly(_handshakelens[2]))
elif mode == 'init':
elif mode == 'init':
proto.set_as_initiator()
proto.start_handshake()
wrr.write(proto.write_message())
wrr.write(proto.write_message())
proto.read_message(await rdr.readexactly(_handshakelens[1]))
proto.read_message(await rdr.readexactly(_handshakelens[1]))
@@ -156,7 +165,8 @@ async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
async def encses():
async def encses():
try:
try:
while True:
while True:
ptmsg = await reader.read(65535 - 16) # largest message
# largest message
ptmsg = await reader.read(65535 - 16)
if not ptmsg:
if not ptmsg:
# eof
# eof
return 'enc'
return 'enc'
@@ -211,7 +221,8 @@ def _asyncsockpair():
socka, sockb = socket.socketpair()
socka, sockb = socket.socketpair()
return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb)
return asyncio.open_connection(sock=socka), \
asyncio.open_connection(sock=sockb)
class Tests(unittest.TestCase):
class Tests(unittest.TestCase):
def setUp(self):
def setUp(self):
@@ -256,7 +267,10 @@ class Tests(unittest.TestCase):
async def runnf(rdr, wrr):
async def runnf(rdr, wrr):
ptpair = asyncio.create_task(connectsockstr(pttarg))
ptpair = asyncio.create_task(connectsockstr(pttarg))
a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1])
a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), ptpair,
priv_key=self.server_key_pair[1])
nfs.append(a)
nfs.append(a)
event.set()
event.set()
@@ -272,8 +286,10 @@ class Tests(unittest.TestCase):
proto.set_as_initiator()
proto.set_as_initiator()
# Setup required keys
# Setup required keys
proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
proto.set_keypair_from_private_bytes(Keypair.STATIC,
self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
self.server_key_pair[0])
proto.start_handshake()
proto.start_handshake()
@@ -294,8 +310,10 @@ class Tests(unittest.TestCase):
self.assertTrue(proto.handshake_finished)
self.assertTrue(proto.handshake_finished)
# generate the keys for lengths
# generate the keys for lengths
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(),
b'toinit')
# write a test message
# write a test message
ptmsg = b'this is a test message that should be a little in length'
ptmsg = b'this is a test message that should be a little in length'
@@ -345,14 +363,14 @@ class Tests(unittest.TestCase):
writer.write_eof()
writer.write_eof()
# so pt reader should be shut down
# so pt reader should be shut down
r = await ptreader.read(1 )
self.assertEqual(b'', await ptreader.read(1) )
self.assertTrue(ptreader.at_eof())
self.assertTrue(ptreader.at_eof())
# shut down pt
# shut down pt
ptwriter.write_eof()
ptwriter.write_eof()
# make sure the enc reader is eof
# make sure the enc reader is eof
r = await reader.read(1 )
self.assertEqual(b'', await reader.read(1) )
self.assertTrue(reader.at_eof())
self.assertTrue(reader.at_eof())
await event.wait()
await event.wait()
@@ -370,14 +388,18 @@ class Tests(unittest.TestCase):
ptcareader, ptcawriter = await ptcsockapair
ptcareader, ptcawriter = await ptcsockapair
#ptcsockbpair passed directly
#ptcsockbpair passed directly
clssockapair, clssockbpair = _asyncsockpair()
clssockapair, clssockbpair = _asyncsockpair()
clsapair = await clssockapair
clsbpair = await clssockbpair
#both passed directly
ptssockapair, ptssockbpair = _asyncsockpair()
ptssockapair, ptssockbpair = _asyncsockpair()
#ptssockapair passed directly
#ptssockapair passed directly
ptsbreader, ptsbwriter = await ptssockbpair
ptsbreader, ptsbwriter = await ptssockbpair
clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1]))
clientnf = asyncio.create_task(NoiseForwarder('init',
clssockapair, ptcsockbpair,
priv_key=self.client_key_pair[1],
pub_key=self.server_key_pair[0]))
servnf = asyncio.create_task(NoiseForwarder('resp',
clssockbpair, ptssockapair,
priv_key=self.server_key_pair[1]))
# send a message
# send a message
msga = os.urandom(183)
msga = os.urandom(183)
@@ -411,5 +433,11 @@ class Tests(unittest.TestCase):
ptsbwriter.write_eof()
ptsbwriter.write_eof()
ptcawriter.write_eof()
ptcawriter.write_eof()
# make sure they are closed, and there is no more data
self.assertEqual(b'', await ptsbreader.read(1))
self.assertTrue(ptsbreader.at_eof())
self.assertEqual(b'', await ptcareader.read(1))
self.assertTrue(ptcareader.at_eof())
self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await clientnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)
self.assertEqual([ 'dec', 'enc' ], await servnf)