An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1404 lines
38 KiB

  1. from contextlib import asynccontextmanager
  2. from cryptography.hazmat.backends import default_backend
  3. from cryptography.hazmat.primitives import hashes
  4. from cryptography.hazmat.primitives import serialization
  5. from cryptography.hazmat.primitives.asymmetric import x448
  6. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  7. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  8. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  9. from noise.connection import NoiseConnection, Keypair
  10. #import tracemalloc; tracemalloc.start(100)
  11. import argparse
  12. import asyncio
  13. import base64
  14. import os.path
  15. import shutil
  16. import socket
  17. import sys
  18. import tempfile
  19. import time
  20. import threading
  21. import unittest
  22. _backend = default_backend()
  23. def loadprivkey(fname):
  24. with open(fname, encoding='ascii') as fp:
  25. data = fp.read().encode('ascii')
  26. key = load_pem_private_key(data, password=None, backend=default_backend())
  27. return key
  28. def loadprivkeyraw(fname):
  29. key = loadprivkey(fname)
  30. enc = serialization.Encoding.Raw
  31. privformat = serialization.PrivateFormat.Raw
  32. encalgo = serialization.NoEncryption()
  33. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  34. def loadpubkeyraw(fname):
  35. with open(fname, encoding='ascii') as fp:
  36. lines = fp.readlines()
  37. # XXX
  38. #self.assertEqual(len(lines), 1)
  39. keytype, keyvalue = lines[0].split()
  40. if keytype != 'ntun-x448':
  41. raise RuntimeError
  42. return base64.urlsafe_b64decode(keyvalue)
  43. class ConnectionValidator(object):
  44. '''This class is used to validate a connection, and initiate the
  45. connection that will be used.'''
  46. async def validatekey(self, hash, key):
  47. '''Validate that the key is authorized to connect. The
  48. connection hash is passed in, so that the authorizate of
  49. the key can be validated later.'''
  50. raise NotImplementedError
  51. async def getconnection(self, hash, key, **kwargs):
  52. '''Return the connection that should be used by this
  53. client.'''
  54. raise NotImplementedError
  55. class GenericConnValidator(object):
  56. '''This is a simple implementation of a ConnectionValidator that
  57. can be used w/ most cases. It checks against the list, and then
  58. calls/awaits the function provided, and returns it's value.'''
  59. def __init__(self, keys, connfun):
  60. '''The parameter keys must be an object that supports the
  61. in operators, aka contains. If the key is in the keys
  62. object, the connection will proceed.
  63. The parameter connfun must be an async function that
  64. returns a StreamReader, StreamWriter pair of the
  65. connection that they session is supposed to use.'''
  66. self._keys = keys
  67. self._connfun = connfun
  68. async def validatekey(self, hash, key):
  69. if key not in self._keys:
  70. raise ValueError('key not authorized: %s' % repr(key))
  71. async def getconnection(self, hash, key):
  72. return await self._connfun()
  73. def genkeypair():
  74. '''Generates a keypair, and returns a tuple of (public, private).
  75. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  76. key = x448.X448PrivateKey.generate()
  77. enc = serialization.Encoding.Raw
  78. pubformat = serialization.PublicFormat.Raw
  79. privformat = serialization.PrivateFormat.Raw
  80. encalgo = serialization.NoEncryption()
  81. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  82. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  83. return pub, priv
  84. def _makefut(obj):
  85. loop = asyncio.get_running_loop()
  86. fut = loop.create_future()
  87. fut.set_result(obj)
  88. return fut
  89. def _makeunix(path):
  90. '''Make a properly formed unix path socket string.'''
  91. return 'unix:%s' % path
  92. # Make sure any additions are reflected by tests in test_parsesockstr
  93. _allowedparameters = {
  94. 'unix': {
  95. 'path': str,
  96. },
  97. 'tcp': {
  98. 'host': str,
  99. 'port': int,
  100. },
  101. }
  102. def parsesockstr(sockstr):
  103. '''Parse a socket string to its parts.
  104. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  105. If the proto has a default parameter, the value can be used
  106. directly, like: 'proto:value'. This is only allowed when the
  107. value can unambiguously be determined not to be a param. If
  108. there needs to be an equals '=', then you MUST use the extended
  109. version.
  110. The characters that define 'param' must be all lower case ascii
  111. characters and may contain an underscore. The first character
  112. must not be an underscore.
  113. Supported protocols:
  114. unix:
  115. Default parameter is path.
  116. The path parameter specifies the path to the
  117. unix domain socket. The path MUST start w/ a
  118. slash if it is used as a default parameter.
  119. tcp:
  120. Default parameter is host[:port].
  121. The host parameter specifies the host, and the
  122. port parameter specifies the port of the
  123. connection.
  124. '''
  125. proto, rem = sockstr.split(':', 1)
  126. if '=' not in rem:
  127. if proto == 'unix' and rem[0] != '/':
  128. raise ValueError('bare path MUST start w/ a slash (/).')
  129. if proto == 'unix':
  130. args = { 'path': rem }
  131. else:
  132. args = dict(i.split('=', 1) for i in rem.split(','))
  133. try:
  134. allowed = _allowedparameters[proto]
  135. except KeyError:
  136. raise ValueError('unsupported proto: %s' % repr(proto))
  137. extrakeys = args.keys() - allowed.keys()
  138. if extrakeys:
  139. raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
  140. for i in args:
  141. args[i] = allowed[i](args[i])
  142. return proto, args
  143. async def connectsockstr(sockstr):
  144. '''Wrapper for asyncio.open_*_connection.'''
  145. proto, args = parsesockstr(sockstr)
  146. if proto == 'unix':
  147. fun = asyncio.open_unix_connection
  148. elif proto == 'tcp':
  149. fun = asyncio.open_connection
  150. reader, writer = await fun(**args)
  151. return reader, writer
  152. async def listensockstr(sockstr, cb):
  153. '''Wrapper for asyncio.start_x_server.
  154. For the format of sockstr, please see parsesockstr.
  155. The cb parameter is passed to asyncio's start_server or related
  156. calls. Per those docs, the cb parameter is calls or scheduled
  157. as a task when a client establishes a connection. It is called
  158. with two arguments, the reader and writer streams. For more
  159. information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
  160. '''
  161. proto, args = parsesockstr(sockstr)
  162. if proto == 'unix':
  163. fun = asyncio.start_unix_server
  164. elif proto == 'tcp':
  165. fun = asyncio.start_server
  166. return await fun(cb, **args)
  167. # !!python makemessagelengths.py
  168. _handshakelens = \
  169. [72, 72, 88]
  170. def _genciphfun(hash, ad):
  171. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  172. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  173. key = hkdf.derive(hash)
  174. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  175. backend=_backend)
  176. enctor = cipher.encryptor()
  177. def encfun(data):
  178. # Returns the two bytes for length
  179. val = len(data)
  180. encbytes = enctor.update(data[:16])
  181. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  182. return (val ^ mask).to_bytes(length=2, byteorder='big')
  183. def decfun(data):
  184. # takes off the data and returns the total
  185. # length
  186. val = int.from_bytes(data[:2], byteorder='big')
  187. encbytes = enctor.update(data[2:2 + 16])
  188. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  189. return val ^ mask
  190. return encfun, decfun
  191. async def NoiseForwarder(mode, encrdrwrr, connvalid, priv_key, pub_key=None):
  192. '''A function that forwards data between the plain text pair of
  193. streams to the encrypted session.
  194. The mode paramater must be one of 'init' or 'resp' for initiator
  195. and responder.
  196. The encrdrwrr is an await object that will return a tunle of the
  197. reader and writer streams for the encrypted side of the
  198. connection.
  199. The connvalid parameter is an instance of the ConnectionValidator
  200. class, or one that implements it's methods. The validatekey method
  201. will be passed the session hash, and the remote public key of the
  202. client or server. If the key is not authorized, an exception
  203. must be raised. Any non-exception return from the function means
  204. that the key is authorized, and that the session should continue.
  205. In the case of the initiator, the server's key will be passed,
  206. despite the fact that it was already validated by the XK Noise
  207. Protocol. This is just to keep the calling convention the same,
  208. and it supports moving to an XX protocol possibly in the future
  209. with minimal changes.
  210. Then the getconnection method will be called. It's expected
  211. return is the connection to forward the data on to. The kwargs
  212. may be used in future protocol versions to allow the client to
  213. request a specific resource, or something similar.
  214. In the case of the initiator, pub_key must be provided and will
  215. be used to authenticate the responder side of the connection.
  216. The priv_key parameter is used to authenticate this side of the
  217. session.
  218. Both priv_key and pub_key parameters must be 56 bytes. For example,
  219. the pair that is returned by genkeypair.
  220. '''
  221. # Send a protocol version so that in the future we can change how
  222. # we interface, and possibly be able to send control messages,
  223. # allow the client to pass some misc data to the callback, or to
  224. # allow a reverse tunnel, were the client talks to the server,
  225. # and waits for the server to "connect" to the client w/ a
  226. # connection, e.g. reverse tunnel out behind a nat to allow
  227. # incoming connections.
  228. #
  229. # Add protocol version to getconnection when it gets bumped
  230. #
  231. protocol_version = 0
  232. rdr, wrr = await encrdrwrr
  233. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  234. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  235. if pub_key is not None:
  236. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  237. pub_key)
  238. if mode == 'resp':
  239. proto.set_as_responder()
  240. proto.start_handshake()
  241. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  242. wrr.write(proto.write_message())
  243. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  244. elif mode == 'init':
  245. proto.set_as_initiator()
  246. proto.start_handshake()
  247. wrr.write(proto.write_message())
  248. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  249. wrr.write(proto.write_message())
  250. if not proto.handshake_finished: # pragma: no cover
  251. raise RuntimeError('failed to finish handshake')
  252. sesshash = proto.get_handshake_hash()
  253. clientkey = getattr(proto.get_keypair(Keypair.REMOTE_STATIC),
  254. 'public_bytes', None)
  255. try:
  256. await connvalid.validatekey(sesshash, clientkey)
  257. except Exception:
  258. wrr.close()
  259. raise
  260. # generate the keys for lengths
  261. # XXX - get_handshake_hash is probably not the best option, but
  262. # this is only to obscure lengths, it is not required to be secure
  263. # as the underlying NoiseProtocol securely validates everything.
  264. # It is marginally useful as writing patterns likely expose the
  265. # true length. Adding padding could marginally help w/ this.
  266. if mode == 'resp':
  267. _, declenfun = _genciphfun(sesshash, b'toresp')
  268. enclenfun, _ = _genciphfun(sesshash, b'toinit')
  269. elif mode == 'init':
  270. enclenfun, _ = _genciphfun(sesshash, b'toresp')
  271. _, declenfun = _genciphfun(sesshash, b'toinit')
  272. # protocol negotiation
  273. # send first, then wait for the response
  274. pvmsg = protocol_version.to_bytes(1, byteorder='big')
  275. encmsg = proto.encrypt(pvmsg)
  276. wrr.write(enclenfun(encmsg))
  277. wrr.write(encmsg)
  278. # get the protocol version
  279. msg = await rdr.readexactly(2 + 16)
  280. tlen = declenfun(msg)
  281. rmsg = await rdr.readexactly(tlen - 16)
  282. tmsg = msg[2:] + rmsg
  283. rpv = proto.decrypt(tmsg)
  284. rempv = int.from_bytes(rpv, byteorder='big')
  285. if rempv != protocol_version:
  286. raise RuntimeError('unsupported protovol version received: %d' %
  287. rempv)
  288. reader, writer = await connvalid.getconnection(sesshash, clientkey)
  289. async def decses():
  290. try:
  291. while True:
  292. try:
  293. msg = await rdr.readexactly(2 + 16)
  294. except asyncio.streams.IncompleteReadError:
  295. if rdr.at_eof():
  296. return 'dec'
  297. tlen = declenfun(msg)
  298. rmsg = await rdr.readexactly(tlen - 16)
  299. tmsg = msg[2:] + rmsg
  300. writer.write(proto.decrypt(tmsg))
  301. await writer.drain()
  302. #except:
  303. # import traceback
  304. # traceback.print_exc()
  305. # raise
  306. finally:
  307. try:
  308. writer.write_eof()
  309. except OSError as e:
  310. if e.errno != 57:
  311. raise
  312. async def encses():
  313. try:
  314. while True:
  315. # largest message
  316. ptmsg = await reader.read(65535 - 16)
  317. if not ptmsg:
  318. # eof
  319. return 'enc'
  320. encmsg = proto.encrypt(ptmsg)
  321. wrr.write(enclenfun(encmsg))
  322. wrr.write(encmsg)
  323. await wrr.drain()
  324. #except:
  325. # import traceback
  326. # traceback.print_exc()
  327. # raise
  328. finally:
  329. wrr.write_eof()
  330. res = await asyncio.gather(decses(), encses())
  331. await wrr.drain() # not sure if needed
  332. wrr.close()
  333. await writer.drain() # not sure if needed
  334. writer.close()
  335. return res
  336. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  337. # Slightly modified to timeout and to print trace back when canceled.
  338. # This makes it easier to figure out what "froze".
  339. def async_test(f):
  340. def wrapper(*args, **kwargs):
  341. async def tbcapture():
  342. try:
  343. return await f(*args, **kwargs)
  344. except asyncio.CancelledError as e:
  345. # if we are going to be cancelled, print out a tb
  346. import traceback
  347. traceback.print_exc()
  348. raise
  349. loop = asyncio.get_event_loop()
  350. # timeout after 4 seconds
  351. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  352. return wrapper
  353. class Tests_misc(unittest.TestCase):
  354. def setUp(self):
  355. # setup temporary directory
  356. d = os.path.realpath(tempfile.mkdtemp())
  357. self.basetempdir = d
  358. self.tempdir = os.path.join(d, 'subdir')
  359. os.mkdir(self.tempdir)
  360. os.chdir(self.tempdir)
  361. def tearDown(self):
  362. #print('td:', time.time())
  363. shutil.rmtree(self.basetempdir)
  364. self.tempdir = None
  365. def test_parsesockstr_bad(self):
  366. badstrs = [
  367. 'unix:ff',
  368. 'randomnocolon',
  369. 'unix:somethingelse=bogus',
  370. 'tcp:port=bogus',
  371. ]
  372. for i in badstrs:
  373. with self.assertRaises(ValueError,
  374. msg='Should have failed processing: %s' % repr(i)):
  375. parsesockstr(i)
  376. def test_parsesockstr(self):
  377. results = {
  378. # Not all of these are valid when passed to a *sockstr
  379. # function
  380. 'unix:/apath': ('unix', { 'path': '/apath' }),
  381. 'unix:path=apath': ('unix', { 'path': 'apath' }),
  382. 'tcp:host=apath': ('tcp', { 'host': 'apath' }),
  383. 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
  384. 'port': 5 }),
  385. }
  386. for s, r in results.items():
  387. self.assertEqual(parsesockstr(s), r)
  388. @async_test
  389. async def test_listensockstr_bad(self):
  390. with self.assertRaises(ValueError):
  391. ls = await listensockstr('bogus:some=arg', None)
  392. with self.assertRaises(ValueError):
  393. ls = await connectsockstr('bogus:some=arg')
  394. @async_test
  395. async def test_listenconnectsockstr(self):
  396. msgsent = b'this is a test message'
  397. msgrcv = b'testing message for receive'
  398. # That when a connection is received and receives and sends
  399. async def servconfhandle(rdr, wrr):
  400. msg = await rdr.readexactly(len(msgsent))
  401. self.assertEqual(msg, msgsent)
  402. #print(repr(wrr.get_extra_info('sockname')))
  403. wrr.write(msgrcv)
  404. await wrr.drain()
  405. wrr.close()
  406. return True
  407. # Test listensockstr
  408. for sstr, confun in [
  409. ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
  410. ('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
  411. ]:
  412. # that listensockstr will bind to the correct path, can call cb
  413. ls = await listensockstr(sstr, servconfhandle)
  414. # that we open a connection to the path
  415. rdr, wrr = await confun()
  416. # and send a message
  417. wrr.write(msgsent)
  418. # and receive the message
  419. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  420. self.assertEqual(rcv, msgrcv)
  421. wrr.close()
  422. # Now test that connectsockstr works similarly.
  423. rdr, wrr = await connectsockstr(sstr)
  424. # and send a message
  425. wrr.write(msgsent)
  426. # and receive the message
  427. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  428. self.assertEqual(rcv, msgrcv)
  429. wrr.close()
  430. ls.close()
  431. await ls.wait_closed()
  432. def test_genciphfun(self):
  433. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  434. msg = b'this is a bunch of data'
  435. tb = enc(msg)
  436. self.assertEqual(len(msg), dec(tb + msg))
  437. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  438. msg = os.urandom(i)
  439. self.assertEqual(len(msg), dec(enc(msg) + msg))
  440. def cmd_client(args):
  441. privkey = loadprivkeyraw(args.clientkey)
  442. pubkey = loadpubkeyraw(args.servkey)
  443. async def runnf(rdr, wrr):
  444. connval = GenericConnValidator([ pubkey ],
  445. lambda: _makefut((rdr, wrr)))
  446. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  447. a = await NoiseForwarder('init', encpair, connval,
  448. priv_key=privkey, pub_key=pubkey)
  449. # Setup client listener
  450. ssock = listensockstr(args.clientlisten, runnf)
  451. loop = asyncio.get_event_loop()
  452. obj = loop.run_until_complete(ssock)
  453. loop.run_until_complete(obj.serve_forever())
  454. def cmd_server(args):
  455. privkey = loadprivkeyraw(args.servkey)
  456. pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
  457. async def runnf(rdr, wrr):
  458. connval = GenericConnValidator(pubkeys, lambda: connectsockstr(args.servtarget))
  459. a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
  460. connval, priv_key=privkey)
  461. # Setup server listener
  462. ssock = listensockstr(args.servlisten, runnf)
  463. loop = asyncio.get_event_loop()
  464. obj = loop.run_until_complete(ssock)
  465. loop.run_until_complete(obj.serve_forever())
  466. def cmd_genkey(args):
  467. keypair = genkeypair()
  468. key = x448.X448PrivateKey.generate()
  469. # public key part
  470. enc = serialization.Encoding.Raw
  471. pubformat = serialization.PublicFormat.Raw
  472. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  473. try:
  474. fname = args.fname + '.pub'
  475. with open(fname, 'x', encoding='ascii') as fp:
  476. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  477. except FileExistsError:
  478. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  479. sys.exit(2)
  480. enc = serialization.Encoding.PEM
  481. format = serialization.PrivateFormat.PKCS8
  482. encalgo = serialization.NoEncryption()
  483. with open(args.fname, 'x', encoding='ascii') as fp:
  484. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  485. def main():
  486. parser = argparse.ArgumentParser()
  487. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  488. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  489. parser_gk.add_argument('fname', type=str, help='file name for the key')
  490. parser_gk.set_defaults(func=cmd_genkey)
  491. parser_serv = subparsers.add_parser('server', help='run a server')
  492. parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  493. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  494. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  495. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  496. parser_serv.set_defaults(func=cmd_server)
  497. parser_client = subparsers.add_parser('client', help='run a client')
  498. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  499. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  500. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  501. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  502. parser_client.set_defaults(func=cmd_client)
  503. args = parser.parse_args()
  504. try:
  505. fun = args.func
  506. except AttributeError:
  507. parser.print_usage()
  508. sys.exit(5)
  509. fun(args)
  510. if __name__ == '__main__': # pragma: no cover
  511. main()
  512. def _asyncsockpair():
  513. '''Create a pair of sockets that are bound to each other.
  514. The function will return a tuple of two coroutine's, that
  515. each, when await'ed upon, will return the reader/writer pair.'''
  516. socka, sockb = socket.socketpair()
  517. return asyncio.open_connection(sock=socka), \
  518. asyncio.open_connection(sock=sockb)
  519. async def _awaitfile(fname):
  520. while not os.path.exists(fname):
  521. await asyncio.sleep(.01)
  522. return True
  523. class TestMain(unittest.TestCase):
  524. def setUp(self):
  525. # setup temporary directory
  526. d = os.path.realpath(tempfile.mkdtemp())
  527. self.basetempdir = d
  528. self.tempdir = os.path.join(d, 'subdir')
  529. os.mkdir(self.tempdir)
  530. # Generate key pairs
  531. self.server_key_pair = genkeypair()
  532. self.client_key_pair = genkeypair()
  533. os.chdir(self.tempdir)
  534. def tearDown(self):
  535. #print('td:', time.time())
  536. shutil.rmtree(self.basetempdir)
  537. self.tempdir = None
  538. @asynccontextmanager
  539. async def run_with_args(self, *args, pipes=True):
  540. kwargs = {}
  541. if pipes:
  542. kwargs.update(dict(
  543. stdout=asyncio.subprocess.PIPE,
  544. stderr=asyncio.subprocess.PIPE))
  545. aproc = asyncio.create_subprocess_exec(sys.executable,
  546. # XXX - figure out how to add coverage data on these runs
  547. #'-m', 'coverage', 'run', '-p',
  548. __file__, *args, **kwargs)
  549. try:
  550. proc = await aproc
  551. yield proc
  552. finally:
  553. if proc.returncode is None:
  554. proc.terminate()
  555. # Make sure that process exits before continuing
  556. await proc.wait()
  557. @async_test
  558. async def test_noargs(self):
  559. async with self.run_with_args() as proc:
  560. await proc.wait()
  561. # XXX - not checking error message
  562. # And that it exited w/ the correct code
  563. self.assertEqual(proc.returncode, 5)
  564. async def genkey(self, name):
  565. async with self.run_with_args('genkey', name, pipes=False) as proc:
  566. await proc.wait()
  567. self.assertEqual(proc.returncode, 0)
  568. @async_test
  569. async def test_loadpubkey(self):
  570. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  571. await self.genkey(keypath)
  572. privkey = loadprivkey(keypath)
  573. enc = serialization.Encoding.Raw
  574. pubformat = serialization.PublicFormat.Raw
  575. pubkeybytes = privkey.public_key().public_bytes(encoding=enc,
  576. format=pubformat)
  577. pubkey = loadpubkeyraw(keypath + '.pub')
  578. self.assertEqual(pubkeybytes, pubkey)
  579. privrawkey = loadprivkeyraw(keypath)
  580. enc = serialization.Encoding.Raw
  581. privformat = serialization.PrivateFormat.Raw
  582. encalgo = serialization.NoEncryption()
  583. rprivrawkey = privkey.private_bytes(encoding=enc,
  584. format=privformat, encryption_algorithm=encalgo)
  585. self.assertEqual(rprivrawkey, privrawkey)
  586. @async_test
  587. async def test_clientkeymismatch(self):
  588. # make sure that if there's a client key mismatch, we
  589. # don't connect
  590. # Generate necessar keys
  591. servkeypath = os.path.join(self.tempdir, 'server_key')
  592. await self.genkey(servkeypath)
  593. clientkeypath = os.path.join(self.tempdir, 'client_key')
  594. await self.genkey(clientkeypath)
  595. badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
  596. await self.genkey(badclientkeypath)
  597. # forwards connectsion to this socket (created by client)
  598. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  599. ptclientstr = _makeunix(ptclientpath)
  600. # this is the socket server listen to
  601. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  602. incservstr = _makeunix(incservpath)
  603. # to this socket, opened by server
  604. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  605. servtargstr = _makeunix(servtargpath)
  606. # Setup server target listener
  607. ptsockevent = asyncio.Event()
  608. # Bind to pt listener
  609. lsock = await listensockstr(servtargstr, None)
  610. # Startup the server
  611. wserver = self.run_with_args('server',
  612. '-c', clientkeypath + '.pub',
  613. servkeypath, incservstr, servtargstr)
  614. # Startup the client with the "bad" key
  615. wclient = self.run_with_args('client', badclientkeypath,
  616. servkeypath + '.pub', ptclientstr, incservstr)
  617. async with wserver as server, wclient as client:
  618. # wait for server target to be created
  619. await _awaitfile(servtargpath)
  620. # wait for server to start
  621. await _awaitfile(incservpath)
  622. # wait for client to start
  623. await _awaitfile(ptclientpath)
  624. # Connect to the client
  625. reader, writer = await connectsockstr(ptclientstr)
  626. # XXX - this might not be the best test.
  627. with self.assertRaises(asyncio.futures.TimeoutError):
  628. # make sure that we don't get the conenction
  629. await asyncio.wait_for(ptsockevent.wait(), .5)
  630. writer.close()
  631. # Make sure that when the server is terminated
  632. server.terminate()
  633. # that it's stderr
  634. stdout, stderr = await server.communicate()
  635. #print('s:', repr((stdout, stderr)))
  636. # doesn't have an exceptions never retrieved
  637. # even the example echo server has this same leak
  638. #self.assertNotIn(b'Task exception was never retrieved', stderr)
  639. lsock.close()
  640. await lsock.wait_closed()
  641. # Kill off the client
  642. client.terminate()
  643. stdout, stderr = await client.communicate()
  644. #print('s:', repr((stdout, stderr)))
  645. # XXX - figure out how to clean up client properly
  646. @async_test
  647. async def test_end2end(self):
  648. # Generate necessar keys
  649. servkeypath = os.path.join(self.tempdir, 'server_key')
  650. await self.genkey(servkeypath)
  651. clientkeypath = os.path.join(self.tempdir, 'client_key')
  652. await self.genkey(clientkeypath)
  653. # forwards connectsion to this socket (created by client)
  654. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  655. ptclientstr = _makeunix(ptclientpath)
  656. # this is the socket server listen to
  657. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  658. incservstr = _makeunix(incservpath)
  659. # to this socket, opened by server
  660. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  661. servtargstr = _makeunix(servtargpath)
  662. # Setup server target listener
  663. ptsock = []
  664. ptsockevent = asyncio.Event()
  665. def ptsockaccept(reader, writer, ptsock=ptsock):
  666. ptsock.append((reader, writer))
  667. ptsockevent.set()
  668. # Bind to pt listener
  669. lsock = await listensockstr(servtargstr, ptsockaccept)
  670. # Startup the server
  671. wserver = self.run_with_args('server',
  672. '-c', clientkeypath + '.pub',
  673. servkeypath, incservstr, servtargstr,
  674. pipes=False)
  675. # Startup the client
  676. wclient = self.run_with_args('client',
  677. clientkeypath, servkeypath + '.pub', ptclientstr,
  678. incservstr, pipes=False)
  679. async with wserver as server, wclient as client:
  680. # wait for server target to be created
  681. await _awaitfile(servtargpath)
  682. # wait for server to start
  683. await _awaitfile(incservpath)
  684. # wait for client to start
  685. await _awaitfile(ptclientpath)
  686. # Connect to the client
  687. reader, writer = await connectsockstr(ptclientstr)
  688. # send a message
  689. ptmsg = b'this is a message for testing'
  690. writer.write(ptmsg)
  691. # make sure that we got the conenction
  692. await ptsockevent.wait()
  693. # get the connection
  694. endrdr, endwrr = ptsock[0]
  695. # make sure we can read back what we sent
  696. self.assertEqual(ptmsg,
  697. await endrdr.readexactly(len(ptmsg)))
  698. # test some additional messages
  699. for i in [ 129, 1287, 28792, 129872 ]:
  700. # in on direction
  701. msg = os.urandom(i)
  702. writer.write(msg)
  703. self.assertEqual(msg,
  704. await endrdr.readexactly(len(msg)))
  705. # and the other
  706. endwrr.write(msg)
  707. self.assertEqual(msg,
  708. await reader.readexactly(len(msg)))
  709. writer.close()
  710. endwrr.close()
  711. lsock.close()
  712. await lsock.wait_closed()
  713. # XXX - more testing that things exited properly
  714. @async_test
  715. async def test_genkey(self):
  716. # that it can generate a key
  717. async with self.run_with_args('genkey', 'somefile') as proc:
  718. await proc.wait()
  719. #print(await proc.communicate())
  720. self.assertEqual(proc.returncode, 0)
  721. with open('somefile.pub', encoding='ascii') as fp:
  722. lines = fp.readlines()
  723. self.assertEqual(len(lines), 1)
  724. keytype, keyvalue = lines[0].split()
  725. self.assertEqual(keytype, 'ntun-x448')
  726. key = x448.X448PublicKey.from_public_bytes(
  727. base64.urlsafe_b64decode(keyvalue))
  728. key = loadprivkey('somefile')
  729. self.assertIsInstance(key, x448.X448PrivateKey)
  730. # that a second call fails
  731. async with self.run_with_args('genkey', 'somefile') as proc:
  732. await proc.wait()
  733. stdoutdata, stderrdata = await proc.communicate()
  734. self.assertFalse(stdoutdata)
  735. self.assertEqual(
  736. b'failed to create somefile.pub, file exists.\n',
  737. stderrdata)
  738. # And that it exited w/ the correct code
  739. self.assertEqual(proc.returncode, 2)
  740. class TestNoiseFowarder(unittest.TestCase):
  741. def setUp(self):
  742. # setup temporary directory
  743. d = os.path.realpath(tempfile.mkdtemp())
  744. self.basetempdir = d
  745. self.tempdir = os.path.join(d, 'subdir')
  746. os.mkdir(self.tempdir)
  747. # Generate key pairs
  748. self.server_key_pair = genkeypair()
  749. self.client_key_pair = genkeypair()
  750. def tearDown(self):
  751. shutil.rmtree(self.basetempdir)
  752. self.tempdir = None
  753. @async_test
  754. async def test_clientkeymissmatch(self):
  755. # generate a key that is incorrect
  756. wrongclient_key_pair = genkeypair()
  757. # the secure socket
  758. clssockapair, clssockbpair = _asyncsockpair()
  759. reader, writer = await clssockapair
  760. # create the server
  761. servnf = asyncio.create_task(NoiseForwarder('resp',
  762. clssockbpair, GenericConnValidator([], None),
  763. priv_key=self.server_key_pair[1]))
  764. # Create client
  765. proto = NoiseConnection.from_name(
  766. b'Noise_XK_448_ChaChaPoly_SHA256')
  767. proto.set_as_initiator()
  768. # Setup wrong client key
  769. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  770. wrongclient_key_pair[1])
  771. # but the correct server key
  772. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  773. self.server_key_pair[0])
  774. proto.start_handshake()
  775. # Send first message
  776. message = proto.write_message()
  777. self.assertEqual(len(message), _handshakelens[0])
  778. writer.write(message)
  779. # Get response
  780. respmsg = await reader.readexactly(_handshakelens[1])
  781. proto.read_message(respmsg)
  782. # Send final reply
  783. message = proto.write_message()
  784. writer.write(message)
  785. # Make sure handshake has completed
  786. self.assertTrue(proto.handshake_finished)
  787. with self.assertRaises(ValueError):
  788. await servnf
  789. writer.close()
  790. @async_test
  791. async def test_server(self):
  792. # Test is plumbed:
  793. # (reader, writer) -> servsock ->
  794. # (rdr, wrr) NoiseForward (reader, writer) ->
  795. # servptsock -> (ptsock[0], ptsock[1])
  796. # Path that the server will sit on
  797. servsockpath = os.path.join(self.tempdir, 'servsock')
  798. servarg = _makeunix(servsockpath)
  799. # Path that the server will send pt data to
  800. servptpath = os.path.join(self.tempdir, 'servptsock')
  801. # Setup pt target listener
  802. pttarg = _makeunix(servptpath)
  803. ptsock = []
  804. ptsockevent = asyncio.Event()
  805. def ptsockaccept(reader, writer, ptsock=ptsock):
  806. ptsock.append((reader, writer))
  807. ptsockevent.set()
  808. # Bind to pt listener
  809. lsock = await listensockstr(pttarg, ptsockaccept)
  810. nfs = []
  811. event = asyncio.Event()
  812. async def runnf(rdr, wrr):
  813. connval = GenericConnValidator(
  814. [ self.client_key_pair[0] ],
  815. lambda: connectsockstr(pttarg))
  816. a = await NoiseForwarder('resp',
  817. _makefut((rdr, wrr)), connval,
  818. priv_key=self.server_key_pair[1])
  819. nfs.append(a)
  820. event.set()
  821. # Setup server listener
  822. ssock = await listensockstr(servarg, runnf)
  823. # Connect to server
  824. reader, writer = await connectsockstr(servarg)
  825. # Create client
  826. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  827. proto.set_as_initiator()
  828. # Setup required keys
  829. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  830. self.client_key_pair[1])
  831. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  832. self.server_key_pair[0])
  833. proto.start_handshake()
  834. # Send first message
  835. message = proto.write_message()
  836. self.assertEqual(len(message), _handshakelens[0])
  837. writer.write(message)
  838. # Get response
  839. respmsg = await reader.readexactly(_handshakelens[1])
  840. proto.read_message(respmsg)
  841. # Send final reply
  842. message = proto.write_message()
  843. writer.write(message)
  844. # Make sure handshake has completed
  845. self.assertTrue(proto.handshake_finished)
  846. # generate the keys for lengths
  847. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  848. b'toresp')
  849. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  850. b'toinit')
  851. pversion = 0
  852. # Send the protocol version string first
  853. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  854. writer.write(enclenfun(encmsg))
  855. writer.write(encmsg)
  856. # Read the peer's protocol version
  857. # find out how much we need to read
  858. encmsg = await reader.readexactly(2 + 16)
  859. tlen = declenfun(encmsg)
  860. # read the rest of the message
  861. rencmsg = await reader.readexactly(tlen - 16)
  862. tmsg = encmsg[2:] + rencmsg
  863. rptmsg = proto.decrypt(tmsg)
  864. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion)
  865. # write a test message
  866. ptmsg = b'this is a test message that should be a little in length'
  867. encmsg = proto.encrypt(ptmsg)
  868. writer.write(enclenfun(encmsg))
  869. writer.write(encmsg)
  870. # wait for the connection to arrive
  871. await ptsockevent.wait()
  872. ptreader, ptwriter = ptsock[0]
  873. # read the test message
  874. rptmsg = await ptreader.readexactly(len(ptmsg))
  875. self.assertEqual(rptmsg, ptmsg)
  876. # write a different message
  877. ptmsg = os.urandom(2843)
  878. encmsg = proto.encrypt(ptmsg)
  879. writer.write(enclenfun(encmsg))
  880. writer.write(encmsg)
  881. # read the test message
  882. rptmsg = await ptreader.readexactly(len(ptmsg))
  883. self.assertEqual(rptmsg, ptmsg)
  884. # now try the other way
  885. ptmsg = os.urandom(912)
  886. ptwriter.write(ptmsg)
  887. # find out how much we need to read
  888. encmsg = await reader.readexactly(2 + 16)
  889. tlen = declenfun(encmsg)
  890. # read the rest of the message
  891. rencmsg = await reader.readexactly(tlen - 16)
  892. tmsg = encmsg[2:] + rencmsg
  893. rptmsg = proto.decrypt(tmsg)
  894. self.assertEqual(rptmsg, ptmsg)
  895. # shut down sending
  896. writer.write_eof()
  897. # so pt reader should be shut down
  898. self.assertEqual(b'', await ptreader.read(1))
  899. self.assertTrue(ptreader.at_eof())
  900. # shut down pt
  901. ptwriter.write_eof()
  902. # make sure the enc reader is eof
  903. self.assertEqual(b'', await reader.read(1))
  904. self.assertTrue(reader.at_eof())
  905. await event.wait()
  906. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  907. writer.close()
  908. ptwriter.close()
  909. lsock.close()
  910. ssock.close()
  911. await lsock.wait_closed()
  912. await ssock.wait_closed()
  913. @async_test
  914. async def test_protocolversionmismatch(self):
  915. # make sure that if we send a future version, that we
  916. # still get a protocol version, and that the connection
  917. # is closed w/o establishing a connection to the remote
  918. # side
  919. # Test is plumbed:
  920. # (reader, writer) -> servsock ->
  921. # (rdr, wrr) NoiseForward (reader, writer) ->
  922. # servptsock -> (ptsock[0], ptsock[1])
  923. # Path that the server will sit on
  924. servsockpath = os.path.join(self.tempdir, 'servsock')
  925. servarg = _makeunix(servsockpath)
  926. # Path that the server will send pt data to
  927. servptpath = os.path.join(self.tempdir, 'servptsock')
  928. # Setup pt target listener
  929. pttarg = _makeunix(servptpath)
  930. ptsock = []
  931. ptsockevent = asyncio.Event()
  932. def ptsockaccept(reader, writer, ptsock=ptsock):
  933. ptsock.append((reader, writer))
  934. ptsockevent.set()
  935. # Bind to pt listener
  936. lsock = await listensockstr(pttarg, ptsockaccept)
  937. nfs = []
  938. event = asyncio.Event()
  939. async def runnf(rdr, wrr):
  940. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  941. try:
  942. connval = GenericConnValidator(
  943. [ self.client_key_pair[0] ],
  944. lambda: ptpairfun)
  945. a = await NoiseForwarder('resp',
  946. _makefut((rdr, wrr)), connval,
  947. priv_key=self.server_key_pair[1])
  948. except RuntimeError as e:
  949. nfs.append(e)
  950. event.set()
  951. return
  952. nfs.append(a)
  953. event.set()
  954. # Setup server listener
  955. ssock = await listensockstr(servarg, runnf)
  956. # Connect to server
  957. reader, writer = await connectsockstr(servarg)
  958. # Create client
  959. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  960. proto.set_as_initiator()
  961. # Setup required keys
  962. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  963. self.client_key_pair[1])
  964. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  965. self.server_key_pair[0])
  966. proto.start_handshake()
  967. # Send first message
  968. message = proto.write_message()
  969. self.assertEqual(len(message), _handshakelens[0])
  970. writer.write(message)
  971. # Get response
  972. respmsg = await reader.readexactly(_handshakelens[1])
  973. proto.read_message(respmsg)
  974. # Send final reply
  975. message = proto.write_message()
  976. writer.write(message)
  977. # Make sure handshake has completed
  978. self.assertTrue(proto.handshake_finished)
  979. # generate the keys for lengths
  980. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  981. b'toresp')
  982. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  983. b'toinit')
  984. pversion = 1
  985. # Send the protocol version string first
  986. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  987. writer.write(enclenfun(encmsg))
  988. writer.write(encmsg)
  989. # Read the peer's protocol version
  990. # find out how much we need to read
  991. encmsg = await reader.readexactly(2 + 16)
  992. tlen = declenfun(encmsg)
  993. # read the rest of the message
  994. rencmsg = await reader.readexactly(tlen - 16)
  995. tmsg = encmsg[2:] + rencmsg
  996. rptmsg = proto.decrypt(tmsg)
  997. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0)
  998. await event.wait()
  999. self.assertIsInstance(nfs[0], RuntimeError)
  1000. @async_test
  1001. async def test_serverclient(self):
  1002. # plumbing:
  1003. #
  1004. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  1005. #
  1006. ptcsockapair, ptcsockbpair = _asyncsockpair()
  1007. ptcareader, ptcawriter = await ptcsockapair
  1008. #ptcsockbpair passed directly
  1009. clssockapair, clssockbpair = _asyncsockpair()
  1010. #both passed directly
  1011. ptssockapair, ptssockbpair = _asyncsockpair()
  1012. #ptssockapair passed directly
  1013. ptsbreader, ptsbwriter = await ptssockbpair
  1014. validateclientside = GenericConnValidator(
  1015. [ self.server_key_pair[0] ], lambda: ptcsockbpair)
  1016. validateserverside = GenericConnValidator(
  1017. [ self.client_key_pair[0] ], lambda: ptssockapair)
  1018. clientnf = asyncio.create_task(NoiseForwarder('init',
  1019. clssockapair, validateclientside,
  1020. priv_key=self.client_key_pair[1],
  1021. pub_key=self.server_key_pair[0]))
  1022. servnf = asyncio.create_task(NoiseForwarder('resp',
  1023. clssockbpair, validateserverside,
  1024. priv_key=self.server_key_pair[1]))
  1025. # send a message
  1026. msga = os.urandom(183)
  1027. ptcawriter.write(msga)
  1028. # make sure we get the same message
  1029. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  1030. # send a second message
  1031. msga = os.urandom(2834)
  1032. ptcawriter.write(msga)
  1033. # make sure we get the same message
  1034. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  1035. # send a message larger than the block size
  1036. msga = os.urandom(103958)
  1037. ptcawriter.write(msga)
  1038. # make sure we get the same message
  1039. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  1040. # send a message the other direction
  1041. msga = os.urandom(103958)
  1042. ptsbwriter.write(msga)
  1043. # make sure we get the same message
  1044. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  1045. # close down the pt writers, the rest should follow
  1046. ptsbwriter.write_eof()
  1047. ptcawriter.write_eof()
  1048. # make sure they are closed, and there is no more data
  1049. self.assertEqual(b'', await ptsbreader.read(1))
  1050. self.assertTrue(ptsbreader.at_eof())
  1051. self.assertEqual(b'', await ptcareader.read(1))
  1052. self.assertTrue(ptcareader.at_eof())
  1053. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  1054. self.assertEqual([ 'dec', 'enc' ], await servnf)
  1055. await ptsbwriter.drain()
  1056. await ptcawriter.drain()
  1057. ptsbwriter.close()
  1058. ptcawriter.close()