Browse Source

add test to make sure that the server checks the client's key and

only allows the key that was specified
tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
ef95a6f276
1 changed files with 84 additions and 5 deletions
  1. +84
    -5
      ntunnel.py

+ 84
- 5
ntunnel.py View File

@@ -334,13 +334,17 @@ def cmd_client(args):

def cmd_server(args):
privkey = loadprivkeyraw(args.servkey)
pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]

async def runnf(rdr, wrr):
ptpairfun = asyncio.create_task(connectsockstr(args.servtarget))
async def checkclientfun(clientkey):
if clientkey not in pubkeys:
raise RuntimeError('invalid key provided')

a = await NoiseForwarder('resp',
_makefut((rdr, wrr)), lambda x: ptpairfun,
priv_key=privkey)
return await connectsockstr(args.servtarget)

a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
checkclientfun, priv_key=privkey)

# Setup server listener
ssock = listensockstr(args.servlisten, runnf)
@@ -385,7 +389,7 @@ def main():
parser_gk.set_defaults(func=cmd_genkey)

parser_serv = subparsers.add_parser('server', help='run a server')
parser_serv.add_argument('-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
parser_serv.add_argument('servkey', type=str, help='file name for the server key')
parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
@@ -501,6 +505,81 @@ class TestMain(unittest.TestCase):

self.assertEqual(rprivrawkey, privrawkey)

@async_test
async def test_clientkeymismatch(self):
# make sure that if there's a client key mismatch, we
# don't connect

# Generate necessar keys
servkeypath = os.path.join(self.tempdir, 'server_key')
await self.genkey(servkeypath)
clientkeypath = os.path.join(self.tempdir, 'client_key')
await self.genkey(clientkeypath)
badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
await self.genkey(badclientkeypath)

await asyncio.sleep(.1)

# forwards connectsion to this socket (created by client)
ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
ptclientstr = _makeunix(ptclientpath)

# this is the socket server listen to
incservpath = os.path.join(self.tempdir, 'incserv.sock')
incservstr = _makeunix(incservpath)

# to this socket, opened by server
servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
servtargstr = _makeunix(servtargpath)

# Setup server target listener
ptsock = []
ptsockevent = asyncio.Event()
def ptsockaccept(reader, writer, ptsock=ptsock):
ptsock.append((reader, writer))
ptsockevent.set()

# Bind to pt listener
lsock = await listensockstr(servtargstr, ptsockaccept)

# Startup the server
server = await self.run_with_args('server',
'-c', clientkeypath + '.pub',
servkeypath, incservstr, servtargstr)

# Startup the client with the "bad" key
client = await self.run_with_args('client',
badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr)

# wait for server target to be created
await _awaitfile(servtargpath)

# wait for server to start
await _awaitfile(incservpath)

# wait for client to start
await _awaitfile(ptclientpath)

await asyncio.sleep(.1)

# Connect to the client
reader, writer = await connectsockstr(ptclientstr)

with self.assertRaises(asyncio.futures.TimeoutError):
# make sure that we don't get the conenction
await asyncio.wait_for(ptsockevent.wait(), 1)

# Make sure that when the server is terminated
server.terminate()

# that it's stderr
stdout, stderr = await server.communicate()
#print('s:', repr((stdout, stderr)))

# doesn't have an exceptions never retrieved
# even the example echo server has this same leak
#self.assertNotIn(b'Task exception was never retrieved', stderr)

@async_test
async def test_end2end(self):
# Generate necessar keys


Loading…
Cancel
Save