Browse Source

add protocol version checking, so we fail if the version mismatches...

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
e79d8273f4
1 changed files with 115 additions and 0 deletions
  1. +115
    -0
      ntunnel.py

+ 115
- 0
ntunnel.py View File

@@ -317,6 +317,11 @@ async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
rmsg = await rdr.readexactly(tlen - 16) rmsg = await rdr.readexactly(tlen - 16)
tmsg = msg[2:] + rmsg tmsg = msg[2:] + rmsg
rpv = proto.decrypt(tmsg) rpv = proto.decrypt(tmsg)
rempv = int.from_bytes(rpv, byteorder='big')

if rempv != protocol_version:
raise RuntimeError('unsupported protovol version received: %d' %
rempv)


async def decses(): async def decses():
try: try:
@@ -1125,6 +1130,116 @@ class TestNoiseFowarder(unittest.TestCase):
await lsock.wait_closed() await lsock.wait_closed()
await ssock.wait_closed() await ssock.wait_closed()


@async_test
async def test_protocolversionmismatch(self):
# make sure that if we send a future version, that we
# still get a protocol version, and that the connection
# is closed w/o establishing a connection to the remote
# side

# Test is plumbed:
# (reader, writer) -> servsock ->
# (rdr, wrr) NoiseForward (reader, writer) ->
# servptsock -> (ptsock[0], ptsock[1])
# Path that the server will sit on
servsockpath = os.path.join(self.tempdir, 'servsock')
servarg = _makeunix(servsockpath)

# Path that the server will send pt data to
servptpath = os.path.join(self.tempdir, 'servptsock')

# Setup pt target listener
pttarg = _makeunix(servptpath)
ptsock = []
ptsockevent = asyncio.Event()
def ptsockaccept(reader, writer, ptsock=ptsock):
ptsock.append((reader, writer))
ptsockevent.set()

# Bind to pt listener
lsock = await listensockstr(pttarg, ptsockaccept)

nfs = []
event = asyncio.Event()

async def runnf(rdr, wrr):
ptpairfun = asyncio.create_task(connectsockstr(pttarg))

try:
a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), lambda x: ptpairfun,
priv_key=self.server_key_pair[1])
except RuntimeError as e:
nfs.append(e)
event.set()
return

nfs.append(a)
event.set()

# Setup server listener
ssock = await listensockstr(servarg, runnf)

# Connect to server
reader, writer = await connectsockstr(servarg)

# Create client
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
proto.set_as_initiator()

# Setup required keys
proto.set_keypair_from_private_bytes(Keypair.STATIC,
self.client_key_pair[1])
proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
self.server_key_pair[0])

proto.start_handshake()

# Send first message
message = proto.write_message()
self.assertEqual(len(message), _handshakelens[0])
writer.write(message)

# Get response
respmsg = await reader.readexactly(_handshakelens[1])
proto.read_message(respmsg)

# Send final reply
message = proto.write_message()
writer.write(message)

# Make sure handshake has completed
self.assertTrue(proto.handshake_finished)

# generate the keys for lengths
enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
b'toresp')
_, declenfun = _genciphfun(proto.get_handshake_hash(),
b'toinit')

pversion = 1
# Send the protocol version string first
encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
writer.write(enclenfun(encmsg))
writer.write(encmsg)

# Read the peer's protocol version

# find out how much we need to read
encmsg = await reader.readexactly(2 + 16)
tlen = declenfun(encmsg)

# read the rest of the message
rencmsg = await reader.readexactly(tlen - 16)
tmsg = encmsg[2:] + rencmsg
rptmsg = proto.decrypt(tmsg)

self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0)

await event.wait()

self.assertIsInstance(nfs[0], RuntimeError)

@async_test @async_test
async def test_serverclient(self): async def test_serverclient(self):
# plumbing: # plumbing:


Loading…
Cancel
Save