An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

656 lines
18 KiB

  1. from cryptography.hazmat.backends import default_backend
  2. from cryptography.hazmat.primitives import hashes
  3. from cryptography.hazmat.primitives import serialization
  4. from cryptography.hazmat.primitives.asymmetric import x448
  5. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  6. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  7. from cryptography.hazmat.primitives.serialization import load_pem_private_key
  8. from noise.connection import NoiseConnection, Keypair
  9. #import tracemalloc; tracemalloc.start()
  10. import argparse
  11. import asyncio
  12. import base64
  13. import os.path
  14. import shutil
  15. import socket
  16. import sys
  17. import tempfile
  18. import threading
  19. import unittest
  20. _backend = default_backend()
  21. def genkeypair():
  22. '''Generates a keypair, and returns a tuple of (public, private).
  23. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  24. key = x448.X448PrivateKey.generate()
  25. enc = serialization.Encoding.Raw
  26. pubformat = serialization.PublicFormat.Raw
  27. privformat = serialization.PrivateFormat.Raw
  28. encalgo = serialization.NoEncryption()
  29. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  30. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  31. return pub, priv
  32. def _makefut(obj):
  33. loop = asyncio.get_running_loop()
  34. fut = loop.create_future()
  35. fut.set_result(obj)
  36. return fut
  37. def _makeunix(path):
  38. '''Make a properly formed unix path socket string.'''
  39. return 'unix:%s' % path
  40. def _parsesockstr(sockstr):
  41. proto, rem = sockstr.split(':', 1)
  42. return proto, rem
  43. async def connectsockstr(sockstr):
  44. proto, rem = _parsesockstr(sockstr)
  45. reader, writer = await asyncio.open_unix_connection(rem)
  46. return reader, writer
  47. async def listensockstr(sockstr, cb):
  48. '''Wrapper for asyncio.start_x_server.
  49. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  50. If the proto has a default parameter, the value can be used
  51. directly, like: 'proto:value'. This is only allowed when the
  52. value can unambiguously be determined not to be a param.
  53. The characters that define 'param' must be all lower case ascii
  54. characters and may contain an underscore. The first character
  55. must not be and underscore.
  56. Supported protocols:
  57. unix:
  58. Default parameter is path.
  59. The path parameter specifies the path to the
  60. unix domain socket. The path MUST start w/ a
  61. slash if it is used as a default parameter.
  62. '''
  63. proto, rem = _parsesockstr(sockstr)
  64. server = await asyncio.start_unix_server(cb, path=rem)
  65. return server
  66. # !!python makemessagelengths.py
  67. _handshakelens = \
  68. [72, 72, 88]
  69. def _genciphfun(hash, ad):
  70. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  71. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  72. key = hkdf.derive(hash)
  73. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  74. backend=_backend)
  75. enctor = cipher.encryptor()
  76. def encfun(data):
  77. # Returns the two bytes for length
  78. val = len(data)
  79. encbytes = enctor.update(data[:16])
  80. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  81. return (val ^ mask).to_bytes(length=2, byteorder='big')
  82. def decfun(data):
  83. # takes off the data and returns the total
  84. # length
  85. val = int.from_bytes(data[:2], byteorder='big')
  86. encbytes = enctor.update(data[2:2 + 16])
  87. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  88. return val ^ mask
  89. return encfun, decfun
  90. async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
  91. rdr, wrr = await rdrwrr
  92. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  93. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  94. if pub_key is not None:
  95. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  96. pub_key)
  97. if mode == 'resp':
  98. proto.set_as_responder()
  99. proto.start_handshake()
  100. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  101. wrr.write(proto.write_message())
  102. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  103. elif mode == 'init':
  104. proto.set_as_initiator()
  105. proto.start_handshake()
  106. wrr.write(proto.write_message())
  107. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  108. wrr.write(proto.write_message())
  109. if not proto.handshake_finished: # pragma: no cover
  110. raise RuntimeError('failed to finish handshake')
  111. # generate the keys for lengths
  112. # XXX - get_handshake_hash is probably not the best option, but
  113. # this is only to obscure lengths, it is not required to be secure
  114. # as the underlying NoiseProtocol securely validates everything.
  115. # It is marginally useful as writing patterns likely expose the
  116. # true length. Adding padding could marginally help w/ this.
  117. if mode == 'resp':
  118. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  119. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  120. elif mode == 'init':
  121. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  122. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  123. reader, writer = await ptpair
  124. async def decses():
  125. try:
  126. while True:
  127. try:
  128. msg = await rdr.readexactly(2 + 16)
  129. except asyncio.streams.IncompleteReadError:
  130. if rdr.at_eof():
  131. return 'dec'
  132. tlen = declenfun(msg)
  133. rmsg = await rdr.readexactly(tlen - 16)
  134. tmsg = msg[2:] + rmsg
  135. writer.write(proto.decrypt(tmsg))
  136. await writer.drain()
  137. #except:
  138. # import traceback
  139. # traceback.print_exc()
  140. # raise
  141. finally:
  142. writer.write_eof()
  143. async def encses():
  144. try:
  145. while True:
  146. # largest message
  147. ptmsg = await reader.read(65535 - 16)
  148. if not ptmsg:
  149. # eof
  150. return 'enc'
  151. encmsg = proto.encrypt(ptmsg)
  152. wrr.write(enclenfun(encmsg))
  153. wrr.write(encmsg)
  154. await wrr.drain()
  155. #except:
  156. # import traceback
  157. # traceback.print_exc()
  158. # raise
  159. finally:
  160. wrr.write_eof()
  161. return await asyncio.gather(decses(), encses())
  162. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  163. # Slightly modified to timeout
  164. def async_test(f):
  165. def wrapper(*args, **kwargs):
  166. coro = asyncio.coroutine(f)
  167. future = coro(*args, **kwargs)
  168. loop = asyncio.get_event_loop()
  169. # timeout after 2 seconds
  170. loop.run_until_complete(asyncio.wait_for(future, 2))
  171. return wrapper
  172. class Tests_misc(unittest.TestCase):
  173. def test_listensockstr(self):
  174. # XXX write test
  175. pass
  176. def test_genciphfun(self):
  177. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  178. msg = b'this is a bunch of data'
  179. tb = enc(msg)
  180. self.assertEqual(len(msg), dec(tb + msg))
  181. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  182. msg = os.urandom(i)
  183. self.assertEqual(len(msg), dec(enc(msg) + msg))
  184. def cmd_client(args):
  185. pass
  186. def cmd_server(args):
  187. pass
  188. def cmd_genkey(args):
  189. keypair = genkeypair()
  190. key = x448.X448PrivateKey.generate()
  191. # public key part
  192. enc = serialization.Encoding.Raw
  193. pubformat = serialization.PublicFormat.Raw
  194. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  195. try:
  196. fname = args.fname + '.pub'
  197. with open(fname, 'x', encoding='ascii') as fp:
  198. print('ntun-x448', base64.urlsafe_b64encode(pub).decode('ascii'), file=fp)
  199. except FileExistsError:
  200. print('failed to create %s, file exists.' % fname, file=sys.stderr)
  201. sys.exit(2)
  202. enc = serialization.Encoding.PEM
  203. format = serialization.PrivateFormat.PKCS8
  204. encalgo = serialization.NoEncryption()
  205. with open(args.fname, 'x', encoding='ascii') as fp:
  206. fp.write(key.private_bytes(encoding=enc, format=format, encryption_algorithm=encalgo).decode('ascii'))
  207. def main():
  208. parser = argparse.ArgumentParser()
  209. subparsers = parser.add_subparsers(title='subcommands', description='valid subcommands', help='additional help')
  210. parser_gk = subparsers.add_parser('genkey', help='generate keys')
  211. parser_gk.add_argument('fname', type=str, help='file name for the key')
  212. parser_gk.set_defaults(func=cmd_genkey)
  213. parser_serv = subparsers.add_parser('server', help='run a server')
  214. parser_serv.add_argument('-c', action='append', type=str, help='file of authorized client keys, or a .pub file')
  215. parser_serv.add_argument('servkey', type=str, help='file name for the server key')
  216. parser_serv.add_argument('servlisten', type=str, help='Connection that the server listens on')
  217. parser_serv.add_argument('servtarget', type=str, help='Connection that the server connects to')
  218. parser_serv.set_defaults(func=cmd_server)
  219. parser_client = subparsers.add_parser('client', help='run a client')
  220. parser_client.add_argument('clientkey', type=str, help='file name for the client private key')
  221. parser_client.add_argument('servkey', type=str, help='file name for the server public key')
  222. parser_client.add_argument('clientlisten', type=str, help='Connection that the client listens on')
  223. parser_client.add_argument('clienttarget', type=str, help='Connection that the client connects to')
  224. parser_client.set_defaults(func=cmd_client)
  225. args = parser.parse_args()
  226. try:
  227. fun = args.func
  228. except AttributeError:
  229. parser.print_usage()
  230. sys.exit(5)
  231. fun(args)
  232. if __name__ == '__main__': # pragma: no cover
  233. main()
  234. def _asyncsockpair():
  235. '''Create a pair of sockets that are bound to each other.
  236. The function will return a tuple of two coroutine's, that
  237. each, when await'ed upon, will return the reader/writer pair.'''
  238. socka, sockb = socket.socketpair()
  239. return asyncio.open_connection(sock=socka), \
  240. asyncio.open_connection(sock=sockb)
  241. class TestMain(unittest.TestCase):
  242. def setUp(self):
  243. # setup temporary directory
  244. d = os.path.realpath(tempfile.mkdtemp())
  245. self.basetempdir = d
  246. self.tempdir = os.path.join(d, 'subdir')
  247. os.mkdir(self.tempdir)
  248. # Generate key pairs
  249. self.server_key_pair = genkeypair()
  250. self.client_key_pair = genkeypair()
  251. os.chdir(self.tempdir)
  252. def tearDown(self):
  253. shutil.rmtree(self.basetempdir)
  254. self.tempdir = None
  255. @async_test
  256. async def test_noargs(self):
  257. proc = await self.run_with_args()
  258. await proc.wait()
  259. # XXX - not checking error message
  260. # And that it exited w/ the correct code
  261. self.assertEqual(proc.returncode, 5)
  262. def run_with_args(self, *args, pipes=True):
  263. kwargs = {}
  264. if pipes:
  265. kwargs.update(dict(
  266. stdout=asyncio.subprocess.PIPE,
  267. stderr=asyncio.subprocess.PIPE))
  268. return asyncio.create_subprocess_exec(sys.executable,
  269. # XXX - figure out how to add coverage data on these runs
  270. #'-m', 'coverage', 'run', '-p',
  271. __file__, *args, **kwargs)
  272. async def genkey(self, name):
  273. proc = await self.run_with_args('genkey', name)
  274. await proc.wait()
  275. self.assertEqual(proc.returncode, 0)
  276. @async_test
  277. async def test_both(self):
  278. # Generate necessar keys
  279. await self.genkey('server_key')
  280. await self.genkey('client_key')
  281. ptclientstr = _makeunix(os.path.join(self.tempdir, 'incclient.sock'))
  282. incservstr = _makeunix(os.path.join(self.tempdir, 'incserv.sock'))
  283. servtargstr = _makeunix(os.path.join(self.tempdir, 'servtarget.sock'))
  284. # Setup pt target listener
  285. pttarg = _makeunix('servtarget.sock')
  286. ptsock = []
  287. def ptsockaccept(reader, writer, ptsock=ptsock):
  288. ptsock.append((reader, writer))
  289. # Bind to pt listener
  290. lsock = await listensockstr(pttarg, ptsockaccept)
  291. # Startup the server
  292. server = await self.run_with_args('server',
  293. '-c', 'client_key.pub',
  294. 'server_key', incservstr, servtargstr,
  295. pipes=False)
  296. # Startup the client
  297. server = await self.run_with_args('client',
  298. 'client_key', 'server_key.pub', ptclientstr, incservstr,
  299. pipes=False)
  300. # Connect to the client
  301. reader, writer = await connectsockstr(ptclientstr)
  302. @async_test
  303. async def test_genkey(self):
  304. # that it can generate a key
  305. proc = await self.run_with_args('genkey', 'somefile')
  306. await proc.wait()
  307. #print(await proc.communicate())
  308. self.assertEqual(proc.returncode, 0)
  309. with open('somefile.pub', encoding='ascii') as fp:
  310. lines = fp.readlines()
  311. self.assertEqual(len(lines), 1)
  312. keytype, keyvalue = lines[0].split()
  313. self.assertEqual(keytype, 'ntun-x448')
  314. key = x448.X448PublicKey.from_public_bytes(base64.urlsafe_b64decode(keyvalue))
  315. with open('somefile', encoding='ascii') as fp:
  316. data = fp.read().encode('ascii')
  317. key = load_pem_private_key(data, password=None, backend=default_backend())
  318. self.assertIsInstance(key, x448.X448PrivateKey)
  319. # that a second call fails
  320. proc = await self.run_with_args('genkey', 'somefile')
  321. await proc.wait()
  322. stdoutdata, stderrdata = await proc.communicate()
  323. self.assertFalse(stdoutdata)
  324. self.assertEqual(b'failed to create somefile.pub, file exists.\n', stderrdata)
  325. # And that it exited w/ the correct code
  326. self.assertEqual(proc.returncode, 2)
  327. class TestNoiseFowarder(unittest.TestCase):
  328. def setUp(self):
  329. # setup temporary directory
  330. d = os.path.realpath(tempfile.mkdtemp())
  331. self.basetempdir = d
  332. self.tempdir = os.path.join(d, 'subdir')
  333. os.mkdir(self.tempdir)
  334. # Generate key pairs
  335. self.server_key_pair = genkeypair()
  336. self.client_key_pair = genkeypair()
  337. def tearDown(self):
  338. shutil.rmtree(self.basetempdir)
  339. self.tempdir = None
  340. @async_test
  341. async def test_server(self):
  342. # Test is plumbed:
  343. # (reader, writer) -> servsock ->
  344. # (rdr, wrr) NoiseForward (reader, writer) ->
  345. # servptsock -> (ptsock[0], ptsock[1])
  346. # Path that the server will sit on
  347. servsockpath = os.path.join(self.tempdir, 'servsock')
  348. servarg = _makeunix(servsockpath)
  349. # Path that the server will send pt data to
  350. servptpath = os.path.join(self.tempdir, 'servptsock')
  351. # Setup pt target listener
  352. pttarg = _makeunix(servptpath)
  353. ptsock = []
  354. def ptsockaccept(reader, writer, ptsock=ptsock):
  355. ptsock.append((reader, writer))
  356. # Bind to pt listener
  357. lsock = await listensockstr(pttarg, ptsockaccept)
  358. nfs = []
  359. event = asyncio.Event()
  360. async def runnf(rdr, wrr):
  361. ptpair = asyncio.create_task(connectsockstr(pttarg))
  362. a = await NoiseForwarder('resp',
  363. _makefut((rdr, wrr)), ptpair,
  364. priv_key=self.server_key_pair[1])
  365. nfs.append(a)
  366. event.set()
  367. # Setup server listener
  368. ssock = await listensockstr(servarg, runnf)
  369. # Connect to server
  370. reader, writer = await connectsockstr(servarg)
  371. # Create client
  372. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  373. proto.set_as_initiator()
  374. # Setup required keys
  375. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  376. self.client_key_pair[1])
  377. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  378. self.server_key_pair[0])
  379. proto.start_handshake()
  380. # Send first message
  381. message = proto.write_message()
  382. self.assertEqual(len(message), _handshakelens[0])
  383. writer.write(message)
  384. # Get response
  385. respmsg = await reader.readexactly(_handshakelens[1])
  386. proto.read_message(respmsg)
  387. # Send final reply
  388. message = proto.write_message()
  389. writer.write(message)
  390. # Make sure handshake has completed
  391. self.assertTrue(proto.handshake_finished)
  392. # generate the keys for lengths
  393. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  394. b'toresp')
  395. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  396. b'toinit')
  397. # write a test message
  398. ptmsg = b'this is a test message that should be a little in length'
  399. encmsg = proto.encrypt(ptmsg)
  400. writer.write(enclenfun(encmsg))
  401. writer.write(encmsg)
  402. # XXX - how to sync?
  403. await asyncio.sleep(.1)
  404. ptreader, ptwriter = ptsock[0]
  405. # read the test message
  406. rptmsg = await ptreader.readexactly(len(ptmsg))
  407. self.assertEqual(rptmsg, ptmsg)
  408. # write a different message
  409. ptmsg = os.urandom(2843)
  410. encmsg = proto.encrypt(ptmsg)
  411. writer.write(enclenfun(encmsg))
  412. writer.write(encmsg)
  413. # XXX - how to sync?
  414. await asyncio.sleep(.1)
  415. # read the test message
  416. rptmsg = await ptreader.readexactly(len(ptmsg))
  417. self.assertEqual(rptmsg, ptmsg)
  418. # now try the other way
  419. ptmsg = os.urandom(912)
  420. ptwriter.write(ptmsg)
  421. # find out how much we need to read
  422. encmsg = await reader.readexactly(2 + 16)
  423. tlen = declenfun(encmsg)
  424. # read the rest of the message
  425. rencmsg = await reader.readexactly(tlen - 16)
  426. tmsg = encmsg[2:] + rencmsg
  427. rptmsg = proto.decrypt(tmsg)
  428. self.assertEqual(rptmsg, ptmsg)
  429. # shut down sending
  430. writer.write_eof()
  431. # so pt reader should be shut down
  432. self.assertEqual(b'', await ptreader.read(1))
  433. self.assertTrue(ptreader.at_eof())
  434. # shut down pt
  435. ptwriter.write_eof()
  436. # make sure the enc reader is eof
  437. self.assertEqual(b'', await reader.read(1))
  438. self.assertTrue(reader.at_eof())
  439. await event.wait()
  440. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  441. @async_test
  442. async def test_serverclient(self):
  443. # plumbing:
  444. #
  445. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  446. #
  447. ptcsockapair, ptcsockbpair = _asyncsockpair()
  448. ptcareader, ptcawriter = await ptcsockapair
  449. #ptcsockbpair passed directly
  450. clssockapair, clssockbpair = _asyncsockpair()
  451. #both passed directly
  452. ptssockapair, ptssockbpair = _asyncsockpair()
  453. #ptssockapair passed directly
  454. ptsbreader, ptsbwriter = await ptssockbpair
  455. clientnf = asyncio.create_task(NoiseForwarder('init',
  456. clssockapair, ptcsockbpair,
  457. priv_key=self.client_key_pair[1],
  458. pub_key=self.server_key_pair[0]))
  459. servnf = asyncio.create_task(NoiseForwarder('resp',
  460. clssockbpair, ptssockapair,
  461. priv_key=self.server_key_pair[1]))
  462. # send a message
  463. msga = os.urandom(183)
  464. ptcawriter.write(msga)
  465. # make sure we get the same message
  466. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  467. # send a second message
  468. msga = os.urandom(2834)
  469. ptcawriter.write(msga)
  470. # make sure we get the same message
  471. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  472. # send a message larger than the block size
  473. msga = os.urandom(103958)
  474. ptcawriter.write(msga)
  475. # make sure we get the same message
  476. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  477. # send a message the other direction
  478. msga = os.urandom(103958)
  479. ptsbwriter.write(msga)
  480. # make sure we get the same message
  481. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  482. # close down the pt writers, the rest should follow
  483. ptsbwriter.write_eof()
  484. ptcawriter.write_eof()
  485. # make sure they are closed, and there is no more data
  486. self.assertEqual(b'', await ptsbreader.read(1))
  487. self.assertTrue(ptsbreader.at_eof())
  488. self.assertEqual(b'', await ptcareader.read(1))
  489. self.assertTrue(ptcareader.at_eof())
  490. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  491. self.assertEqual([ 'dec', 'enc' ], await servnf)