An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

1401 lines
38 KiB

  1. from contextlib import asynccontextmanager
  2. from cryptography.hazmat.backends import default_backend
  3. from cryptography.hazmat.primitives import hashes
  4. from cryptography.hazmat.primitives import serialization
  5. from cryptography.hazmat.primitives.asymmetric import x448
  6. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  7. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  8. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  9. from noise.connection import NoiseConnection, Keypair
  10. #import tracemalloc; tracemalloc.start(100)
  11. import argparse
  12. import asyncio
  13. import base64
  14. import os.path
  15. import shutil
  16. import socket
  17. import sys
  18. import tempfile
  19. import time
  20. import threading
  21. import unittest
  22. _backend = default_backend()
  23. def loadprivkey(fname):
  24. with open(fname, encoding='ascii') as fp:
  25. data = fp.read().encode('ascii')
  26. key = load_pem_private_key(data, password=None, backend=default_backend())
  27. return key
  28. def loadprivkeyraw(fname):
  29. key = loadprivkey(fname)
  30. enc = serialization.Encoding.Raw
  31. privformat = serialization.PrivateFormat.Raw
  32. encalgo = serialization.NoEncryption()
  33. return key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  34. def loadpubkeyraw(fname):
  35. with open(fname, encoding='ascii') as fp:
  36. lines = fp.readlines()
  37. # XXX
  38. #self.assertEqual(len(lines), 1)
  39. keytype, keyvalue = lines[0].split()
  40. if keytype != 'ntun-x448':
  41. raise RuntimeError
  42. return base64.urlsafe_b64decode(keyvalue)
  43. class ConnectionValidator(object):
  44. '''This class is used to validate a connection, and initiate the
  45. connection that will be used.'''
  46. async def validatekey(self, hash, key):
  47. '''Validate that the key is authorized to connect. The
  48. connection hash is passed in, so that the authorizate of
  49. the key can be validated later.'''
  50. raise NotImplementedError
  51. async def getconnection(self, hash, key, **kwargs):
  52. '''Return the connection that should be used by this
  53. client.'''
  54. raise NotImplementedError
  55. class GenericConnValidator(object):
  56. '''This is a simple implementation of a ConnectionValidator that
  57. can be used w/ most cases. It checks against the list, and then
  58. calls/awaits the function provided, and returns it's value.'''
  59. def __init__(self, keys, connfun):
  60. '''The parameter keys must be an object that supports the
  61. in operators, aka contains. If the key is in the keys
  62. object, the connection will proceed.
  63. The parameter connfun must be an async function that
  64. returns a StreamReader, StreamWriter pair of the
  65. connection that they session is supposed to use.'''
  66. self._keys = keys
  67. self._connfun = connfun
  68. async def validatekey(self, hash, key):
  69. if key not in self._keys:
  70. raise ValueError('key not authorized: %s' % repr(key))
  71. async def getconnection(self, hash, key):
  72. return await self._connfun()
  73. def genkeypair():
  74. '''Generates a keypair, and returns a tuple of (public, private).
  75. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  76. key = x448.X448PrivateKey.generate()
  77. enc = serialization.Encoding.Raw
  78. pubformat = serialization.PublicFormat.Raw
  79. privformat = serialization.PrivateFormat.Raw
  80. encalgo = serialization.NoEncryption()
  81. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  82. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  83. return pub, priv
  84. def _makefut(obj):
  85. loop = asyncio.get_running_loop()
  86. fut = loop.create_future()
  87. fut.set_result(obj)
  88. return fut
  89. def _makeunix(path):
  90. '''Make a properly formed unix path socket string.'''
  91. return 'unix:%s' % path
  92. # Make sure any additions are reflected by tests in test_parsesockstr
  93. _allowedparameters = {
  94. 'unix': {
  95. 'path': str,
  96. },
  97. 'tcp': {
  98. 'host': str,
  99. 'port': int,
  100. },
  101. }
  102. def parsesockstr(sockstr):
  103. '''Parse a socket string to its parts.
  104. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  105. If the proto has a default parameter, the value can be used
  106. directly, like: 'proto:value'. This is only allowed when the
  107. value can unambiguously be determined not to be a param. If
  108. there needs to be an equals '=', then you MUST use the extended
  109. version.
  110. The characters that define 'param' must be all lower case ascii
  111. characters and may contain an underscore. The first character
  112. must not be an underscore.
  113. Supported protocols:
  114. unix:
  115. Default parameter is path.
  116. The path parameter specifies the path to the
  117. unix domain socket. The path MUST start w/ a
  118. slash if it is used as a default parameter.
  119. tcp:
  120. Default parameter is host[:port].
  121. The host parameter specifies the host, and the
  122. port parameter specifies the port of the
  123. connection.
  124. '''
  125. proto, rem = sockstr.split(':', 1)
  126. if '=' not in rem:
  127. if proto == 'unix' and rem[0] != '/':
  128. raise ValueError('bare path MUST start w/ a slash (/).')
  129. if proto == 'unix':
  130. args = { 'path': rem }
  131. else:
  132. args = dict(i.split('=', 1) for i in rem.split(','))
  133. try:
  134. allowed = _allowedparameters[proto]
  135. except KeyError:
  136. raise ValueError('unsupported proto: %s' % repr(proto))
  137. extrakeys = args.keys() - allowed.keys()
  138. if extrakeys:
  139. raise ValueError('keys for proto %s not allowed: %s' % (repr(proto), extrakeys))
  140. for i in args:
  141. args[i] = allowed[i](args[i])
  142. return proto, args
  143. async def connectsockstr(sockstr):
  144. '''Wrapper for asyncio.open_*_connection.'''
  145. proto, args = parsesockstr(sockstr)
  146. if proto == 'unix':
  147. fun = asyncio.open_unix_connection
  148. elif proto == 'tcp':
  149. fun = asyncio.open_connection
  150. reader, writer = await fun(**args)
  151. return reader, writer
  152. async def listensockstr(sockstr, cb):
  153. '''Wrapper for asyncio.start_x_server.
  154. For the format of sockstr, please see parsesockstr.
  155. The cb parameter is passed to asyncio's start_server or related
  156. calls. Per those docs, the cb parameter is calls or scheduled
  157. as a task when a client establishes a connection. It is called
  158. with two arguments, the reader and writer streams. For more
  159. information, see: https://docs.python.org/3/library/asyncio-stream.html#asyncio.start_server
  160. '''
  161. proto, args = parsesockstr(sockstr)
  162. if proto == 'unix':
  163. fun = asyncio.start_unix_server
  164. elif proto == 'tcp':
  165. fun = asyncio.start_server
  166. return await fun(cb, **args)
  167. # !!python makemessagelengths.py
  168. _handshakelens = \
  169. [72, 72, 88]
  170. def _genciphfun(hash, ad):
  171. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  172. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  173. key = hkdf.derive(hash)
  174. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  175. backend=_backend)
  176. enctor = cipher.encryptor()
  177. def encfun(data):
  178. # Returns the two bytes for length
  179. val = len(data)
  180. encbytes = enctor.update(data[:16])
  181. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  182. return (val ^ mask).to_bytes(length=2, byteorder='big')
  183. def decfun(data):
  184. # takes off the data and returns the total
  185. # length
  186. val = int.from_bytes(data[:2], byteorder='big')
  187. encbytes = enctor.update(data[2:2 + 16])
  188. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  189. return val ^ mask
  190. return encfun, decfun
  191. async def NoiseForwarder(mode, encrdrwrr, connvalid, priv_key, pub_key=None):
  192. '''A function that forwards data between the plain text pair of
  193. streams to the encrypted session.
  194. The mode paramater must be one of 'init' or 'resp' for initiator
  195. and responder.
  196. The encrdrwrr is an await object that will return a tunle of the
  197. reader and writer streams for the encrypted side of the
  198. connection.
  199. The connvalid parameter is an instance of the ConnectionValidator
  200. class, or one that implements it's methods. The validatekey method
  201. will be passed the session hash, and the remote public key of the
  202. client or server. If the key is not authorized, an exception
  203. must be raised. Any non-exception return from the function means
  204. that the key is authorized, and that the session should continue.
  205. In the case of the initiator, the server's key will be passed,
  206. despite the fact that it was already validated by the XK Noise
  207. Protocol. This is just to keep the calling convention the same,
  208. and it supports moving to an XX protocol possibly in the future
  209. with minimal changes.
  210. Then the getconnection method will be called. It's expected
  211. return is the connection to forward the data on to. The kwargs
  212. may be used in future protocol versions to allow the client to
  213. request a specific resource, or something similar.
  214. In the case of the initiator, pub_key must be provided and will
  215. be used to authenticate the responder side of the connection.
  216. The priv_key parameter is used to authenticate this side of the
  217. session.
  218. Both priv_key and pub_key parameters must be 56 bytes. For example,
  219. the pair that is returned by genkeypair.
  220. '''
  221. # Send a protocol version so that in the future we can change how
  222. # we interface, and possibly be able to send control messages,
  223. # allow the client to pass some misc data to the callback, or to
  224. # allow a reverse tunnel, were the client talks to the server,
  225. # and waits for the server to "connect" to the client w/ a
  226. # connection, e.g. reverse tunnel out behind a nat to allow
  227. # incoming connections.
  228. protocol_version = 0
  229. rdr, wrr = await encrdrwrr
  230. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  231. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  232. if pub_key is not None:
  233. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  234. pub_key)
  235. if mode == 'resp':
  236. proto.set_as_responder()
  237. proto.start_handshake()
  238. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  239. wrr.write(proto.write_message())
  240. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  241. elif mode == 'init':
  242. proto.set_as_initiator()
  243. proto.start_handshake()
  244. wrr.write(proto.write_message())
  245. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  246. wrr.write(proto.write_message())
  247. if not proto.handshake_finished: # pragma: no cover
  248. raise RuntimeError('failed to finish handshake')
  249. sesshash = proto.get_handshake_hash()
  250. clientkey = getattr(proto.get_keypair(Keypair.REMOTE_STATIC),
  251. 'public_bytes', None)
  252. try:
  253. await connvalid.validatekey(sesshash, clientkey)
  254. except Exception:
  255. wrr.close()
  256. raise
  257. # generate the keys for lengths
  258. # XXX - get_handshake_hash is probably not the best option, but
  259. # this is only to obscure lengths, it is not required to be secure
  260. # as the underlying NoiseProtocol securely validates everything.
  261. # It is marginally useful as writing patterns likely expose the
  262. # true length. Adding padding could marginally help w/ this.
  263. if mode == 'resp':
  264. _, declenfun = _genciphfun(sesshash, b'toresp')
  265. enclenfun, _ = _genciphfun(sesshash, b'toinit')
  266. elif mode == 'init':
  267. enclenfun, _ = _genciphfun(sesshash, b'toresp')
  268. _, declenfun = _genciphfun(sesshash, b'toinit')
  269. # protocol negotiation
  270. # send first, then wait for the response
  271. pvmsg = protocol_version.to_bytes(1, byteorder='big')
  272. encmsg = proto.encrypt(pvmsg)
  273. wrr.write(enclenfun(encmsg))
  274. wrr.write(encmsg)
  275. # get the protocol version
  276. msg = await rdr.readexactly(2 + 16)
  277. tlen = declenfun(msg)
  278. rmsg = await rdr.readexactly(tlen - 16)
  279. tmsg = msg[2:] + rmsg
  280. rpv = proto.decrypt(tmsg)
  281. rempv = int.from_bytes(rpv, byteorder='big')
  282. if rempv != protocol_version:
  283. raise RuntimeError('unsupported protovol version received: %d' %
  284. rempv)
  285. reader, writer = await connvalid.getconnection(sesshash, clientkey)
  286. async def decses():
  287. try:
  288. while True:
  289. try:
  290. msg = await rdr.readexactly(2 + 16)
  291. except asyncio.streams.IncompleteReadError:
  292. if rdr.at_eof():
  293. return 'dec'
  294. tlen = declenfun(msg)
  295. rmsg = await rdr.readexactly(tlen - 16)
  296. tmsg = msg[2:] + rmsg
  297. writer.write(proto.decrypt(tmsg))
  298. await writer.drain()
  299. #except:
  300. # import traceback
  301. # traceback.print_exc()
  302. # raise
  303. finally:
  304. try:
  305. writer.write_eof()
  306. except OSError as e:
  307. if e.errno != 57:
  308. raise
  309. async def encses():
  310. try:
  311. while True:
  312. # largest message
  313. ptmsg = await reader.read(65535 - 16)
  314. if not ptmsg:
  315. # eof
  316. return 'enc'
  317. encmsg = proto.encrypt(ptmsg)
  318. wrr.write(enclenfun(encmsg))
  319. wrr.write(encmsg)
  320. await wrr.drain()
  321. #except:
  322. # import traceback
  323. # traceback.print_exc()
  324. # raise
  325. finally:
  326. wrr.write_eof()
  327. res = await asyncio.gather(decses(), encses())
  328. await wrr.drain() # not sure if needed
  329. wrr.close()
  330. await writer.drain() # not sure if needed
  331. writer.close()
  332. return res
  333. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  334. # Slightly modified to timeout and to print trace back when canceled.
  335. # This makes it easier to figure out what "froze".
  336. def async_test(f):
  337. def wrapper(*args, **kwargs):
  338. async def tbcapture():
  339. try:
  340. return await f(*args, **kwargs)
  341. except asyncio.CancelledError as e:
  342. # if we are going to be cancelled, print out a tb
  343. import traceback
  344. traceback.print_exc()
  345. raise
  346. loop = asyncio.get_event_loop()
  347. # timeout after 4 seconds
  348. loop.run_until_complete(asyncio.wait_for(tbcapture(), 4))
  349. return wrapper
  350. class Tests_misc(unittest.TestCase):
  351. def setUp(self):
  352. # setup temporary directory
  353. d = os.path.realpath(tempfile.mkdtemp())
  354. self.basetempdir = d
  355. self.tempdir = os.path.join(d, 'subdir')
  356. os.mkdir(self.tempdir)
  357. os.chdir(self.tempdir)
  358. def tearDown(self):
  359. #print('td:', time.time())
  360. shutil.rmtree(self.basetempdir)
  361. self.tempdir = None
  362. def test_parsesockstr_bad(self):
  363. badstrs = [
  364. 'unix:ff',
  365. 'randomnocolon',
  366. 'unix:somethingelse=bogus',
  367. 'tcp:port=bogus',
  368. ]
  369. for i in badstrs:
  370. with self.assertRaises(ValueError,
  371. msg='Should have failed processing: %s' % repr(i)):
  372. parsesockstr(i)
  373. def test_parsesockstr(self):
  374. results = {
  375. # Not all of these are valid when passed to a *sockstr
  376. # function
  377. 'unix:/apath': ('unix', { 'path': '/apath' }),
  378. 'unix:path=apath': ('unix', { 'path': 'apath' }),
  379. 'tcp:host=apath': ('tcp', { 'host': 'apath' }),
  380. 'tcp:host=apath,port=5': ('tcp', { 'host': 'apath',
  381. 'port': 5 }),
  382. }
  383. for s, r in results.items():
  384. self.assertEqual(parsesockstr(s), r)
  385. @async_test
  386. async def test_listensockstr_bad(self):
  387. with self.assertRaises(ValueError):
  388. ls = await listensockstr('bogus:some=arg', None)
  389. with self.assertRaises(ValueError):
  390. ls = await connectsockstr('bogus:some=arg')
  391. @async_test
  392. async def test_listenconnectsockstr(self):
  393. msgsent = b'this is a test message'
  394. msgrcv = b'testing message for receive'
  395. # That when a connection is received and receives and sends
  396. async def servconfhandle(rdr, wrr):
  397. msg = await rdr.readexactly(len(msgsent))
  398. self.assertEqual(msg, msgsent)
  399. #print(repr(wrr.get_extra_info('sockname')))
  400. wrr.write(msgrcv)
  401. await wrr.drain()
  402. wrr.close()
  403. return True
  404. # Test listensockstr
  405. for sstr, confun in [
  406. ('unix:path=ff', lambda: asyncio.open_unix_connection(path='ff')),
  407. ('tcp:port=9384', lambda: asyncio.open_connection(port=9384))
  408. ]:
  409. # that listensockstr will bind to the correct path, can call cb
  410. ls = await listensockstr(sstr, servconfhandle)
  411. # that we open a connection to the path
  412. rdr, wrr = await confun()
  413. # and send a message
  414. wrr.write(msgsent)
  415. # and receive the message
  416. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  417. self.assertEqual(rcv, msgrcv)
  418. wrr.close()
  419. # Now test that connectsockstr works similarly.
  420. rdr, wrr = await connectsockstr(sstr)
  421. # and send a message
  422. wrr.write(msgsent)
  423. # and receive the message
  424. rcv = await asyncio.wait_for(rdr.readexactly(len(msgrcv)), .5)
  425. self.assertEqual(rcv, msgrcv)
  426. wrr.close()
  427. ls.close()
  428. await ls.wait_closed()
  429. def test_genciphfun(self):
  430. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  431. msg = b'this is a bunch of data'
  432. tb = enc(msg)
  433. self.assertEqual(len(msg), dec(tb + msg))
  434. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  435. msg = os.urandom(i)
  436. self.assertEqual(len(msg), dec(enc(msg) + msg))
  437. def cmd_client(args):
  438. privkey = loadprivkeyraw(args.clientkey)
  439. pubkey = loadpubkeyraw(args.servkey)
  440. async def runnf(rdr, wrr):
  441. connval = GenericConnValidator([ pubkey ],
  442. lambda: _makefut((rdr, wrr)))
  443. encpair = asyncio.create_task(connectsockstr(args.clienttarget))
  444. a = await NoiseForwarder('init', encpair, connval,
  445. priv_key=privkey, pub_key=pubkey)
  446. # Setup client listener
  447. ssock = listensockstr(args.clientlisten, runnf)
  448. loop = asyncio.get_event_loop()
  449. obj = loop.run_until_complete(ssock)
  450. loop.run_until_complete(obj.serve_forever())
  451. def cmd_server(args):
  452. privkey = loadprivkeyraw(args.servkey)
  453. pubkeys = [ loadpubkeyraw(x) for x in args.clientkey ]
  454. async def runnf(rdr, wrr):
  455. connval = GenericConnValidator(pubkeys, lambda: connectsockstr(args.servtarget))
  456. a = await NoiseForwarder('resp', _makefut((rdr, wrr)),
  457. connval, priv_key=privkey)
  458. # Setup server listener
  459. ssock = listensockstr(args.servlisten, runnf)
  460. loop = asyncio.get_event_loop()
  461. obj = loop.run_until_complete(ssock)
  462. loop.run_until_complete(obj.serve_forever())
  463. def cmd_genkey(args):
  464. keypair = genkeypair()
  465. key = x448.X448PrivateKey.generate()
  466. # public key part
  467. enc = serialization.Encoding.Raw
  468. pubformat = serialization.PublicFormat.Raw
  469. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  470. try:
  471. fname = args.fname + '.pub'
  472. with open(fname, 'x', encoding='ascii') as fp:
  473. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  474. except FileExistsError:
  475. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  476. sys.exit(2)
  477. enc = serialization.Encoding.PEM
  478. format = serialization.PrivateFormat.PKCS8
  479. encalgo = serialization.NoEncryption()
  480. with open(args.fname, 'x', encoding='ascii') as fp:
  481. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  482. def main():
  483. parser = argparse.ArgumentParser()
  484. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  485. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  486. parser_gk.add_argument('fname', type=str, help='file name for the key')
  487. parser_gk.set_defaults(func=cmd_genkey)
  488. parser_serv = subparsers.add_parser('server', help='run a server')
  489. parser_serv.add_argument('--clientkey', '-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  490. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  491. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  492. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  493. parser_serv.set_defaults(func=cmd_server)
  494. parser_client = subparsers.add_parser('client', help='run a client')
  495. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  496. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  497. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  498. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  499. parser_client.set_defaults(func=cmd_client)
  500. args = parser.parse_args()
  501. try:
  502. fun = args.func
  503. except AttributeError:
  504. parser.print_usage()
  505. sys.exit(5)
  506. fun(args)
  507. if __name__ == '__main__': # pragma: no cover
  508. main()
  509. def _asyncsockpair():
  510. '''Create a pair of sockets that are bound to each other.
  511. The function will return a tuple of two coroutine's, that
  512. each, when await'ed upon, will return the reader/writer pair.'''
  513. socka, sockb = socket.socketpair()
  514. return asyncio.open_connection(sock=socka), \
  515. asyncio.open_connection(sock=sockb)
  516. async def _awaitfile(fname):
  517. while not os.path.exists(fname):
  518. await asyncio.sleep(.01)
  519. return True
  520. class TestMain(unittest.TestCase):
  521. def setUp(self):
  522. # setup temporary directory
  523. d = os.path.realpath(tempfile.mkdtemp())
  524. self.basetempdir = d
  525. self.tempdir = os.path.join(d, 'subdir')
  526. os.mkdir(self.tempdir)
  527. # Generate key pairs
  528. self.server_key_pair = genkeypair()
  529. self.client_key_pair = genkeypair()
  530. os.chdir(self.tempdir)
  531. def tearDown(self):
  532. #print('td:', time.time())
  533. shutil.rmtree(self.basetempdir)
  534. self.tempdir = None
  535. @asynccontextmanager
  536. async def run_with_args(self, *args, pipes=True):
  537. kwargs = {}
  538. if pipes:
  539. kwargs.update(dict(
  540. stdout=asyncio.subprocess.PIPE,
  541. stderr=asyncio.subprocess.PIPE))
  542. aproc = asyncio.create_subprocess_exec(sys.executable,
  543. # XXX - figure out how to add coverage data on these runs
  544. #'-m', 'coverage', 'run', '-p',
  545. __file__, *args, **kwargs)
  546. try:
  547. proc = await aproc
  548. yield proc
  549. finally:
  550. if proc.returncode is None:
  551. proc.terminate()
  552. # Make sure that process exits before continuing
  553. await proc.wait()
  554. @async_test
  555. async def test_noargs(self):
  556. async with self.run_with_args() as proc:
  557. await proc.wait()
  558. # XXX - not checking error message
  559. # And that it exited w/ the correct code
  560. self.assertEqual(proc.returncode, 5)
  561. async def genkey(self, name):
  562. async with self.run_with_args('genkey', name, pipes=False) as proc:
  563. await proc.wait()
  564. self.assertEqual(proc.returncode, 0)
  565. @async_test
  566. async def test_loadpubkey(self):
  567. keypath = os.path.join(self.tempdir, 'loadpubkeytest')
  568. await self.genkey(keypath)
  569. privkey = loadprivkey(keypath)
  570. enc = serialization.Encoding.Raw
  571. pubformat = serialization.PublicFormat.Raw
  572. pubkeybytes = privkey.public_key().public_bytes(encoding=enc,
  573. format=pubformat)
  574. pubkey = loadpubkeyraw(keypath + '.pub')
  575. self.assertEqual(pubkeybytes, pubkey)
  576. privrawkey = loadprivkeyraw(keypath)
  577. enc = serialization.Encoding.Raw
  578. privformat = serialization.PrivateFormat.Raw
  579. encalgo = serialization.NoEncryption()
  580. rprivrawkey = privkey.private_bytes(encoding=enc,
  581. format=privformat, encryption_algorithm=encalgo)
  582. self.assertEqual(rprivrawkey, privrawkey)
  583. @async_test
  584. async def test_clientkeymismatch(self):
  585. # make sure that if there's a client key mismatch, we
  586. # don't connect
  587. # Generate necessar keys
  588. servkeypath = os.path.join(self.tempdir, 'server_key')
  589. await self.genkey(servkeypath)
  590. clientkeypath = os.path.join(self.tempdir, 'client_key')
  591. await self.genkey(clientkeypath)
  592. badclientkeypath = os.path.join(self.tempdir, 'badclient_key')
  593. await self.genkey(badclientkeypath)
  594. # forwards connectsion to this socket (created by client)
  595. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  596. ptclientstr = _makeunix(ptclientpath)
  597. # this is the socket server listen to
  598. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  599. incservstr = _makeunix(incservpath)
  600. # to this socket, opened by server
  601. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  602. servtargstr = _makeunix(servtargpath)
  603. # Setup server target listener
  604. ptsockevent = asyncio.Event()
  605. # Bind to pt listener
  606. lsock = await listensockstr(servtargstr, None)
  607. # Startup the server
  608. wserver = self.run_with_args('server',
  609. '-c', clientkeypath + '.pub',
  610. servkeypath, incservstr, servtargstr)
  611. # Startup the client with the "bad" key
  612. wclient = self.run_with_args('client', badclientkeypath,
  613. servkeypath + '.pub', ptclientstr, incservstr)
  614. async with wserver as server, wclient as client:
  615. # wait for server target to be created
  616. await _awaitfile(servtargpath)
  617. # wait for server to start
  618. await _awaitfile(incservpath)
  619. # wait for client to start
  620. await _awaitfile(ptclientpath)
  621. # Connect to the client
  622. reader, writer = await connectsockstr(ptclientstr)
  623. # XXX - this might not be the best test.
  624. with self.assertRaises(asyncio.futures.TimeoutError):
  625. # make sure that we don't get the conenction
  626. await asyncio.wait_for(ptsockevent.wait(), .5)
  627. writer.close()
  628. # Make sure that when the server is terminated
  629. server.terminate()
  630. # that it's stderr
  631. stdout, stderr = await server.communicate()
  632. #print('s:', repr((stdout, stderr)))
  633. # doesn't have an exceptions never retrieved
  634. # even the example echo server has this same leak
  635. #self.assertNotIn(b'Task exception was never retrieved', stderr)
  636. lsock.close()
  637. await lsock.wait_closed()
  638. # Kill off the client
  639. client.terminate()
  640. stdout, stderr = await client.communicate()
  641. #print('s:', repr((stdout, stderr)))
  642. # XXX - figure out how to clean up client properly
  643. @async_test
  644. async def test_end2end(self):
  645. # Generate necessar keys
  646. servkeypath = os.path.join(self.tempdir, 'server_key')
  647. await self.genkey(servkeypath)
  648. clientkeypath = os.path.join(self.tempdir, 'client_key')
  649. await self.genkey(clientkeypath)
  650. # forwards connectsion to this socket (created by client)
  651. ptclientpath = os.path.join(self.tempdir, 'incclient.sock')
  652. ptclientstr = _makeunix(ptclientpath)
  653. # this is the socket server listen to
  654. incservpath = os.path.join(self.tempdir, 'incserv.sock')
  655. incservstr = _makeunix(incservpath)
  656. # to this socket, opened by server
  657. servtargpath = os.path.join(self.tempdir, 'servtarget.sock')
  658. servtargstr = _makeunix(servtargpath)
  659. # Setup server target listener
  660. ptsock = []
  661. ptsockevent = asyncio.Event()
  662. def ptsockaccept(reader, writer, ptsock=ptsock):
  663. ptsock.append((reader, writer))
  664. ptsockevent.set()
  665. # Bind to pt listener
  666. lsock = await listensockstr(servtargstr, ptsockaccept)
  667. # Startup the server
  668. wserver = self.run_with_args('server',
  669. '-c', clientkeypath + '.pub',
  670. servkeypath, incservstr, servtargstr,
  671. pipes=False)
  672. # Startup the client
  673. wclient = self.run_with_args('client',
  674. clientkeypath, servkeypath + '.pub', ptclientstr,
  675. incservstr, pipes=False)
  676. async with wserver as server, wclient as client:
  677. # wait for server target to be created
  678. await _awaitfile(servtargpath)
  679. # wait for server to start
  680. await _awaitfile(incservpath)
  681. # wait for client to start
  682. await _awaitfile(ptclientpath)
  683. # Connect to the client
  684. reader, writer = await connectsockstr(ptclientstr)
  685. # send a message
  686. ptmsg = b'this is a message for testing'
  687. writer.write(ptmsg)
  688. # make sure that we got the conenction
  689. await ptsockevent.wait()
  690. # get the connection
  691. endrdr, endwrr = ptsock[0]
  692. # make sure we can read back what we sent
  693. self.assertEqual(ptmsg,
  694. await endrdr.readexactly(len(ptmsg)))
  695. # test some additional messages
  696. for i in [ 129, 1287, 28792, 129872 ]:
  697. # in on direction
  698. msg = os.urandom(i)
  699. writer.write(msg)
  700. self.assertEqual(msg,
  701. await endrdr.readexactly(len(msg)))
  702. # and the other
  703. endwrr.write(msg)
  704. self.assertEqual(msg,
  705. await reader.readexactly(len(msg)))
  706. writer.close()
  707. endwrr.close()
  708. lsock.close()
  709. await lsock.wait_closed()
  710. # XXX - more testing that things exited properly
  711. @async_test
  712. async def test_genkey(self):
  713. # that it can generate a key
  714. async with self.run_with_args('genkey', 'somefile') as proc:
  715. await proc.wait()
  716. #print(await proc.communicate())
  717. self.assertEqual(proc.returncode, 0)
  718. with open('somefile.pub', encoding='ascii') as fp:
  719. lines = fp.readlines()
  720. self.assertEqual(len(lines), 1)
  721. keytype, keyvalue = lines[0].split()
  722. self.assertEqual(keytype, 'ntun-x448')
  723. key = x448.X448PublicKey.from_public_bytes(
  724. base64.urlsafe_b64decode(keyvalue))
  725. key = loadprivkey('somefile')
  726. self.assertIsInstance(key, x448.X448PrivateKey)
  727. # that a second call fails
  728. async with self.run_with_args('genkey', 'somefile') as proc:
  729. await proc.wait()
  730. stdoutdata, stderrdata = await proc.communicate()
  731. self.assertFalse(stdoutdata)
  732. self.assertEqual(
  733. b'failed to create somefile.pub, file exists.\n',
  734. stderrdata)
  735. # And that it exited w/ the correct code
  736. self.assertEqual(proc.returncode, 2)
  737. class TestNoiseFowarder(unittest.TestCase):
  738. def setUp(self):
  739. # setup temporary directory
  740. d = os.path.realpath(tempfile.mkdtemp())
  741. self.basetempdir = d
  742. self.tempdir = os.path.join(d, 'subdir')
  743. os.mkdir(self.tempdir)
  744. # Generate key pairs
  745. self.server_key_pair = genkeypair()
  746. self.client_key_pair = genkeypair()
  747. def tearDown(self):
  748. shutil.rmtree(self.basetempdir)
  749. self.tempdir = None
  750. @async_test
  751. async def test_clientkeymissmatch(self):
  752. # generate a key that is incorrect
  753. wrongclient_key_pair = genkeypair()
  754. # the secure socket
  755. clssockapair, clssockbpair = _asyncsockpair()
  756. reader, writer = await clssockapair
  757. # create the server
  758. servnf = asyncio.create_task(NoiseForwarder('resp',
  759. clssockbpair, GenericConnValidator([], None),
  760. priv_key=self.server_key_pair[1]))
  761. # Create client
  762. proto = NoiseConnection.from_name(
  763. b'Noise_XK_448_ChaChaPoly_SHA256')
  764. proto.set_as_initiator()
  765. # Setup wrong client key
  766. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  767. wrongclient_key_pair[1])
  768. # but the correct server key
  769. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  770. self.server_key_pair[0])
  771. proto.start_handshake()
  772. # Send first message
  773. message = proto.write_message()
  774. self.assertEqual(len(message), _handshakelens[0])
  775. writer.write(message)
  776. # Get response
  777. respmsg = await reader.readexactly(_handshakelens[1])
  778. proto.read_message(respmsg)
  779. # Send final reply
  780. message = proto.write_message()
  781. writer.write(message)
  782. # Make sure handshake has completed
  783. self.assertTrue(proto.handshake_finished)
  784. with self.assertRaises(ValueError):
  785. await servnf
  786. writer.close()
  787. @async_test
  788. async def test_server(self):
  789. # Test is plumbed:
  790. # (reader, writer) -> servsock ->
  791. # (rdr, wrr) NoiseForward (reader, writer) ->
  792. # servptsock -> (ptsock[0], ptsock[1])
  793. # Path that the server will sit on
  794. servsockpath = os.path.join(self.tempdir, 'servsock')
  795. servarg = _makeunix(servsockpath)
  796. # Path that the server will send pt data to
  797. servptpath = os.path.join(self.tempdir, 'servptsock')
  798. # Setup pt target listener
  799. pttarg = _makeunix(servptpath)
  800. ptsock = []
  801. ptsockevent = asyncio.Event()
  802. def ptsockaccept(reader, writer, ptsock=ptsock):
  803. ptsock.append((reader, writer))
  804. ptsockevent.set()
  805. # Bind to pt listener
  806. lsock = await listensockstr(pttarg, ptsockaccept)
  807. nfs = []
  808. event = asyncio.Event()
  809. async def runnf(rdr, wrr):
  810. connval = GenericConnValidator(
  811. [ self.client_key_pair[0] ],
  812. lambda: connectsockstr(pttarg))
  813. a = await NoiseForwarder('resp',
  814. _makefut((rdr, wrr)), connval,
  815. priv_key=self.server_key_pair[1])
  816. nfs.append(a)
  817. event.set()
  818. # Setup server listener
  819. ssock = await listensockstr(servarg, runnf)
  820. # Connect to server
  821. reader, writer = await connectsockstr(servarg)
  822. # Create client
  823. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  824. proto.set_as_initiator()
  825. # Setup required keys
  826. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  827. self.client_key_pair[1])
  828. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  829. self.server_key_pair[0])
  830. proto.start_handshake()
  831. # Send first message
  832. message = proto.write_message()
  833. self.assertEqual(len(message), _handshakelens[0])
  834. writer.write(message)
  835. # Get response
  836. respmsg = await reader.readexactly(_handshakelens[1])
  837. proto.read_message(respmsg)
  838. # Send final reply
  839. message = proto.write_message()
  840. writer.write(message)
  841. # Make sure handshake has completed
  842. self.assertTrue(proto.handshake_finished)
  843. # generate the keys for lengths
  844. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  845. b'toresp')
  846. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  847. b'toinit')
  848. pversion = 0
  849. # Send the protocol version string first
  850. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  851. writer.write(enclenfun(encmsg))
  852. writer.write(encmsg)
  853. # Read the peer's protocol version
  854. # find out how much we need to read
  855. encmsg = await reader.readexactly(2 + 16)
  856. tlen = declenfun(encmsg)
  857. # read the rest of the message
  858. rencmsg = await reader.readexactly(tlen - 16)
  859. tmsg = encmsg[2:] + rencmsg
  860. rptmsg = proto.decrypt(tmsg)
  861. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), pversion)
  862. # write a test message
  863. ptmsg = b'this is a test message that should be a little in length'
  864. encmsg = proto.encrypt(ptmsg)
  865. writer.write(enclenfun(encmsg))
  866. writer.write(encmsg)
  867. # wait for the connection to arrive
  868. await ptsockevent.wait()
  869. ptreader, ptwriter = ptsock[0]
  870. # read the test message
  871. rptmsg = await ptreader.readexactly(len(ptmsg))
  872. self.assertEqual(rptmsg, ptmsg)
  873. # write a different message
  874. ptmsg = os.urandom(2843)
  875. encmsg = proto.encrypt(ptmsg)
  876. writer.write(enclenfun(encmsg))
  877. writer.write(encmsg)
  878. # read the test message
  879. rptmsg = await ptreader.readexactly(len(ptmsg))
  880. self.assertEqual(rptmsg, ptmsg)
  881. # now try the other way
  882. ptmsg = os.urandom(912)
  883. ptwriter.write(ptmsg)
  884. # find out how much we need to read
  885. encmsg = await reader.readexactly(2 + 16)
  886. tlen = declenfun(encmsg)
  887. # read the rest of the message
  888. rencmsg = await reader.readexactly(tlen - 16)
  889. tmsg = encmsg[2:] + rencmsg
  890. rptmsg = proto.decrypt(tmsg)
  891. self.assertEqual(rptmsg, ptmsg)
  892. # shut down sending
  893. writer.write_eof()
  894. # so pt reader should be shut down
  895. self.assertEqual(b'', await ptreader.read(1))
  896. self.assertTrue(ptreader.at_eof())
  897. # shut down pt
  898. ptwriter.write_eof()
  899. # make sure the enc reader is eof
  900. self.assertEqual(b'', await reader.read(1))
  901. self.assertTrue(reader.at_eof())
  902. await event.wait()
  903. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  904. writer.close()
  905. ptwriter.close()
  906. lsock.close()
  907. ssock.close()
  908. await lsock.wait_closed()
  909. await ssock.wait_closed()
  910. @async_test
  911. async def test_protocolversionmismatch(self):
  912. # make sure that if we send a future version, that we
  913. # still get a protocol version, and that the connection
  914. # is closed w/o establishing a connection to the remote
  915. # side
  916. # Test is plumbed:
  917. # (reader, writer) -> servsock ->
  918. # (rdr, wrr) NoiseForward (reader, writer) ->
  919. # servptsock -> (ptsock[0], ptsock[1])
  920. # Path that the server will sit on
  921. servsockpath = os.path.join(self.tempdir, 'servsock')
  922. servarg = _makeunix(servsockpath)
  923. # Path that the server will send pt data to
  924. servptpath = os.path.join(self.tempdir, 'servptsock')
  925. # Setup pt target listener
  926. pttarg = _makeunix(servptpath)
  927. ptsock = []
  928. ptsockevent = asyncio.Event()
  929. def ptsockaccept(reader, writer, ptsock=ptsock):
  930. ptsock.append((reader, writer))
  931. ptsockevent.set()
  932. # Bind to pt listener
  933. lsock = await listensockstr(pttarg, ptsockaccept)
  934. nfs = []
  935. event = asyncio.Event()
  936. async def runnf(rdr, wrr):
  937. ptpairfun = asyncio.create_task(connectsockstr(pttarg))
  938. try:
  939. connval = GenericConnValidator(
  940. [ self.client_key_pair[0] ],
  941. lambda: ptpairfun)
  942. a = await NoiseForwarder('resp',
  943. _makefut((rdr, wrr)), connval,
  944. priv_key=self.server_key_pair[1])
  945. except RuntimeError as e:
  946. nfs.append(e)
  947. event.set()
  948. return
  949. nfs.append(a)
  950. event.set()
  951. # Setup server listener
  952. ssock = await listensockstr(servarg, runnf)
  953. # Connect to server
  954. reader, writer = await connectsockstr(servarg)
  955. # Create client
  956. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  957. proto.set_as_initiator()
  958. # Setup required keys
  959. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  960. self.client_key_pair[1])
  961. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  962. self.server_key_pair[0])
  963. proto.start_handshake()
  964. # Send first message
  965. message = proto.write_message()
  966. self.assertEqual(len(message), _handshakelens[0])
  967. writer.write(message)
  968. # Get response
  969. respmsg = await reader.readexactly(_handshakelens[1])
  970. proto.read_message(respmsg)
  971. # Send final reply
  972. message = proto.write_message()
  973. writer.write(message)
  974. # Make sure handshake has completed
  975. self.assertTrue(proto.handshake_finished)
  976. # generate the keys for lengths
  977. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  978. b'toresp')
  979. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  980. b'toinit')
  981. pversion = 1
  982. # Send the protocol version string first
  983. encmsg = proto.encrypt(pversion.to_bytes(1, byteorder='big'))
  984. writer.write(enclenfun(encmsg))
  985. writer.write(encmsg)
  986. # Read the peer's protocol version
  987. # find out how much we need to read
  988. encmsg = await reader.readexactly(2 + 16)
  989. tlen = declenfun(encmsg)
  990. # read the rest of the message
  991. rencmsg = await reader.readexactly(tlen - 16)
  992. tmsg = encmsg[2:] + rencmsg
  993. rptmsg = proto.decrypt(tmsg)
  994. self.assertEqual(int.from_bytes(rptmsg, byteorder='big'), 0)
  995. await event.wait()
  996. self.assertIsInstance(nfs[0], RuntimeError)
  997. @async_test
  998. async def test_serverclient(self):
  999. # plumbing:
  1000. #
  1001. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  1002. #
  1003. ptcsockapair, ptcsockbpair = _asyncsockpair()
  1004. ptcareader, ptcawriter = await ptcsockapair
  1005. #ptcsockbpair passed directly
  1006. clssockapair, clssockbpair = _asyncsockpair()
  1007. #both passed directly
  1008. ptssockapair, ptssockbpair = _asyncsockpair()
  1009. #ptssockapair passed directly
  1010. ptsbreader, ptsbwriter = await ptssockbpair
  1011. validateclientside = GenericConnValidator(
  1012. [ self.server_key_pair[0] ], lambda: ptcsockbpair)
  1013. validateserverside = GenericConnValidator(
  1014. [ self.client_key_pair[0] ], lambda: ptssockapair)
  1015. clientnf = asyncio.create_task(NoiseForwarder('init',
  1016. clssockapair, validateclientside,
  1017. priv_key=self.client_key_pair[1],
  1018. pub_key=self.server_key_pair[0]))
  1019. servnf = asyncio.create_task(NoiseForwarder('resp',
  1020. clssockbpair, validateserverside,
  1021. priv_key=self.server_key_pair[1]))
  1022. # send a message
  1023. msga = os.urandom(183)
  1024. ptcawriter.write(msga)
  1025. # make sure we get the same message
  1026. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  1027. # send a second message
  1028. msga = os.urandom(2834)
  1029. ptcawriter.write(msga)
  1030. # make sure we get the same message
  1031. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  1032. # send a message larger than the block size
  1033. msga = os.urandom(103958)
  1034. ptcawriter.write(msga)
  1035. # make sure we get the same message
  1036. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  1037. # send a message the other direction
  1038. msga = os.urandom(103958)
  1039. ptsbwriter.write(msga)
  1040. # make sure we get the same message
  1041. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  1042. # close down the pt writers, the rest should follow
  1043. ptsbwriter.write_eof()
  1044. ptcawriter.write_eof()
  1045. # make sure they are closed, and there is no more data
  1046. self.assertEqual(b'', await ptsbreader.read(1))
  1047. self.assertTrue(ptsbreader.at_eof())
  1048. self.assertEqual(b'', await ptcareader.read(1))
  1049. self.assertTrue(ptcareader.at_eof())
  1050. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  1051. self.assertEqual([ 'dec', 'enc' ], await servnf)
  1052. await ptsbwriter.drain()
  1053. await ptcawriter.drain()
  1054. ptsbwriter.close()
  1055. ptcawriter.close()