| @@ -241,6 +241,15 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): | |||||
| the pair that is returned by genkeypair. | the pair that is returned by genkeypair. | ||||
| ''' | ''' | ||||
| # Send a protocol version so that in the future we can change how | |||||
| # we interface, and possibly be able to send control messages, | |||||
| # allow the client to pass some misc data to the callback, or to | |||||
| # allow a reverse tunnel, were the client talks to the server, | |||||
| # and waits for the server to "connect" to the client w/ a | |||||
| # connection, e.g. reverse tunnel out behind a nat to allow | |||||
| # incoming connections. | |||||
| protocol_version = 0 | |||||
| rdr, wrr = await encrdrwrr | rdr, wrr = await encrdrwrr | ||||
| proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') | ||||
| @@ -294,6 +303,21 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None): | |||||
| enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') | enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp') | ||||
| _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') | _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit') | ||||
| # protocol negotiation | |||||
| # send first, then wait for the response | |||||
| pvmsg = protocol_version.to_bytes(1, byteorder='big') | |||||
| encmsg = proto.encrypt(pvmsg) | |||||
| wrr.write(enclenfun(encmsg)) | |||||
| wrr.write(encmsg) | |||||
| # get the protocol version | |||||
| msg = await rdr.readexactly(2 + 16) | |||||
| tlen = declenfun(msg) | |||||
| rmsg = await rdr.readexactly(tlen - 16) | |||||
| tmsg = msg[2:] + rmsg | |||||
| rpv = proto.decrypt(tmsg) | |||||
| async def decses(): | async def decses(): | ||||
| try: | try: | ||||
| while True: | while True: | ||||
| @@ -1015,6 +1039,25 @@ class TestNoiseFowarder(unittest.TestCase): | |||||
| _, declenfun = _genciphfun(proto.get_handshake_hash(), | _, declenfun = _genciphfun(proto.get_handshake_hash(), | ||||
| b'toinit') | b'toinit') | ||||
| pversion = 0 | |||||
| # Send the protocol version string first | |||||
| encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big')) | |||||
| writer.write(enclenfun(encmsg)) | |||||
| writer.write(encmsg) | |||||
| # Read the peer's protocol version | |||||
| # find out how much we need to read | |||||
| encmsg = await reader.readexactly(2 + 16) | |||||
| tlen = declenfun(encmsg) | |||||
| # read the rest of the message | |||||
| rencmsg = await reader.readexactly(tlen - 16) | |||||
| tmsg = encmsg[2:] + rencmsg | |||||
| rptmsg = proto.decrypt(tmsg) | |||||
| self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion) | |||||
| # write a test message | # write a test message | ||||
| ptmsg = b'this is a test message that should be a little in length' | ptmsg = b'this is a test message that should be a little in length' | ||||
| encmsg = proto.encrypt(ptmsg) | encmsg = proto.encrypt(ptmsg) | ||||