|
|
@@ -334,13 +334,17 @@ def cmd_client(args): |
|
|
|
|
|
|
|
def cmd_server(args): |
|
|
|
privkey = loadprivkeyraw(args.servkey) |
|
|
|
pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ] |
|
|
|
|
|
|
|
async def runnf(rdr, wrr): |
|
|
|
ptpairfun = asyncio.create_task(connectsockstr(args.servtarget)) |
|
|
|
async def checkclientfun(clientkey): |
|
|
|
if clientkey not in pubkeys: |
|
|
|
raise RuntimeError('invalid key provided') |
|
|
|
|
|
|
|
a = await NoiseForwarder('resp', |
|
|
|
_makefut((rdr, wrr)), lambda x: ptpairfun, |
|
|
|
priv_key=privkey) |
|
|
|
return await connectsockstr(args.servtarget) |
|
|
|
|
|
|
|
a = await NoiseForwarder('resp', _makefut((rdr, wrr)), |
|
|
|
checkclientfun, priv_key=privkey) |
|
|
|
|
|
|
|
# Setup server listener |
|
|
|
ssock = listensockstr(args.servlisten, runnf) |
|
|
@@ -385,7 +389,7 @@ def main(): |
|
|
|
parser_gk.set_defaults(func=cmd_genkey) |
|
|
|
|
|
|
|
parser_serv = subparsers.add_parser('server', help='run a server') |
|
|
|
parser_serv.add_argument('-c', action='append', type=str, help='file of authorized client keys, or a .pub file') |
|
|
|
parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file') |
|
|
|
parser_serv.add_argument('servkey', type=str, help='file name for the server key') |
|
|
|
parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on') |
|
|
|
parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to') |
|
|
@@ -501,6 +505,81 @@ class TestMain(unittest.TestCase): |
|
|
|
|
|
|
|
self.assertEqual(rprivrawkey, privrawkey) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_clientkeymismatch(self): |
|
|
|
# make sure that if there's a client key mismatch, we |
|
|
|
# don't connect |
|
|
|
|
|
|
|
# Generate necessar keys |
|
|
|
servkeypath = os.path.join(self.tempdir, 'server_key') |
|
|
|
await self.genkey(servkeypath) |
|
|
|
clientkeypath = os.path.join(self.tempdir, 'client_key') |
|
|
|
await self.genkey(clientkeypath) |
|
|
|
badclientkeypath = os.path.join(self.tempdir, 'badclient_key') |
|
|
|
await self.genkey(badclientkeypath) |
|
|
|
|
|
|
|
await asyncio.sleep(.1) |
|
|
|
|
|
|
|
# forwards connectsion to this socket (created by client) |
|
|
|
ptclientpath = os.path.join(self.tempdir, 'incclient.sock') |
|
|
|
ptclientstr = _makeunix(ptclientpath) |
|
|
|
|
|
|
|
# this is the socket server listen to |
|
|
|
incservpath = os.path.join(self.tempdir, 'incserv.sock') |
|
|
|
incservstr = _makeunix(incservpath) |
|
|
|
|
|
|
|
# to this socket, opened by server |
|
|
|
servtargpath = os.path.join(self.tempdir, 'servtarget.sock') |
|
|
|
servtargstr = _makeunix(servtargpath) |
|
|
|
|
|
|
|
# Setup server target listener |
|
|
|
ptsock = [] |
|
|
|
ptsockevent = asyncio.Event() |
|
|
|
def ptsockaccept(reader, writer, ptsock=ptsock): |
|
|
|
ptsock.append((reader, writer)) |
|
|
|
ptsockevent.set() |
|
|
|
|
|
|
|
# Bind to pt listener |
|
|
|
lsock = await listensockstr(servtargstr, ptsockaccept) |
|
|
|
|
|
|
|
# Startup the server |
|
|
|
server = await self.run_with_args('server', |
|
|
|
'-c', clientkeypath + '.pub', |
|
|
|
servkeypath, incservstr, servtargstr) |
|
|
|
|
|
|
|
# Startup the client with the "bad" key |
|
|
|
client = await self.run_with_args('client', |
|
|
|
badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr) |
|
|
|
|
|
|
|
# wait for server target to be created |
|
|
|
await _awaitfile(servtargpath) |
|
|
|
|
|
|
|
# wait for server to start |
|
|
|
await _awaitfile(incservpath) |
|
|
|
|
|
|
|
# wait for client to start |
|
|
|
await _awaitfile(ptclientpath) |
|
|
|
|
|
|
|
await asyncio.sleep(.1) |
|
|
|
|
|
|
|
# Connect to the client |
|
|
|
reader, writer = await connectsockstr(ptclientstr) |
|
|
|
|
|
|
|
with self.assertRaises(asyncio.futures.TimeoutError): |
|
|
|
# make sure that we don't get the conenction |
|
|
|
await asyncio.wait_for(ptsockevent.wait(), 1) |
|
|
|
|
|
|
|
# Make sure that when the server is terminated |
|
|
|
server.terminate() |
|
|
|
|
|
|
|
# that it's stderr |
|
|
|
stdout, stderr = await server.communicate() |
|
|
|
#print('s:', repr((stdout, stderr))) |
|
|
|
|
|
|
|
# doesn't have an exceptions never retrieved |
|
|
|
# even the example echo server has this same leak |
|
|
|
#self.assertNotIn(b'Task exception was never retrieved', stderr) |
|
|
|
|
|
|
|
@async_test |
|
|
|
async def test_end2end(self): |
|
|
|
# Generate necessar keys |
|
|
|