An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1329 lines
36 KiB

  1. from cryptography.hazmat.backends import default_backend
  2. from cryptography.hazmat.primitives import hashes
  3. from cryptography.hazmat.primitives import serialization
  4. from cryptography.hazmat.primitives.asymmetric import x448
  5. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  6. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  7. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  8. from noise.connection import NoiseConnection, Keypair
  9. #import tracemalloc; tracemalloc.start(100)
  10. import argparse
  11. import asyncio
  12. import base64
  13. import os.path
  14. import shutil
  15. import socket
  16. import sys
  17. import tempfile
  18. import time
  19. import threading
  20. import unittest
  21. _backend = default_backend()
  22. def loadprivkey(fname):
  23. with open(fname, encoding='ascii') as fp:
  24. data = fp.read().encode('ascii')
  25. key = load_pem_private_key(data, password=None, backend=default_backend())
  26. return key
  27. def loadprivkeyraw(fname):
  28. key = loadprivkey(fname)
  29. enc = serialization.Encoding.Raw
  30. privformat = serialization.PrivateFormat.Raw
  31. encalgo = serialization.NoEncryption()
  32. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  33. def loadpubkeyraw(fname):
  34. with open(fname, encoding='ascii') as fp:
  35. lines = fp.readlines()
  36. # XXX
  37. #self.assertEqual(len(lines), 1)
  38. keytype, keyvalue = lines[0].split()
  39. if keytype != 'ntun-x448':
  40. raise RuntimeError
  41. return base64.urlsafe_b64decode(keyvalue)
  42. def genkeypair():
  43. '''Generates a keypair, and returns a tuple of (public, private).
  44. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  45. key = x448.X448PrivateKey.generate()
  46. enc = serialization.Encoding.Raw
  47. pubformat = serialization.PublicFormat.Raw
  48. privformat = serialization.PrivateFormat.Raw
  49. encalgo = serialization.NoEncryption()
  50. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  51. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  52. return pub, priv
  53. def _makefut(obj):
  54. loop = asyncio.get_running_loop()
  55. fut = loop.create_future()
  56. fut.set_result(obj)
  57. return fut
  58. def _makeunix(path):
  59. '''Make a properly formed unix path socket string.'''
  60. return 'unix:%s' % path
  61. # Make sure any additions are reflected by tests in test_parsesockstr
  62. _allowedparameters = {
  63. 'unix': {
  64. 'path': str,
  65. },
  66. 'tcp': {
  67. 'host': str,
  68. 'port': int,
  69. },
  70. }
  71. def parsesockstr(sockstr):
  72. '''Parse a socket string to its parts.
  73. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  74. If the proto has a default parameter, the value can be used
  75. directly, like: 'proto:value'. This is only allowed when the
  76. value can unambiguously be determined not to be a param. If
  77. there needs to be an equals '=', then you MUST use the extended
  78. version.
  79. The characters that define 'param' must be all lower case ascii
  80. characters and may contain an underscore. The first character
  81. must not be an underscore.
  82. Supported protocols:
  83. unix:
  84. Default parameter is path.
  85. The path parameter specifies the path to the
  86. unix domain socket. The path MUST start w/ a
  87. slash if it is used as a default parameter.
  88. tcp:
  89. Default parameter is host[:port].
  90. The host parameter specifies the host, and the
  91. port parameter specifies the port of the
  92. connection.
  93. '''
  94. proto, rem = sockstr.split(':', 1)
  95. if '=' not in rem:
  96. if proto == 'unix' and rem[0] != '/':
  97. raise ValueError('bare path MUST start w/ a slash (/).')
  98. if proto == 'unix':
  99. args = { 'path': rem }
  100. else:
  101. args = dict(i.split('=', 1) for i in rem.split(','))
  102. try:
  103. allowed = _allowedparameters[proto]
  104. except KeyError:
  105. raise ValueError('unsupported proto: %s' % repr(proto))
  106. extrakeys = args.keys() - allowed.keys()
  107. if extrakeys:
  108. raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
  109. for i in args:
  110. args[i] = allowed[i](args[i])
  111. return proto, args
  112. async def connectsockstr(sockstr):
  113. '''Wrapper for asyncio.open_*_connection.'''
  114. proto, args = parsesockstr(sockstr)
  115. if proto == 'unix':
  116. fun = asyncio.open_unix_connection
  117. elif proto == 'tcp':
  118. fun = asyncio.open_connection
  119. reader, writer = await fun(**args)
  120. return reader, writer
  121. async def listensockstr(sockstr, cb):
  122. '''Wrapper for asyncio.start_x_server.
  123. For the format of sockstr, please see parsesockstr.
  124. The cb parameter is passed to asyncio's start_server or related
  125. calls. Per those docs, the cb parameter is calls or scheduled
  126. as a task when a client establishes a connection. It is called
  127. with two arguments, the reader and writer streams. For more
  128. information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
  129. '''
  130. proto, args = parsesockstr(sockstr)
  131. if proto == 'unix':
  132. fun = asyncio.start_unix_server
  133. elif proto == 'tcp':
  134. fun = asyncio.start_server
  135. return await fun(cb, **args)
  136. # !!python makemessagelengths.py
  137. _handshakelens = \
  138. [72, 72, 88]
  139. def _genciphfun(hash, ad):
  140. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  141. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  142. key = hkdf.derive(hash)
  143. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  144. backend=_backend)
  145. enctor = cipher.encryptor()
  146. def encfun(data):
  147. # Returns the two bytes for length
  148. val = len(data)
  149. encbytes = enctor.update(data[:16])
  150. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  151. return (val ^ mask).to_bytes(length=2, byteorder='big')
  152. def decfun(data):
  153. # takes off the data and returns the total
  154. # length
  155. val = int.from_bytes(data[:2], byteorder='big')
  156. encbytes = enctor.update(data[2:2 + 16])
  157. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  158. return val ^ mask
  159. return encfun, decfun
  160. async def NoiseForwarder(mode, encrdrwrr, ptpairfun, priv_key, pub_key=None):
  161. '''A function that forwards data between the plain text pair of
  162. streams to the encrypted session.
  163. The mode paramater must be one of 'init' or 'resp' for initiator
  164. and responder.
  165. The encrdrwrr is an await object that will return a tunle of the
  166. reader and writer streams for the encrypted side of the
  167. connection.
  168. The ptpairfun parameter is a function that will be passed the
  169. public key bytes for the remote client. This can be used to
  170. both validate that the correct client is connecting, and to
  171. pass back the correct plain text reader/writer objects that
  172. match the provided static key. The function must be an async
  173. function.
  174. In the case of the initiator, pub_key must be provided and will
  175. be used to authenticate the responder side of the connection.
  176. The priv_key parameter is used to authenticate this side of the
  177. session.
  178. Both priv_key and pub_key parameters must be 56 bytes. For example,
  179. the pair that is returned by genkeypair.
  180. '''
  181. # Send a protocol version so that in the future we can change how
  182. # we interface, and possibly be able to send control messages,
  183. # allow the client to pass some misc data to the callback, or to
  184. # allow a reverse tunnel, were the client talks to the server,
  185. # and waits for the server to "connect" to the client w/ a
  186. # connection, e.g. reverse tunnel out behind a nat to allow
  187. # incoming connections.
  188. protocol_version = 0
  189. rdr, wrr = await encrdrwrr
  190. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  191. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  192. if pub_key is not None:
  193. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  194. pub_key)
  195. if mode == 'resp':
  196. proto.set_as_responder()
  197. proto.start_handshake()
  198. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  199. wrr.write(proto.write_message())
  200. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  201. elif mode == 'init':
  202. proto.set_as_initiator()
  203. proto.start_handshake()
  204. wrr.write(proto.write_message())
  205. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  206. wrr.write(proto.write_message())
  207. if not proto.handshake_finished: # pragma: no cover
  208. raise RuntimeError('failed to finish handshake')
  209. try:
  210. reader, writer = await ptpairfun(getattr(proto.get_keypair(
  211. Keypair.REMOTE_STATIC), 'public_bytes', None))
  212. except:
  213. wrr.close()
  214. raise
  215. # generate the keys for lengths
  216. # XXX - get_handshake_hash is probably not the best option, but
  217. # this is only to obscure lengths, it is not required to be secure
  218. # as the underlying NoiseProtocol securely validates everything.
  219. # It is marginally useful as writing patterns likely expose the
  220. # true length. Adding padding could marginally help w/ this.
  221. if mode == 'resp':
  222. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  223. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  224. elif mode == 'init':
  225. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  226. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  227. # protocol negotiation
  228. # send first, then wait for the response
  229. pvmsg = protocol_version.to_bytes(1, byteorder='big')
  230. encmsg = proto.encrypt(pvmsg)
  231. wrr.write(enclenfun(encmsg))
  232. wrr.write(encmsg)
  233. # get the protocol version
  234. msg = await rdr.readexactly(2 + 16)
  235. tlen = declenfun(msg)
  236. rmsg = await rdr.readexactly(tlen - 16)
  237. tmsg = msg[2:] + rmsg
  238. rpv = proto.decrypt(tmsg)
  239. rempv = int.from_bytes(rpv, byteorder='big')
  240. if rempv != protocol_version:
  241. raise RuntimeError('unsupported protovol version received: %d' %
  242. rempv)
  243. async def decses():
  244. try:
  245. while True:
  246. try:
  247. msg = await rdr.readexactly(2 + 16)
  248. except asyncio.streams.IncompleteReadError:
  249. if rdr.at_eof():
  250. return 'dec'
  251. tlen = declenfun(msg)
  252. rmsg = await rdr.readexactly(tlen - 16)
  253. tmsg = msg[2:] + rmsg
  254. writer.write(proto.decrypt(tmsg))
  255. await writer.drain()
  256. #except:
  257. # import traceback
  258. # traceback.print_exc()
  259. # raise
  260. finally:
  261. try:
  262. writer.write_eof()
  263. except OSError as e:
  264. if e.errno != 57:
  265. raise
  266. async def encses():
  267. try:
  268. while True:
  269. # largest message
  270. ptmsg = await reader.read(65535 - 16)
  271. if not ptmsg:
  272. # eof
  273. return 'enc'
  274. encmsg = proto.encrypt(ptmsg)
  275. wrr.write(enclenfun(encmsg))
  276. wrr.write(encmsg)
  277. await wrr.drain()
  278. #except:
  279. # import traceback
  280. # traceback.print_exc()
  281. # raise
  282. finally:
  283. wrr.write_eof()
  284. res = await asyncio.gather(decses(), encses())
  285. await wrr.drain() # not sure if needed
  286. wrr.close()
  287. await writer.drain() # not sure if needed
  288. writer.close()
  289. return res
  290. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  291. # Slightly modified to timeout and to print trace back when canceled.
  292. # This makes it easier to figure out what "froze".
  293. def async_test(f):
  294. def wrapper(*args, **kwargs):
  295. async def tbcapture():
  296. try:
  297. return await f(*args, **kwargs)
  298. except asyncio.CancelledError as e:
  299. # if we are going to be cancelled, print out a tb
  300. import traceback
  301. traceback.print_exc()
  302. raise
  303. loop = asyncio.get_event_loop()
  304. # timeout after 4 seconds
  305. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  306. return wrapper
  307. class Tests_misc(unittest.TestCase):
  308. def setUp(self):
  309. # setup temporary directory
  310. d = os.path.realpath(tempfile.mkdtemp())
  311. self.basetempdir = d
  312. self.tempdir = os.path.join(d, 'subdir')
  313. os.mkdir(self.tempdir)
  314. os.chdir(self.tempdir)
  315. def tearDown(self):
  316. #print('td:', time.time())
  317. shutil.rmtree(self.basetempdir)
  318. self.tempdir = None
  319. def test_parsesockstr_bad(self):
  320. badstrs = [
  321. 'unix:ff',
  322. 'randomnocolon',
  323. 'unix:somethingelse=bogus',
  324. 'tcp:port=bogus',
  325. ]
  326. for i in badstrs:
  327. with self.assertRaises(ValueError,
  328. msg='Should have failed processing: %s' % repr(i)):
  329. parsesockstr(i)
  330. def test_parsesockstr(self):
  331. results = {
  332. # Not all of these are valid when passed to a *sockstr
  333. # function
  334. 'unix:/apath': ('unix', { 'path': '/apath' }),
  335. 'unix:path=apath': ('unix', { 'path': 'apath' }),
  336. 'tcp:host=apath': ('tcp', { 'host': 'apath' }),
  337. 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
  338. 'port': 5 }),
  339. }
  340. for s, r in results.items():
  341. self.assertEqual(parsesockstr(s), r)
  342. @async_test
  343. async def test_listensockstr_bad(self):
  344. with self.assertRaises(ValueError):
  345. ls = await listensockstr('bogus:some=arg', None)
  346. with self.assertRaises(ValueError):
  347. ls = await connectsockstr('bogus:some=arg')
  348. @async_test
  349. async def test_listenconnectsockstr(self):
  350. msgsent = b'this is a test message'
  351. msgrcv = b'testing message for receive'
  352. # That when a connection is received and receives and sends
  353. async def servconfhandle(rdr, wrr):
  354. msg = await rdr.readexactly(len(msgsent))
  355. self.assertEqual(msg, msgsent)
  356. #print(repr(wrr.get_extra_info('sockname')))
  357. wrr.write(msgrcv)
  358. await wrr.drain()
  359. wrr.close()
  360. return True
  361. # Test listensockstr
  362. for sstr, confun in [
  363. ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
  364. ('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
  365. ]:
  366. # that listensockstr will bind to the correct path, can call cb
  367. ls = await listensockstr(sstr, servconfhandle)
  368. # that we open a connection to the path
  369. rdr, wrr = await confun()
  370. # and send a message
  371. wrr.write(msgsent)
  372. # and receive the message
  373. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  374. self.assertEqual(rcv, msgrcv)
  375. wrr.close()
  376. # Now test that connectsockstr works similarly.
  377. rdr, wrr = await connectsockstr(sstr)
  378. # and send a message
  379. wrr.write(msgsent)
  380. # and receive the message
  381. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  382. self.assertEqual(rcv, msgrcv)
  383. wrr.close()
  384. ls.close()
  385. await ls.wait_closed()
  386. def test_genciphfun(self):
  387. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  388. msg = b'this is a bunch of data'
  389. tb = enc(msg)
  390. self.assertEqual(len(msg), dec(tb + msg))
  391. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  392. msg = os.urandom(i)
  393. self.assertEqual(len(msg), dec(enc(msg) + msg))
  394. def cmd_client(args):
  395. privkey = loadprivkeyraw(args.clientkey)
  396. pubkey = loadpubkeyraw(args.servkey)
  397. async def runnf(rdr, wrr):
  398. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  399. a = await NoiseForwarder('init',
  400. encpair, lambda x: _makefut((rdr, wrr)),
  401. priv_key=privkey, pub_key=pubkey)
  402. # Setup client listener
  403. ssock = listensockstr(args.clientlisten, runnf)
  404. loop = asyncio.get_event_loop()
  405. obj = loop.run_until_complete(ssock)
  406. loop.run_until_complete(obj.serve_forever())
  407. def cmd_server(args):
  408. privkey = loadprivkeyraw(args.servkey)
  409. pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
  410. async def runnf(rdr, wrr):
  411. async def checkclientfun(clientkey):
  412. if clientkey not in pubkeys:
  413. raise RuntimeError('invalid key provided')
  414. return await connectsockstr(args.servtarget)
  415. a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
  416. checkclientfun, priv_key=privkey)
  417. # Setup server listener
  418. ssock = listensockstr(args.servlisten, runnf)
  419. loop = asyncio.get_event_loop()
  420. obj = loop.run_until_complete(ssock)
  421. loop.run_until_complete(obj.serve_forever())
  422. def cmd_genkey(args):
  423. keypair = genkeypair()
  424. key = x448.X448PrivateKey.generate()
  425. # public key part
  426. enc = serialization.Encoding.Raw
  427. pubformat = serialization.PublicFormat.Raw
  428. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  429. try:
  430. fname = args.fname + '.pub'
  431. with open(fname, 'x', encoding='ascii') as fp:
  432. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  433. except FileExistsError:
  434. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  435. sys.exit(2)
  436. enc = serialization.Encoding.PEM
  437. format = serialization.PrivateFormat.PKCS8
  438. encalgo = serialization.NoEncryption()
  439. with open(args.fname, 'x', encoding='ascii') as fp:
  440. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  441. def main():
  442. parser = argparse.ArgumentParser()
  443. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  444. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  445. parser_gk.add_argument('fname', type=str, help='file name for the key')
  446. parser_gk.set_defaults(func=cmd_genkey)
  447. parser_serv = subparsers.add_parser('server', help='run a server')
  448. parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  449. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  450. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  451. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  452. parser_serv.set_defaults(func=cmd_server)
  453. parser_client = subparsers.add_parser('client', help='run a client')
  454. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  455. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  456. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  457. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  458. parser_client.set_defaults(func=cmd_client)
  459. args = parser.parse_args()
  460. try:
  461. fun = args.func
  462. except AttributeError:
  463. parser.print_usage()
  464. sys.exit(5)
  465. fun(args)
  466. if __name__ == '__main__': # pragma: no cover
  467. main()
  468. def _asyncsockpair():
  469. '''Create a pair of sockets that are bound to each other.
  470. The function will return a tuple of two coroutine's, that
  471. each, when await'ed upon, will return the reader/writer pair.'''
  472. socka, sockb = socket.socketpair()
  473. return asyncio.open_connection(sock=socka), \
  474. asyncio.open_connection(sock=sockb)
  475. async def _awaitfile(fname):
  476. while not os.path.exists(fname):
  477. await asyncio.sleep(.01)
  478. return True
  479. class TestMain(unittest.TestCase):
  480. def setUp(self):
  481. # setup temporary directory
  482. d = os.path.realpath(tempfile.mkdtemp())
  483. self.basetempdir = d
  484. self.tempdir = os.path.join(d, 'subdir')
  485. os.mkdir(self.tempdir)
  486. # Generate key pairs
  487. self.server_key_pair = genkeypair()
  488. self.client_key_pair = genkeypair()
  489. os.chdir(self.tempdir)
  490. def tearDown(self):
  491. #print('td:', time.time())
  492. shutil.rmtree(self.basetempdir)
  493. self.tempdir = None
  494. @async_test
  495. async def test_noargs(self):
  496. proc = await self.run_with_args()
  497. await proc.wait()
  498. # XXX - not checking error message
  499. # And that it exited w/ the correct code
  500. self.assertEqual(proc.returncode, 5)
  501. def run_with_args(self, *args, pipes=True):
  502. kwargs = {}
  503. if pipes:
  504. kwargs.update(dict(
  505. stdout=asyncio.subprocess.PIPE,
  506. stderr=asyncio.subprocess.PIPE))
  507. return asyncio.create_subprocess_exec(sys.executable,
  508. # XXX - figure out how to add coverage data on these runs
  509. #'-m', 'coverage', 'run', '-p',
  510. __file__, *args, **kwargs)
  511. async def genkey(self, name):
  512. proc = await self.run_with_args('genkey', name, pipes=False)
  513. await proc.wait()
  514. self.assertEqual(proc.returncode, 0)
  515. @async_test
  516. async def test_loadpubkey(self):
  517. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  518. await self.genkey(keypath)
  519. privkey = loadprivkey(keypath)
  520. enc = serialization.Encoding.Raw
  521. pubformat = serialization.PublicFormat.Raw
  522. pubkeybytes = privkey.public_key().public_bytes(encoding=enc, format=pubformat)
  523. pubkey = loadpubkeyraw(keypath + '.pub')
  524. self.assertEqual(pubkeybytes, pubkey)
  525. privrawkey = loadprivkeyraw(keypath)
  526. enc = serialization.Encoding.Raw
  527. privformat = serialization.PrivateFormat.Raw
  528. encalgo = serialization.NoEncryption()
  529. rprivrawkey = privkey.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  530. self.assertEqual(rprivrawkey, privrawkey)
  531. @async_test
  532. async def test_clientkeymismatch(self):
  533. # make sure that if there's a client key mismatch, we
  534. # don't connect
  535. # Generate necessar keys
  536. servkeypath = os.path.join(self.tempdir, 'server_key')
  537. await self.genkey(servkeypath)
  538. clientkeypath = os.path.join(self.tempdir, 'client_key')
  539. await self.genkey(clientkeypath)
  540. badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
  541. await self.genkey(badclientkeypath)
  542. # forwards connectsion to this socket (created by client)
  543. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  544. ptclientstr = _makeunix(ptclientpath)
  545. # this is the socket server listen to
  546. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  547. incservstr = _makeunix(incservpath)
  548. # to this socket, opened by server
  549. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  550. servtargstr = _makeunix(servtargpath)
  551. # Setup server target listener
  552. ptsockevent = asyncio.Event()
  553. # Bind to pt listener
  554. lsock = await listensockstr(servtargstr, None)
  555. # Startup the server
  556. server = await self.run_with_args('server',
  557. '-c', clientkeypath + '.pub',
  558. servkeypath, incservstr, servtargstr)
  559. # Startup the client with the "bad" key
  560. client = await self.run_with_args('client',
  561. badclientkeypath, servkeypath + '.pub', ptclientstr, incservstr)
  562. # wait for server target to be created
  563. await _awaitfile(servtargpath)
  564. # wait for server to start
  565. await _awaitfile(incservpath)
  566. # wait for client to start
  567. await _awaitfile(ptclientpath)
  568. # Connect to the client
  569. reader, writer = await connectsockstr(ptclientstr)
  570. # XXX - this might not be the best test.
  571. with self.assertRaises(asyncio.futures.TimeoutError):
  572. # make sure that we don't get the conenction
  573. await asyncio.wait_for(ptsockevent.wait(), .5)
  574. writer.close()
  575. # Make sure that when the server is terminated
  576. server.terminate()
  577. # that it's stderr
  578. stdout, stderr = await server.communicate()
  579. #print('s:', repr((stdout, stderr)))
  580. # doesn't have an exceptions never retrieved
  581. # even the example echo server has this same leak
  582. #self.assertNotIn(b'Task exception was never retrieved', stderr)
  583. lsock.close()
  584. await lsock.wait_closed()
  585. # Kill off the client
  586. client.terminate()
  587. stdout, stderr = await client.communicate()
  588. #print('s:', repr((stdout, stderr)))
  589. # XXX - figure out how to clean up client properly
  590. @async_test
  591. async def test_end2end(self):
  592. # Generate necessar keys
  593. servkeypath = os.path.join(self.tempdir, 'server_key')
  594. await self.genkey(servkeypath)
  595. clientkeypath = os.path.join(self.tempdir, 'client_key')
  596. await self.genkey(clientkeypath)
  597. # forwards connectsion to this socket (created by client)
  598. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  599. ptclientstr = _makeunix(ptclientpath)
  600. # this is the socket server listen to
  601. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  602. incservstr = _makeunix(incservpath)
  603. # to this socket, opened by server
  604. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  605. servtargstr = _makeunix(servtargpath)
  606. # Setup server target listener
  607. ptsock = []
  608. ptsockevent = asyncio.Event()
  609. def ptsockaccept(reader, writer, ptsock=ptsock):
  610. ptsock.append((reader, writer))
  611. ptsockevent.set()
  612. # Bind to pt listener
  613. lsock = await listensockstr(servtargstr, ptsockaccept)
  614. # Startup the server
  615. server = await self.run_with_args('server',
  616. '-c', clientkeypath + '.pub',
  617. servkeypath, incservstr, servtargstr,
  618. pipes=False)
  619. # Startup the client
  620. client = await self.run_with_args('client',
  621. clientkeypath, servkeypath + '.pub', ptclientstr, incservstr,
  622. pipes=False)
  623. # wait for server target to be created
  624. await _awaitfile(servtargpath)
  625. # wait for server to start
  626. await _awaitfile(incservpath)
  627. # wait for client to start
  628. await _awaitfile(ptclientpath)
  629. # Connect to the client
  630. reader, writer = await connectsockstr(ptclientstr)
  631. # send a message
  632. ptmsg = b'this is a message for testing'
  633. writer.write(ptmsg)
  634. # make sure that we got the conenction
  635. await ptsockevent.wait()
  636. # get the connection
  637. endrdr, endwrr = ptsock[0]
  638. # make sure we can read back what we sent
  639. self.assertEqual(ptmsg, await endrdr.readexactly(len(ptmsg)))
  640. # test some additional messages
  641. for i in [ 129, 1287, 28792, 129872 ]:
  642. # in on direction
  643. msg = os.urandom(i)
  644. writer.write(msg)
  645. self.assertEqual(msg, await endrdr.readexactly(len(msg)))
  646. # and the other
  647. endwrr.write(msg)
  648. self.assertEqual(msg, await reader.readexactly(len(msg)))
  649. writer.close()
  650. endwrr.close()
  651. lsock.close()
  652. await lsock.wait_closed()
  653. server.terminate()
  654. client.terminate()
  655. # XXX - more clean up testing
  656. @async_test
  657. async def test_genkey(self):
  658. # that it can generate a key
  659. proc = await self.run_with_args('genkey', 'somefile')
  660. await proc.wait()
  661. #print(await proc.communicate())
  662. self.assertEqual(proc.returncode, 0)
  663. with open('somefile.pub', encoding='ascii') as fp:
  664. lines = fp.readlines()
  665. self.assertEqual(len(lines), 1)
  666. keytype, keyvalue = lines[0].split()
  667. self.assertEqual(keytype, 'ntun-x448')
  668. key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
  669. key = loadprivkey('somefile')
  670. self.assertIsInstance(key, x448.X448PrivateKey)
  671. # that a second call fails
  672. proc = await self.run_with_args('genkey', 'somefile')
  673. await proc.wait()
  674. stdoutdata, stderrdata = await proc.communicate()
  675. self.assertFalse(stdoutdata)
  676. self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
  677. # And that it exited w/ the correct code
  678. self.assertEqual(proc.returncode, 2)
  679. class TestNoiseFowarder(unittest.TestCase):
  680. def setUp(self):
  681. # setup temporary directory
  682. d = os.path.realpath(tempfile.mkdtemp())
  683. self.basetempdir = d
  684. self.tempdir = os.path.join(d, 'subdir')
  685. os.mkdir(self.tempdir)
  686. # Generate key pairs
  687. self.server_key_pair = genkeypair()
  688. self.client_key_pair = genkeypair()
  689. def tearDown(self):
  690. shutil.rmtree(self.basetempdir)
  691. self.tempdir = None
  692. @async_test
  693. async def test_clientkeymissmatch(self):
  694. # generate a key that is incorrect
  695. wrongclient_key_pair = genkeypair()
  696. # the secure socket
  697. clssockapair, clssockbpair = _asyncsockpair()
  698. reader, writer = await clssockapair
  699. async def wrongkey(v):
  700. raise ValueError('no key matches')
  701. # create the server
  702. servnf = asyncio.create_task(NoiseForwarder('resp',
  703. clssockbpair, wrongkey,
  704. priv_key=self.server_key_pair[1]))
  705. # Create client
  706. proto = NoiseConnection.from_name(
  707. b'Noise_XK_448_ChaChaPoly_SHA256')
  708. proto.set_as_initiator()
  709. # Setup wrong client key
  710. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  711. wrongclient_key_pair[1])
  712. # but the correct server key
  713. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  714. self.server_key_pair[0])
  715. proto.start_handshake()
  716. # Send first message
  717. message = proto.write_message()
  718. self.assertEqual(len(message), _handshakelens[0])
  719. writer.write(message)
  720. # Get response
  721. respmsg = await reader.readexactly(_handshakelens[1])
  722. proto.read_message(respmsg)
  723. # Send final reply
  724. message = proto.write_message()
  725. writer.write(message)
  726. # Make sure handshake has completed
  727. self.assertTrue(proto.handshake_finished)
  728. with self.assertRaises(ValueError):
  729. await servnf
  730. writer.close()
  731. @async_test
  732. async def test_server(self):
  733. # Test is plumbed:
  734. # (reader, writer) -> servsock ->
  735. # (rdr, wrr) NoiseForward (reader, writer) ->
  736. # servptsock -> (ptsock[0], ptsock[1])
  737. # Path that the server will sit on
  738. servsockpath = os.path.join(self.tempdir, 'servsock')
  739. servarg = _makeunix(servsockpath)
  740. # Path that the server will send pt data to
  741. servptpath = os.path.join(self.tempdir, 'servptsock')
  742. # Setup pt target listener
  743. pttarg = _makeunix(servptpath)
  744. ptsock = []
  745. ptsockevent = asyncio.Event()
  746. def ptsockaccept(reader, writer, ptsock=ptsock):
  747. ptsock.append((reader, writer))
  748. ptsockevent.set()
  749. # Bind to pt listener
  750. lsock = await listensockstr(pttarg, ptsockaccept)
  751. nfs = []
  752. event = asyncio.Event()
  753. async def runnf(rdr, wrr):
  754. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  755. a = await NoiseForwarder('resp',
  756. _makefut((rdr, wrr)), lambda x: ptpairfun,
  757. priv_key=self.server_key_pair[1])
  758. nfs.append(a)
  759. event.set()
  760. # Setup server listener
  761. ssock = await listensockstr(servarg, runnf)
  762. # Connect to server
  763. reader, writer = await connectsockstr(servarg)
  764. # Create client
  765. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  766. proto.set_as_initiator()
  767. # Setup required keys
  768. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  769. self.client_key_pair[1])
  770. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  771. self.server_key_pair[0])
  772. proto.start_handshake()
  773. # Send first message
  774. message = proto.write_message()
  775. self.assertEqual(len(message), _handshakelens[0])
  776. writer.write(message)
  777. # Get response
  778. respmsg = await reader.readexactly(_handshakelens[1])
  779. proto.read_message(respmsg)
  780. # Send final reply
  781. message = proto.write_message()
  782. writer.write(message)
  783. # Make sure handshake has completed
  784. self.assertTrue(proto.handshake_finished)
  785. # generate the keys for lengths
  786. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  787. b'toresp')
  788. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  789. b'toinit')
  790. pversion = 0
  791. # Send the protocol version string first
  792. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  793. writer.write(enclenfun(encmsg))
  794. writer.write(encmsg)
  795. # Read the peer's protocol version
  796. # find out how much we need to read
  797. encmsg = await reader.readexactly(2 + 16)
  798. tlen = declenfun(encmsg)
  799. # read the rest of the message
  800. rencmsg = await reader.readexactly(tlen - 16)
  801. tmsg = encmsg[2:] + rencmsg
  802. rptmsg = proto.decrypt(tmsg)
  803. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion)
  804. # write a test message
  805. ptmsg = b'this is a test message that should be a little in length'
  806. encmsg = proto.encrypt(ptmsg)
  807. writer.write(enclenfun(encmsg))
  808. writer.write(encmsg)
  809. # wait for the connection to arrive
  810. await ptsockevent.wait()
  811. ptreader, ptwriter = ptsock[0]
  812. # read the test message
  813. rptmsg = await ptreader.readexactly(len(ptmsg))
  814. self.assertEqual(rptmsg, ptmsg)
  815. # write a different message
  816. ptmsg = os.urandom(2843)
  817. encmsg = proto.encrypt(ptmsg)
  818. writer.write(enclenfun(encmsg))
  819. writer.write(encmsg)
  820. # read the test message
  821. rptmsg = await ptreader.readexactly(len(ptmsg))
  822. self.assertEqual(rptmsg, ptmsg)
  823. # now try the other way
  824. ptmsg = os.urandom(912)
  825. ptwriter.write(ptmsg)
  826. # find out how much we need to read
  827. encmsg = await reader.readexactly(2 + 16)
  828. tlen = declenfun(encmsg)
  829. # read the rest of the message
  830. rencmsg = await reader.readexactly(tlen - 16)
  831. tmsg = encmsg[2:] + rencmsg
  832. rptmsg = proto.decrypt(tmsg)
  833. self.assertEqual(rptmsg, ptmsg)
  834. # shut down sending
  835. writer.write_eof()
  836. # so pt reader should be shut down
  837. self.assertEqual(b'', await ptreader.read(1))
  838. self.assertTrue(ptreader.at_eof())
  839. # shut down pt
  840. ptwriter.write_eof()
  841. # make sure the enc reader is eof
  842. self.assertEqual(b'', await reader.read(1))
  843. self.assertTrue(reader.at_eof())
  844. await event.wait()
  845. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  846. writer.close()
  847. ptwriter.close()
  848. lsock.close()
  849. ssock.close()
  850. await lsock.wait_closed()
  851. await ssock.wait_closed()
  852. @async_test
  853. async def test_protocolversionmismatch(self):
  854. # make sure that if we send a future version, that we
  855. # still get a protocol version, and that the connection
  856. # is closed w/o establishing a connection to the remote
  857. # side
  858. # Test is plumbed:
  859. # (reader, writer) -> servsock ->
  860. # (rdr, wrr) NoiseForward (reader, writer) ->
  861. # servptsock -> (ptsock[0], ptsock[1])
  862. # Path that the server will sit on
  863. servsockpath = os.path.join(self.tempdir, 'servsock')
  864. servarg = _makeunix(servsockpath)
  865. # Path that the server will send pt data to
  866. servptpath = os.path.join(self.tempdir, 'servptsock')
  867. # Setup pt target listener
  868. pttarg = _makeunix(servptpath)
  869. ptsock = []
  870. ptsockevent = asyncio.Event()
  871. def ptsockaccept(reader, writer, ptsock=ptsock):
  872. ptsock.append((reader, writer))
  873. ptsockevent.set()
  874. # Bind to pt listener
  875. lsock = await listensockstr(pttarg, ptsockaccept)
  876. nfs = []
  877. event = asyncio.Event()
  878. async def runnf(rdr, wrr):
  879. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  880. try:
  881. a = await NoiseForwarder('resp',
  882. _makefut((rdr, wrr)), lambda x: ptpairfun,
  883. priv_key=self.server_key_pair[1])
  884. except RuntimeError as e:
  885. nfs.append(e)
  886. event.set()
  887. return
  888. nfs.append(a)
  889. event.set()
  890. # Setup server listener
  891. ssock = await listensockstr(servarg, runnf)
  892. # Connect to server
  893. reader, writer = await connectsockstr(servarg)
  894. # Create client
  895. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  896. proto.set_as_initiator()
  897. # Setup required keys
  898. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  899. self.client_key_pair[1])
  900. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  901. self.server_key_pair[0])
  902. proto.start_handshake()
  903. # Send first message
  904. message = proto.write_message()
  905. self.assertEqual(len(message), _handshakelens[0])
  906. writer.write(message)
  907. # Get response
  908. respmsg = await reader.readexactly(_handshakelens[1])
  909. proto.read_message(respmsg)
  910. # Send final reply
  911. message = proto.write_message()
  912. writer.write(message)
  913. # Make sure handshake has completed
  914. self.assertTrue(proto.handshake_finished)
  915. # generate the keys for lengths
  916. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  917. b'toresp')
  918. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  919. b'toinit')
  920. pversion = 1
  921. # Send the protocol version string first
  922. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  923. writer.write(enclenfun(encmsg))
  924. writer.write(encmsg)
  925. # Read the peer's protocol version
  926. # find out how much we need to read
  927. encmsg = await reader.readexactly(2 + 16)
  928. tlen = declenfun(encmsg)
  929. # read the rest of the message
  930. rencmsg = await reader.readexactly(tlen - 16)
  931. tmsg = encmsg[2:] + rencmsg
  932. rptmsg = proto.decrypt(tmsg)
  933. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0)
  934. await event.wait()
  935. self.assertIsInstance(nfs[0], RuntimeError)
  936. @async_test
  937. async def test_serverclient(self):
  938. # plumbing:
  939. #
  940. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  941. #
  942. ptcsockapair, ptcsockbpair = _asyncsockpair()
  943. ptcareader, ptcawriter = await ptcsockapair
  944. #ptcsockbpair passed directly
  945. clssockapair, clssockbpair = _asyncsockpair()
  946. #both passed directly
  947. ptssockapair, ptssockbpair = _asyncsockpair()
  948. #ptssockapair passed directly
  949. ptsbreader, ptsbwriter = await ptssockbpair
  950. async def validateclientkey(pubkey):
  951. self.assertEqual(pubkey, self.client_key_pair[0])
  952. return await ptssockapair
  953. clientnf = asyncio.create_task(NoiseForwarder('init',
  954. clssockapair, lambda x: ptcsockbpair,
  955. priv_key=self.client_key_pair[1],
  956. pub_key=self.server_key_pair[0]))
  957. servnf = asyncio.create_task(NoiseForwarder('resp',
  958. clssockbpair, validateclientkey,
  959. priv_key=self.server_key_pair[1]))
  960. # send a message
  961. msga = os.urandom(183)
  962. ptcawriter.write(msga)
  963. # make sure we get the same message
  964. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  965. # send a second message
  966. msga = os.urandom(2834)
  967. ptcawriter.write(msga)
  968. # make sure we get the same message
  969. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  970. # send a message larger than the block size
  971. msga = os.urandom(103958)
  972. ptcawriter.write(msga)
  973. # make sure we get the same message
  974. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  975. # send a message the other direction
  976. msga = os.urandom(103958)
  977. ptsbwriter.write(msga)
  978. # make sure we get the same message
  979. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  980. # close down the pt writers, the rest should follow
  981. ptsbwriter.write_eof()
  982. ptcawriter.write_eof()
  983. # make sure they are closed, and there is no more data
  984. self.assertEqual(b'', await ptsbreader.read(1))
  985. self.assertTrue(ptsbreader.at_eof())
  986. self.assertEqual(b'', await ptcareader.read(1))
  987. self.assertTrue(ptcareader.at_eof())
  988. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  989. self.assertEqual([ 'dec', 'enc' ], await servnf)
  990. await ptsbwriter.drain()
  991. await ptcawriter.drain()
  992. ptsbwriter.close()
  993. ptcawriter.close()