An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

444 lines
11 KiB

  1. from noise.connection import NoiseConnection, Keypair
  2. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  3. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  4. from cryptography.hazmat.primitives import hashes
  5. from twistednoise import genkeypair
  6. from cryptography.hazmat.backends import default_backend
  7. import asyncio
  8. import os.path
  9. import shutil
  10. import socket
  11. import tempfile
  12. import threading
  13. import unittest
  14. _backend = default_backend()
  15. def _makefut(obj):
  16. loop = asyncio.get_running_loop()
  17. fut = loop.create_future()
  18. fut.set_result(obj)
  19. return fut
  20. def _makeunix(path):
  21. '''Make a properly formed unix path socket string.'''
  22. return 'unix:%s' % path
  23. def _parsesockstr(sockstr):
  24. proto, rem = sockstr.split(':', 1)
  25. return proto, rem
  26. async def connectsockstr(sockstr):
  27. proto, rem = _parsesockstr(sockstr)
  28. reader, writer = await asyncio.open_unix_connection(rem)
  29. return reader, writer
  30. async def listensockstr(sockstr, cb):
  31. '''Wrapper for asyncio.start_x_server.
  32. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  33. If the proto has a default parameter, the value can be used
  34. directly, like: 'proto:value'. This is only allowed when the
  35. value can unambiguously be determined not to be a param.
  36. The characters that define 'param' must be all lower case ascii
  37. characters and may contain an underscore. The first character
  38. must not be and underscore.
  39. Supported protocols:
  40. unix:
  41. Default parameter is path.
  42. The path parameter specifies the path to the
  43. unix domain socket. The path MUST start w/ a
  44. slash if it is used as a default parameter.
  45. '''
  46. proto, rem = _parsesockstr(sockstr)
  47. server = await asyncio.start_unix_server(cb, path=rem)
  48. return server
  49. # !!python makemessagelengths.py
  50. _handshakelens = \
  51. [72, 72, 88]
  52. def _genciphfun(hash, ad):
  53. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  54. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  55. key = hkdf.derive(hash)
  56. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  57. backend=_backend)
  58. enctor = cipher.encryptor()
  59. def encfun(data):
  60. # Returns the two bytes for length
  61. val = len(data)
  62. encbytes = enctor.update(data[:16])
  63. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  64. return (val ^ mask).to_bytes(length=2, byteorder='big')
  65. def decfun(data):
  66. # takes off the data and returns the total
  67. # length
  68. val = int.from_bytes(data[:2], byteorder='big')
  69. encbytes = enctor.update(data[2:2 + 16])
  70. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  71. return val ^ mask
  72. return encfun, decfun
  73. async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
  74. rdr, wrr = await rdrwrr
  75. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  76. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  77. if pub_key is not None:
  78. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  79. pub_key)
  80. if mode == 'resp':
  81. proto.set_as_responder()
  82. proto.start_handshake()
  83. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  84. wrr.write(proto.write_message())
  85. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  86. elif mode == 'init':
  87. proto.set_as_initiator()
  88. proto.start_handshake()
  89. wrr.write(proto.write_message())
  90. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  91. wrr.write(proto.write_message())
  92. if not proto.handshake_finished: # pragma: no cover
  93. raise RuntimeError('failed to finish handshake')
  94. # generate the keys for lengths
  95. if mode == 'resp':
  96. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  97. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  98. elif mode == 'init':
  99. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  100. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  101. reader, writer = await ptpair
  102. async def decses():
  103. try:
  104. while True:
  105. try:
  106. msg = await rdr.readexactly(2 + 16)
  107. except asyncio.streams.IncompleteReadError:
  108. if rdr.at_eof():
  109. return 'dec'
  110. tlen = declenfun(msg)
  111. rmsg = await rdr.readexactly(tlen - 16)
  112. tmsg = msg[2:] + rmsg
  113. writer.write(proto.decrypt(tmsg))
  114. await writer.drain()
  115. #except:
  116. # import traceback
  117. # traceback.print_exc()
  118. # raise
  119. finally:
  120. writer.write_eof()
  121. async def encses():
  122. try:
  123. while True:
  124. # largest message
  125. ptmsg = await reader.read(65535 - 16)
  126. if not ptmsg:
  127. # eof
  128. return 'enc'
  129. encmsg = proto.encrypt(ptmsg)
  130. wrr.write(enclenfun(encmsg))
  131. wrr.write(encmsg)
  132. await wrr.drain()
  133. #except:
  134. # import traceback
  135. # traceback.print_exc()
  136. # raise
  137. finally:
  138. wrr.write_eof()
  139. return await asyncio.gather(decses(), encses())
  140. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  141. # Slightly modified to timeout
  142. def async_test(f):
  143. def wrapper(*args, **kwargs):
  144. coro = asyncio.coroutine(f)
  145. future = coro(*args, **kwargs)
  146. loop = asyncio.get_event_loop()
  147. # timeout after 2 seconds
  148. loop.run_until_complete(asyncio.wait_for(future, 2))
  149. return wrapper
  150. class Tests_misc(unittest.TestCase):
  151. def test_listensockstr(self):
  152. # XXX write test
  153. pass
  154. def test_genciphfun(self):
  155. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  156. msg = b'this is a bunch of data'
  157. tb = enc(msg)
  158. self.assertEqual(len(msg), dec(tb + msg))
  159. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  160. msg = os.urandom(i)
  161. self.assertEqual(len(msg), dec(enc(msg) + msg))
  162. def _asyncsockpair():
  163. '''Create a pair of sockets that are bound to each other.
  164. The function will return a tuple of two coroutine's, that
  165. each, when await'ed upon, will return the reader/writer pair.'''
  166. socka, sockb = socket.socketpair()
  167. return asyncio.open_connection(sock=socka), \
  168. asyncio.open_connection(sock=sockb)
  169. class Tests(unittest.TestCase):
  170. def setUp(self):
  171. # setup temporary directory
  172. d = os.path.realpath(tempfile.mkdtemp())
  173. self.basetempdir = d
  174. self.tempdir = os.path.join(d, 'subdir')
  175. os.mkdir(self.tempdir)
  176. # Generate key pairs
  177. self.server_key_pair = genkeypair()
  178. self.client_key_pair = genkeypair()
  179. def tearDown(self):
  180. shutil.rmtree(self.basetempdir)
  181. self.tempdir = None
  182. @async_test
  183. async def test_server(self):
  184. # Test is plumbed:
  185. # (reader, writer) -> servsock ->
  186. # (rdr, wrr) NoiseForward (reader, writer) ->
  187. # servptsock -> (ptsock[0], ptsock[1])
  188. # Path that the server will sit on
  189. servsockpath = os.path.join(self.tempdir, 'servsock')
  190. servarg = _makeunix(servsockpath)
  191. # Path that the server will send pt data to
  192. servptpath = os.path.join(self.tempdir, 'servptsock')
  193. # Setup pt target listener
  194. pttarg = _makeunix(servptpath)
  195. ptsock = []
  196. def ptsockaccept(reader, writer, ptsock=ptsock):
  197. ptsock.append((reader, writer))
  198. # Bind to pt listener
  199. lsock = await listensockstr(pttarg, ptsockaccept)
  200. nfs = []
  201. event = asyncio.Event()
  202. async def runnf(rdr, wrr):
  203. ptpair = asyncio.create_task(connectsockstr(pttarg))
  204. a = await NoiseForwarder('resp',
  205. _makefut((rdr, wrr)), ptpair,
  206. priv_key=self.server_key_pair[1])
  207. nfs.append(a)
  208. event.set()
  209. # Setup server listener
  210. ssock = await listensockstr(servarg, runnf)
  211. # Connect to server
  212. reader, writer = await connectsockstr(servarg)
  213. # Create client
  214. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  215. proto.set_as_initiator()
  216. # Setup required keys
  217. proto.set_keypair_from_private_bytes(Keypair.STATIC,
  218. self.client_key_pair[1])
  219. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC,
  220. self.server_key_pair[0])
  221. proto.start_handshake()
  222. # Send first message
  223. message = proto.write_message()
  224. self.assertEqual(len(message), _handshakelens[0])
  225. writer.write(message)
  226. # Get response
  227. respmsg = await reader.readexactly(_handshakelens[1])
  228. proto.read_message(respmsg)
  229. # Send final reply
  230. message = proto.write_message()
  231. writer.write(message)
  232. # Make sure handshake has completed
  233. self.assertTrue(proto.handshake_finished)
  234. # generate the keys for lengths
  235. enclenfun, _ = _genciphfun(proto.get_handshake_hash(),
  236. b'toresp')
  237. _, declenfun = _genciphfun(proto.get_handshake_hash(),
  238. b'toinit')
  239. # write a test message
  240. ptmsg = b'this is a test message that should be a little in length'
  241. encmsg = proto.encrypt(ptmsg)
  242. writer.write(enclenfun(encmsg))
  243. writer.write(encmsg)
  244. # XXX - how to sync?
  245. await asyncio.sleep(.1)
  246. ptreader, ptwriter = ptsock[0]
  247. # read the test message
  248. rptmsg = await ptreader.readexactly(len(ptmsg))
  249. self.assertEqual(rptmsg, ptmsg)
  250. # write a different message
  251. ptmsg = os.urandom(2843)
  252. encmsg = proto.encrypt(ptmsg)
  253. writer.write(enclenfun(encmsg))
  254. writer.write(encmsg)
  255. # XXX - how to sync?
  256. await asyncio.sleep(.1)
  257. # read the test message
  258. rptmsg = await ptreader.readexactly(len(ptmsg))
  259. self.assertEqual(rptmsg, ptmsg)
  260. # now try the other way
  261. ptmsg = os.urandom(912)
  262. ptwriter.write(ptmsg)
  263. # find out how much we need to read
  264. encmsg = await reader.readexactly(2 + 16)
  265. tlen = declenfun(encmsg)
  266. # read the rest of the message
  267. rencmsg = await reader.readexactly(tlen - 16)
  268. tmsg = encmsg[2:] + rencmsg
  269. rptmsg = proto.decrypt(tmsg)
  270. self.assertEqual(rptmsg, ptmsg)
  271. # shut down sending
  272. writer.write_eof()
  273. # so pt reader should be shut down
  274. self.assertEqual(b'', await ptreader.read(1))
  275. self.assertTrue(ptreader.at_eof())
  276. # shut down pt
  277. ptwriter.write_eof()
  278. # make sure the enc reader is eof
  279. self.assertEqual(b'', await reader.read(1))
  280. self.assertTrue(reader.at_eof())
  281. await event.wait()
  282. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  283. @async_test
  284. async def test_serverclient(self):
  285. # plumbing:
  286. #
  287. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  288. #
  289. ptcsockapair, ptcsockbpair = _asyncsockpair()
  290. ptcareader, ptcawriter = await ptcsockapair
  291. #ptcsockbpair passed directly
  292. clssockapair, clssockbpair = _asyncsockpair()
  293. #both passed directly
  294. ptssockapair, ptssockbpair = _asyncsockpair()
  295. #ptssockapair passed directly
  296. ptsbreader, ptsbwriter = await ptssockbpair
  297. clientnf = asyncio.create_task(NoiseForwarder('init',
  298. clssockapair, ptcsockbpair,
  299. priv_key=self.client_key_pair[1],
  300. pub_key=self.server_key_pair[0]))
  301. servnf = asyncio.create_task(NoiseForwarder('resp',
  302. clssockbpair, ptssockapair,
  303. priv_key=self.server_key_pair[1]))
  304. # send a message
  305. msga = os.urandom(183)
  306. ptcawriter.write(msga)
  307. # make sure we get the same message
  308. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  309. # send a second message
  310. msga = os.urandom(2834)
  311. ptcawriter.write(msga)
  312. # make sure we get the same message
  313. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  314. # send a message larger than the block size
  315. msga = os.urandom(103958)
  316. ptcawriter.write(msga)
  317. # make sure we get the same message
  318. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  319. # send a message the other direction
  320. msga = os.urandom(103958)
  321. ptsbwriter.write(msga)
  322. # make sure we get the same message
  323. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  324. # close down the pt writers, the rest should follow
  325. ptsbwriter.write_eof()
  326. ptcawriter.write_eof()
  327. # make sure they are closed, and there is no more data
  328. self.assertEqual(b'', await ptsbreader.read(1))
  329. self.assertTrue(ptsbreader.at_eof())
  330. self.assertEqual(b'', await ptcareader.read(1))
  331. self.assertTrue(ptcareader.at_eof())
  332. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  333. self.assertEqual([ 'dec', 'enc' ], await servnf)