|
@@ -317,6 +317,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): |
|
|
rmsg = await rdr.readexactly(tlen - 16) |
|
|
rmsg = await rdr.readexactly(tlen - 16) |
|
|
tmsg = msg[2:] + rmsg |
|
|
tmsg = msg[2:] + rmsg |
|
|
rpv = proto.decrypt(tmsg) |
|
|
rpv = proto.decrypt(tmsg) |
|
|
|
|
|
rempv = int.from_bytes(rpv, byteorder='big') |
|
|
|
|
|
|
|
|
|
|
|
if rempv != protocol_version: |
|
|
|
|
|
raise RuntimeError('unsupported protovol version received: %d' % |
|
|
|
|
|
rempv) |
|
|
|
|
|
|
|
|
async def decses(): |
|
|
async def decses(): |
|
|
try: |
|
|
try: |
|
@@ -1125,6 +1130,116 @@ class TestNoiseFowarder(unittest.TestCase): |
|
|
await lsock.wait_closed() |
|
|
await lsock.wait_closed() |
|
|
await ssock.wait_closed() |
|
|
await ssock.wait_closed() |
|
|
|
|
|
|
|
|
|
|
|
@async_test |
|
|
|
|
|
async def test_protocolversionmismatch(self): |
|
|
|
|
|
# make sure that if we send a future version, that we |
|
|
|
|
|
# still get a protocol version, and that the connection |
|
|
|
|
|
# is closed w/o establishing a connection to the remote |
|
|
|
|
|
# side |
|
|
|
|
|
|
|
|
|
|
|
# Test is plumbed: |
|
|
|
|
|
# (reader, writer) -> servsock -> |
|
|
|
|
|
# (rdr, wrr) NoiseForward (reader, writer) -> |
|
|
|
|
|
# servptsock -> (ptsock[0], ptsock[1]) |
|
|
|
|
|
# Path that the server will sit on |
|
|
|
|
|
servsockpath = os.path.join(self.tempdir, 'servsock') |
|
|
|
|
|
servarg = _makeunix(servsockpath) |
|
|
|
|
|
|
|
|
|
|
|
# Path that the server will send pt data to |
|
|
|
|
|
servptpath = os.path.join(self.tempdir, 'servptsock') |
|
|
|
|
|
|
|
|
|
|
|
# Setup pt target listener |
|
|
|
|
|
pttarg = _makeunix(servptpath) |
|
|
|
|
|
ptsock = [] |
|
|
|
|
|
ptsockevent = asyncio.Event() |
|
|
|
|
|
def ptsockaccept(reader, writer, ptsock=ptsock): |
|
|
|
|
|
ptsock.append((reader, writer)) |
|
|
|
|
|
ptsockevent.set() |
|
|
|
|
|
|
|
|
|
|
|
# Bind to pt listener |
|
|
|
|
|
lsock = await listensockstr(pttarg, ptsockaccept) |
|
|
|
|
|
|
|
|
|
|
|
nfs = [] |
|
|
|
|
|
event = asyncio.Event() |
|
|
|
|
|
|
|
|
|
|
|
async def runnf(rdr, wrr): |
|
|
|
|
|
ptpairfun = asyncio.create_task(connectsockstr(pttarg)) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
a = await NoiseForwarder('resp', |
|
|
|
|
|
_makefut((rdr, wrr)), lambda x: ptpairfun, |
|
|
|
|
|
priv_key=self.server_key_pair[1]) |
|
|
|
|
|
except RuntimeError as e: |
|
|
|
|
|
nfs.append(e) |
|
|
|
|
|
event.set() |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
nfs.append(a) |
|
|
|
|
|
event.set() |
|
|
|
|
|
|
|
|
|
|
|
# Setup server listener |
|
|
|
|
|
ssock = await listensockstr(servarg, runnf) |
|
|
|
|
|
|
|
|
|
|
|
# Connect to server |
|
|
|
|
|
reader, writer = await connectsockstr(servarg) |
|
|
|
|
|
|
|
|
|
|
|
# Create client |
|
|
|
|
|
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') |
|
|
|
|
|
proto.set_as_initiator() |
|
|
|
|
|
|
|
|
|
|
|
# Setup required keys |
|
|
|
|
|
proto.set_keypair_from_private_bytes(Keypair.STATIC, |
|
|
|
|
|
self.client_key_pair[1]) |
|
|
|
|
|
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, |
|
|
|
|
|
self.server_key_pair[0]) |
|
|
|
|
|
|
|
|
|
|
|
proto.start_handshake() |
|
|
|
|
|
|
|
|
|
|
|
# Send first message |
|
|
|
|
|
message = proto.write_message() |
|
|
|
|
|
self.assertEqual(len(message), _handshakelens[0]) |
|
|
|
|
|
writer.write(message) |
|
|
|
|
|
|
|
|
|
|
|
# Get response |
|
|
|
|
|
respmsg = await reader.readexactly(_handshakelens[1]) |
|
|
|
|
|
proto.read_message(respmsg) |
|
|
|
|
|
|
|
|
|
|
|
# Send final reply |
|
|
|
|
|
message = proto.write_message() |
|
|
|
|
|
writer.write(message) |
|
|
|
|
|
|
|
|
|
|
|
# Make sure handshake has completed |
|
|
|
|
|
self.assertTrue(proto.handshake_finished) |
|
|
|
|
|
|
|
|
|
|
|
# generate the keys for lengths |
|
|
|
|
|
enclenfun, _ = _genciphfun(proto.get_handshake_hash(), |
|
|
|
|
|
b'toresp') |
|
|
|
|
|
_, declenfun = _genciphfun(proto.get_handshake_hash(), |
|
|
|
|
|
b'toinit') |
|
|
|
|
|
|
|
|
|
|
|
pversion = 1 |
|
|
|
|
|
# Send the protocol version string first |
|
|
|
|
|
encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) |
|
|
|
|
|
writer.write(enclenfun(encmsg)) |
|
|
|
|
|
writer.write(encmsg) |
|
|
|
|
|
|
|
|
|
|
|
# Read the peer's protocol version |
|
|
|
|
|
|
|
|
|
|
|
# find out how much we need to read |
|
|
|
|
|
encmsg = await reader.readexactly(2 + 16) |
|
|
|
|
|
tlen = declenfun(encmsg) |
|
|
|
|
|
|
|
|
|
|
|
# read the rest of the message |
|
|
|
|
|
rencmsg = await reader.readexactly(tlen - 16) |
|
|
|
|
|
tmsg = encmsg[2:] + rencmsg |
|
|
|
|
|
rptmsg = proto.decrypt(tmsg) |
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0) |
|
|
|
|
|
|
|
|
|
|
|
await event.wait() |
|
|
|
|
|
|
|
|
|
|
|
self.assertIsInstance(nfs[0], RuntimeError) |
|
|
|
|
|
|
|
|
@async_test |
|
|
@async_test |
|
|
async def test_serverclient(self): |
|
|
async def test_serverclient(self): |
|
|
# plumbing: |
|
|
# plumbing: |
|
|