| @@ -317,6 +317,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): | |||||
| rmsg = await rdr.readexactly(tlen - 16) | rmsg = await rdr.readexactly(tlen - 16) | ||||
| tmsg = msg[2:] + rmsg | tmsg = msg[2:] + rmsg | ||||
| rpv = proto.decrypt(tmsg) | rpv = proto.decrypt(tmsg) | ||||
| rempv = int.from_bytes(rpv, byteorder='big') | |||||
| if rempv != protocol_version: | |||||
| raise RuntimeError('unsupported protovol version received: %d' % | |||||
| rempv) | |||||
| async def decses(): | async def decses(): | ||||
| try: | try: | ||||
| @@ -1125,6 +1130,116 @@ class TestNoiseFowarder(unittest.TestCase): | |||||
| await lsock.wait_closed() | await lsock.wait_closed() | ||||
| await ssock.wait_closed() | await ssock.wait_closed() | ||||
| @async_test | |||||
| async def test_protocolversionmismatch(self): | |||||
| # make sure that if we send a future version, that we | |||||
| # still get a protocol version, and that the connection | |||||
| # is closed w/o establishing a connection to the remote | |||||
| # side | |||||
| # Test is plumbed: | |||||
| # (reader, writer) -> servsock -> | |||||
| # (rdr, wrr) NoiseForward (reader, writer) -> | |||||
| # servptsock -> (ptsock[0], ptsock[1]) | |||||
| # Path that the server will sit on | |||||
| servsockpath = os.path.join(self.tempdir, 'servsock') | |||||
| servarg = _makeunix(servsockpath) | |||||
| # Path that the server will send pt data to | |||||
| servptpath = os.path.join(self.tempdir, 'servptsock') | |||||
| # Setup pt target listener | |||||
| pttarg = _makeunix(servptpath) | |||||
| ptsock = [] | |||||
| ptsockevent = asyncio.Event() | |||||
| def ptsockaccept(reader, writer, ptsock=ptsock): | |||||
| ptsock.append((reader, writer)) | |||||
| ptsockevent.set() | |||||
| # Bind to pt listener | |||||
| lsock = await listensockstr(pttarg, ptsockaccept) | |||||
| nfs = [] | |||||
| event = asyncio.Event() | |||||
| async def runnf(rdr, wrr): | |||||
| ptpairfun = asyncio.create_task(connectsockstr(pttarg)) | |||||
| try: | |||||
| a = await NoiseForwarder('resp', | |||||
| _makefut((rdr, wrr)), lambda x: ptpairfun, | |||||
| priv_key=self.server_key_pair[1]) | |||||
| except RuntimeError as e: | |||||
| nfs.append(e) | |||||
| event.set() | |||||
| return | |||||
| nfs.append(a) | |||||
| event.set() | |||||
| # Setup server listener | |||||
| ssock = await listensockstr(servarg, runnf) | |||||
| # Connect to server | |||||
| reader, writer = await connectsockstr(servarg) | |||||
| # Create client | |||||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | |||||
| proto.set_as_initiator() | |||||
| # Setup required keys | |||||
| proto.set_keypair_from_private_bytes(Keypair.STATIC, | |||||
| self.client_key_pair[1]) | |||||
| proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, | |||||
| self.server_key_pair[0]) | |||||
| proto.start_handshake() | |||||
| # Send first message | |||||
| message = proto.write_message() | |||||
| self.assertEqual(len(message), _handshakelens[0]) | |||||
| writer.write(message) | |||||
| # Get response | |||||
| respmsg = await reader.readexactly(_handshakelens[1]) | |||||
| proto.read_message(respmsg) | |||||
| # Send final reply | |||||
| message = proto.write_message() | |||||
| writer.write(message) | |||||
| # Make sure handshake has completed | |||||
| self.assertTrue(proto.handshake_finished) | |||||
| # generate the keys for lengths | |||||
| enclenfun, _ = _genciphfun(proto.get_handshake_hash(), | |||||
| b'toresp') | |||||
| _, declenfun = _genciphfun(proto.get_handshake_hash(), | |||||
| b'toinit') | |||||
| pversion = 1 | |||||
| # Send the protocol version string first | |||||
| encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) | |||||
| writer.write(enclenfun(encmsg)) | |||||
| writer.write(encmsg) | |||||
| # Read the peer's protocol version | |||||
| # find out how much we need to read | |||||
| encmsg = await reader.readexactly(2 + 16) | |||||
| tlen = declenfun(encmsg) | |||||
| # read the rest of the message | |||||
| rencmsg = await reader.readexactly(tlen - 16) | |||||
| tmsg = encmsg[2:] + rencmsg | |||||
| rptmsg = proto.decrypt(tmsg) | |||||
| self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0) | |||||
| await event.wait() | |||||
| self.assertIsInstance(nfs[0], RuntimeError) | |||||
| @async_test | @async_test | ||||
| async def test_serverclient(self): | async def test_serverclient(self): | ||||
| # plumbing: | # plumbing: | ||||