An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

416 lines
11 KiB

  1. from noise.connection import NoiseConnection, Keypair
  2. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  3. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  4. from cryptography.hazmat.primitives import hashes
  5. from twistednoise import genkeypair
  6. from cryptography.hazmat.backends import default_backend
  7. import asyncio
  8. import os.path
  9. import shutil
  10. import socket
  11. import tempfile
  12. import threading
  13. import unittest
  14. _backend = default_backend()
  15. def _makeunix(path):
  16. '''Make a properly formed unix path socket string.'''
  17. return 'unix:%s' % path
  18. def _parsesockstr(sockstr):
  19. proto, rem = sockstr.split(':', 1)
  20. return proto, rem
  21. async def connectsockstr(sockstr):
  22. proto, rem = _parsesockstr(sockstr)
  23. reader, writer = await asyncio.open_unix_connection(rem)
  24. return reader, writer
  25. async def listensockstr(sockstr, cb):
  26. '''Wrapper for asyncio.start_x_server.
  27. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  28. If the proto has a default parameter, the value can be used
  29. directly, like: 'proto:value'. This is only allowed when the
  30. value can unambiguously be determined not to be a param.
  31. The characters that define 'param' must be all lower case ascii
  32. characters and may contain an underscore. The first character
  33. must not be and underscore.
  34. Supported protocols:
  35. unix:
  36. Default parameter is path.
  37. The path parameter specifies the path to the
  38. unix domain socket. The path MUST start w/ a
  39. slash if it is used as a default parameter.
  40. '''
  41. proto, rem = _parsesockstr(sockstr)
  42. server = await asyncio.start_unix_server(cb, path=rem)
  43. return server
  44. # !!python makemessagelengths.py
  45. _handshakelens = \
  46. [72, 72, 88]
  47. def _genciphfun(hash, ad):
  48. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  49. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  50. key = hkdf.derive(hash)
  51. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  52. backend=_backend)
  53. enctor = cipher.encryptor()
  54. def encfun(data):
  55. # Returns the two bytes for length
  56. val = len(data)
  57. encbytes = enctor.update(data[:16])
  58. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  59. return (val ^ mask).to_bytes(length=2, byteorder='big')
  60. def decfun(data):
  61. # takes off the data and returns the total
  62. # length
  63. val = int.from_bytes(data[:2], byteorder='big')
  64. encbytes = enctor.update(data[2:2 + 16])
  65. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  66. return val ^ mask
  67. return encfun, decfun
  68. async def NoiseForwarder(mode, rdrwrr, ptpair, priv_key, pub_key=None):
  69. rdr, wrr = rdrwrr
  70. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  71. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  72. if pub_key is not None:
  73. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pub_key)
  74. if mode == 'resp':
  75. proto.set_as_responder()
  76. elif mode == 'init':
  77. proto.set_as_initiator()
  78. proto.start_handshake()
  79. if mode == 'resp':
  80. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  81. wrr.write(proto.write_message())
  82. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  83. elif mode == 'init':
  84. wrr.write(proto.write_message())
  85. proto.read_message(await rdr.readexactly(_handshakelens[1]))
  86. wrr.write(proto.write_message())
  87. if not proto.handshake_finished: # pragma: no cover
  88. raise RuntimeError('failed to finish handshake')
  89. # generate the keys for lengths
  90. if mode == 'resp':
  91. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  92. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  93. elif mode == 'init':
  94. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  95. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  96. reader, writer = await ptpair
  97. async def decses():
  98. try:
  99. while True:
  100. try:
  101. msg = await rdr.readexactly(2 + 16)
  102. except asyncio.streams.IncompleteReadError:
  103. if rdr.at_eof():
  104. return 'dec'
  105. tlen = declenfun(msg)
  106. rmsg = await rdr.readexactly(tlen - 16)
  107. tmsg = msg[2:] + rmsg
  108. writer.write(proto.decrypt(tmsg))
  109. await writer.drain()
  110. #except:
  111. # import traceback
  112. # traceback.print_exc()
  113. # raise
  114. finally:
  115. writer.write_eof()
  116. async def encses():
  117. try:
  118. while True:
  119. ptmsg = await reader.read(65535 - 16) # largest message
  120. if not ptmsg:
  121. # eof
  122. return 'enc'
  123. encmsg = proto.encrypt(ptmsg)
  124. wrr.write(enclenfun(encmsg))
  125. wrr.write(encmsg)
  126. await wrr.drain()
  127. #except:
  128. # import traceback
  129. # traceback.print_exc()
  130. # raise
  131. finally:
  132. wrr.write_eof()
  133. return await asyncio.gather(decses(), encses())
  134. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  135. # Slightly modified to timeout
  136. def async_test(f):
  137. def wrapper(*args, **kwargs):
  138. coro = asyncio.coroutine(f)
  139. future = coro(*args, **kwargs)
  140. loop = asyncio.get_event_loop()
  141. # timeout after 2 seconds
  142. loop.run_until_complete(asyncio.wait_for(future, 2))
  143. return wrapper
  144. class Tests_misc(unittest.TestCase):
  145. def test_listensockstr(self):
  146. # XXX write test
  147. pass
  148. def test_genciphfun(self):
  149. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  150. msg = b'this is a bunch of data'
  151. tb = enc(msg)
  152. self.assertEqual(len(msg), dec(tb + msg))
  153. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  154. msg = os.urandom(i)
  155. self.assertEqual(len(msg), dec(enc(msg) + msg))
  156. def _asyncsockpair():
  157. '''Create a pair of sockets that are bound to each other.
  158. The function will return a tuple of two coroutine's, that
  159. each, when await'ed upon, will return the reader/writer pair.'''
  160. socka, sockb = socket.socketpair()
  161. return asyncio.open_connection(sock=socka), asyncio.open_connection(sock=sockb)
  162. class Tests(unittest.TestCase):
  163. def setUp(self):
  164. # setup temporary directory
  165. d = os.path.realpath(tempfile.mkdtemp())
  166. self.basetempdir = d
  167. self.tempdir = os.path.join(d, 'subdir')
  168. os.mkdir(self.tempdir)
  169. # Generate key pairs
  170. self.server_key_pair = genkeypair()
  171. self.client_key_pair = genkeypair()
  172. def tearDown(self):
  173. shutil.rmtree(self.basetempdir)
  174. self.tempdir = None
  175. @async_test
  176. async def test_server(self):
  177. # Test is plumbed:
  178. # (reader, writer) -> servsock ->
  179. # (rdr, wrr) NoiseForward (reader, writer) ->
  180. # servptsock -> (ptsock[0], ptsock[1])
  181. # Path that the server will sit on
  182. servsockpath = os.path.join(self.tempdir, 'servsock')
  183. servarg = _makeunix(servsockpath)
  184. # Path that the server will send pt data to
  185. servptpath = os.path.join(self.tempdir, 'servptsock')
  186. # Setup pt target listener
  187. pttarg = _makeunix(servptpath)
  188. ptsock = []
  189. def ptsockaccept(reader, writer, ptsock=ptsock):
  190. ptsock.append((reader, writer))
  191. # Bind to pt listener
  192. lsock = await listensockstr(pttarg, ptsockaccept)
  193. nfs = []
  194. event = asyncio.Event()
  195. async def runnf(rdr, wrr):
  196. ptpair = asyncio.create_task(connectsockstr(pttarg))
  197. a = await NoiseForwarder('resp', (rdr, wrr), ptpair, priv_key=self.server_key_pair[1])
  198. nfs.append(a)
  199. event.set()
  200. # Setup server listener
  201. ssock = await listensockstr(servarg, runnf)
  202. # Connect to server
  203. reader, writer = await connectsockstr(servarg)
  204. # Create client
  205. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  206. proto.set_as_initiator()
  207. # Setup required keys
  208. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  209. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  210. proto.start_handshake()
  211. # Send first message
  212. message = proto.write_message()
  213. self.assertEqual(len(message), _handshakelens[0])
  214. writer.write(message)
  215. # Get response
  216. respmsg = await reader.readexactly(_handshakelens[1])
  217. proto.read_message(respmsg)
  218. # Send final reply
  219. message = proto.write_message()
  220. writer.write(message)
  221. # Make sure handshake has completed
  222. self.assertTrue(proto.handshake_finished)
  223. # generate the keys for lengths
  224. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  225. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  226. # write a test message
  227. ptmsg = b'this is a test message that should be a little in length'
  228. encmsg = proto.encrypt(ptmsg)
  229. writer.write(enclenfun(encmsg))
  230. writer.write(encmsg)
  231. # XXX - how to sync?
  232. await asyncio.sleep(.1)
  233. ptreader, ptwriter = ptsock[0]
  234. # read the test message
  235. rptmsg = await ptreader.readexactly(len(ptmsg))
  236. self.assertEqual(rptmsg, ptmsg)
  237. # write a different message
  238. ptmsg = os.urandom(2843)
  239. encmsg = proto.encrypt(ptmsg)
  240. writer.write(enclenfun(encmsg))
  241. writer.write(encmsg)
  242. # XXX - how to sync?
  243. await asyncio.sleep(.1)
  244. # read the test message
  245. rptmsg = await ptreader.readexactly(len(ptmsg))
  246. self.assertEqual(rptmsg, ptmsg)
  247. # now try the other way
  248. ptmsg = os.urandom(912)
  249. ptwriter.write(ptmsg)
  250. # find out how much we need to read
  251. encmsg = await reader.readexactly(2 + 16)
  252. tlen = declenfun(encmsg)
  253. # read the rest of the message
  254. rencmsg = await reader.readexactly(tlen - 16)
  255. tmsg = encmsg[2:] + rencmsg
  256. rptmsg = proto.decrypt(tmsg)
  257. self.assertEqual(rptmsg, ptmsg)
  258. # shut down sending
  259. writer.write_eof()
  260. # so pt reader should be shut down
  261. r = await ptreader.read(1)
  262. self.assertTrue(ptreader.at_eof())
  263. # shut down pt
  264. ptwriter.write_eof()
  265. # make sure the enc reader is eof
  266. r = await reader.read(1)
  267. self.assertTrue(reader.at_eof())
  268. await event.wait()
  269. self.assertEqual(nfs[0], [ 'dec', 'enc' ])
  270. @async_test
  271. async def test_serverclient(self):
  272. # plumbing:
  273. #
  274. # ptca -> ptcb NF client clsa -> clsb NF server ptsa -> ptsb
  275. #
  276. ptcsockapair, ptcsockbpair = _asyncsockpair()
  277. ptcareader, ptcawriter = await ptcsockapair
  278. #ptcsockbpair passed directly
  279. clssockapair, clssockbpair = _asyncsockpair()
  280. clsapair = await clssockapair
  281. clsbpair = await clssockbpair
  282. ptssockapair, ptssockbpair = _asyncsockpair()
  283. #ptssockapair passed directly
  284. ptsbreader, ptsbwriter = await ptssockbpair
  285. clientnf = asyncio.create_task(NoiseForwarder('init', clsapair, ptcsockbpair, priv_key=self.client_key_pair[1], pub_key=self.server_key_pair[0]))
  286. servnf = asyncio.create_task(NoiseForwarder('resp', clsbpair, ptssockapair, priv_key=self.server_key_pair[1]))
  287. # send a message
  288. msga = os.urandom(183)
  289. ptcawriter.write(msga)
  290. # make sure we get the same message
  291. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  292. # send a second message
  293. msga = os.urandom(2834)
  294. ptcawriter.write(msga)
  295. # make sure we get the same message
  296. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  297. # send a message larger than the block size
  298. msga = os.urandom(103958)
  299. ptcawriter.write(msga)
  300. # make sure we get the same message
  301. self.assertEqual(msga, await ptsbreader.readexactly(len(msga)))
  302. # send a message the other direction
  303. msga = os.urandom(103958)
  304. ptsbwriter.write(msga)
  305. # make sure we get the same message
  306. self.assertEqual(msga, await ptcareader.readexactly(len(msga)))
  307. # close down the pt writers, the rest should follow
  308. ptsbwriter.write_eof()
  309. ptcawriter.write_eof()
  310. self.assertEqual([ 'dec', 'enc' ], await clientnf)
  311. self.assertEqual([ 'dec', 'enc' ], await servnf)