diff --git a/ntunnel.py b/ntunnel.py index c967d68..6ef35e4 100644 --- a/ntunnel.py +++ b/ntunnel.py @@ -317,6 +317,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): rmsg = await rdr.readexactly(tlen - 16) tmsg = msg[2:] + rmsg rpv = proto.decrypt(tmsg) + rempv = int.from_bytes(rpv, byteorder='big') + + if rempv != protocol_version: + raise RuntimeError('unsupported protovol version received: %d' % + rempv) async def decses(): try: @@ -1125,6 +1130,116 @@ class TestNoiseFowarder(unittest.TestCase): await lsock.wait_closed() await ssock.wait_closed() + @async_test + async def test_protocolversionmismatch(self): + # make sure that if we send a future version, that we + # still get a protocol version, and that the connection + # is closed w/o establishing a connection to the remote + # side + + # Test is plumbed: + # (reader, writer) -> servsock -> + # (rdr, wrr) NoiseForward (reader, writer) -> + # servptsock -> (ptsock[0], ptsock[1]) + # Path that the server will sit on + servsockpath = os.path.join(self.tempdir, 'servsock') + servarg = _makeunix(servsockpath) + + # Path that the server will send pt data to + servptpath = os.path.join(self.tempdir, 'servptsock') + + # Setup pt target listener + pttarg = _makeunix(servptpath) + ptsock = [] + ptsockevent = asyncio.Event() + def ptsockaccept(reader, writer, ptsock=ptsock): + ptsock.append((reader, writer)) + ptsockevent.set() + + # Bind to pt listener + lsock = await listensockstr(pttarg, ptsockaccept) + + nfs = [] + event = asyncio.Event() + + async def runnf(rdr, wrr): + ptpairfun = asyncio.create_task(connectsockstr(pttarg)) + + try: + a = await NoiseForwarder('resp', + _makefut((rdr, wrr)), lambda x: ptpairfun, + priv_key=self.server_key_pair[1]) + except RuntimeError as e: + nfs.append(e) + event.set() + return + + nfs.append(a) + event.set() + + # Setup server listener + ssock = await listensockstr(servarg, runnf) + + # Connect to server + reader, writer = await connectsockstr(servarg) + + # Create client + proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') + proto.set_as_initiator() + + # Setup required keys + proto.set_keypair_from_private_bytes(Keypair.STATIC, + self.client_key_pair[1]) + proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, + self.server_key_pair[0]) + + proto.start_handshake() + + # Send first message + message = proto.write_message() + self.assertEqual(len(message), _handshakelens[0]) + writer.write(message) + + # Get response + respmsg = await reader.readexactly(_handshakelens[1]) + proto.read_message(respmsg) + + # Send final reply + message = proto.write_message() + writer.write(message) + + # Make sure handshake has completed + self.assertTrue(proto.handshake_finished) + + # generate the keys for lengths + enclenfun, _ = _genciphfun(proto.get_handshake_hash(), + b'toresp') + _, declenfun = _genciphfun(proto.get_handshake_hash(), + b'toinit') + + pversion = 1 + # Send the protocol version string first + encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) + writer.write(enclenfun(encmsg)) + writer.write(encmsg) + + # Read the peer's protocol version + + # find out how much we need to read + encmsg = await reader.readexactly(2 + 16) + tlen = declenfun(encmsg) + + # read the rest of the message + rencmsg = await reader.readexactly(tlen - 16) + tmsg = encmsg[2:] + rencmsg + rptmsg = proto.decrypt(tmsg) + + self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0) + + await event.wait() + + self.assertIsInstance(nfs[0], RuntimeError) + @async_test async def test_serverclient(self): # plumbing: