An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

321 lines
8.1 KiB

  1. from noise.connection import NoiseConnection, Keypair
  2. from cryptography.hazmat.primitives.kdf.hkdf import HKDF
  3. from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
  4. from cryptography.hazmat.primitives import hashes
  5. from twistednoise import genkeypair
  6. from cryptography.hazmat.backends import default_backend
  7. import asyncio
  8. import os.path
  9. import shutil
  10. import socket
  11. import tempfile
  12. import threading
  13. import unittest
  14. _backend = default_backend()
  15. def _makeunix(path):
  16. '''Make a properly formed unix path socket string.'''
  17. return 'unix:%s' % path
  18. def _parsesockstr(sockstr):
  19. proto, rem = sockstr.split(':', 1)
  20. return proto, rem
  21. async def connectsockstr(sockstr):
  22. proto, rem = _parsesockstr(sockstr)
  23. reader, writer = await asyncio.open_unix_connection(rem)
  24. return reader, writer
  25. async def listensockstr(sockstr, cb):
  26. '''Wrapper for asyncio.start_x_server.
  27. The format of sockstr is: 'proto:param=value[,param2=value2]'.
  28. If the proto has a default parameter, the value can be used
  29. directly, like: 'proto:value'. This is only allowed when the
  30. value can unambiguously be determined not to be a param.
  31. The characters that define 'param' must be all lower case ascii
  32. characters and may contain an underscore. The first character
  33. must not be and underscore.
  34. Supported protocols:
  35. unix:
  36. Default parameter is path.
  37. The path parameter specifies the path to the
  38. unix domain socket. The path MUST start w/ a
  39. slash if it is used as a default parameter.
  40. '''
  41. proto, rem = _parsesockstr(sockstr)
  42. server = await asyncio.start_unix_server(cb, path=rem)
  43. return server
  44. # !!python makemessagelengths.py
  45. _handshakelens = \
  46. [72, 72, 88]
  47. def _genciphfun(hash, ad):
  48. hkdf = HKDF(algorithm=hashes.SHA256(), length=32,
  49. salt=b'asdoifjsldkjdsf', info=ad, backend=_backend)
  50. key = hkdf.derive(hash)
  51. cipher = Cipher(algorithms.AES(key), modes.ECB(),
  52. backend=_backend)
  53. enctor = cipher.encryptor()
  54. def encfun(data):
  55. # Returns the two bytes for length
  56. val = len(data)
  57. encbytes = enctor.update(data[:16])
  58. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  59. return (val ^ mask).to_bytes(length=2, byteorder='big')
  60. def decfun(data):
  61. # takes off the data and returns the total
  62. # length
  63. val = int.from_bytes(data[:2], byteorder='big')
  64. encbytes = enctor.update(data[2:2 + 16])
  65. mask = int.from_bytes(encbytes[:2], byteorder='big') & 0xff
  66. return val ^ mask
  67. return encfun, decfun
  68. async def NoiseForwarder(mode, priv_key, rdrwrr, ptsockstr):
  69. rdr, wrr = rdrwrr
  70. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  71. proto.set_keypair_from_private_bytes(Keypair.STATIC, priv_key)
  72. proto.set_as_responder()
  73. proto.start_handshake()
  74. proto.read_message(await rdr.readexactly(_handshakelens[0]))
  75. wrr.write(proto.write_message())
  76. proto.read_message(await rdr.readexactly(_handshakelens[2]))
  77. if not proto.handshake_finished: # pragma: no cover
  78. raise RuntimeError('failed to finish handshake')
  79. # generate the keys for lengths
  80. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toresp')
  81. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toinit')
  82. reader, writer = await connectsockstr(ptsockstr)
  83. async def decses():
  84. try:
  85. while True:
  86. try:
  87. msg = await rdr.readexactly(2 + 16)
  88. except asyncio.streams.IncompleteReadError:
  89. if rdr.at_eof():
  90. return 'dec'
  91. tlen = declenfun(msg)
  92. rmsg = await rdr.readexactly(tlen - 16)
  93. tmsg = msg[2:] + rmsg
  94. writer.write(proto.decrypt(tmsg))
  95. await writer.drain()
  96. finally:
  97. print('foo')
  98. # XXX - how to test
  99. #writer.write_eof()
  100. async def encses():
  101. while True:
  102. ptmsg = await reader.read(65535 - 16) # largest message
  103. encmsg = proto.encrypt(ptmsg)
  104. wrr.write(enclenfun(encmsg))
  105. wrr.write(encmsg)
  106. await wrr.drain()
  107. done, pending = await asyncio.wait((decses(), encses()), return_when=asyncio.FIRST_COMPLETED)
  108. for i in done:
  109. print('v:', repr(await i))
  110. done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
  111. for i in done:
  112. print('v:', repr(await i))
  113. return done
  114. class TestListenSocket(unittest.TestCase):
  115. def test_listensockstr(self):
  116. # XXX write test
  117. pass
  118. # https://stackoverflow.com/questions/23033939/how-to-test-python-3-4-asyncio-code
  119. def async_test(f):
  120. def wrapper(*args, **kwargs):
  121. coro = asyncio.coroutine(f)
  122. future = coro(*args, **kwargs)
  123. loop = asyncio.get_event_loop()
  124. # timeout after 2 seconds
  125. loop.run_until_complete(asyncio.wait_for(future, 2))
  126. return wrapper
  127. class Tests_misc(unittest.TestCase):
  128. def test_genciphfun(self):
  129. enc, dec = _genciphfun(b'0' * 32, b'foobar')
  130. msg = b'this is a bunch of data'
  131. tb = enc(msg)
  132. self.assertEqual(len(msg), dec(tb + msg))
  133. for i in [ 20, 1384, 64000, 23839, 65535 ]:
  134. msg = os.urandom(i)
  135. self.assertEqual(len(msg), dec(enc(msg) + msg))
  136. class Tests(unittest.TestCase):
  137. def setUp(self):
  138. # setup temporary directory
  139. d = os.path.realpath(tempfile.mkdtemp())
  140. self.basetempdir = d
  141. self.tempdir = os.path.join(d, 'subdir')
  142. os.mkdir(self.tempdir)
  143. # Generate key pairs
  144. self.server_key_pair = genkeypair()
  145. self.client_key_pair = genkeypair()
  146. def tearDown(self):
  147. shutil.rmtree(self.basetempdir)
  148. self.tempdir = None
  149. @async_test
  150. async def test_server(self):
  151. # Path that the server will sit on
  152. servsockpath = os.path.join(self.tempdir, 'servsock')
  153. servarg = _makeunix(servsockpath)
  154. # Path that the server will send pt data to
  155. servsockpath = os.path.join(self.tempdir, 'servptsock')
  156. # Setup pt target listener
  157. pttarg = _makeunix(servsockpath)
  158. ptsock = []
  159. def ptsockaccept(reader, writer, ptsock=ptsock):
  160. ptsock.append((reader, writer))
  161. # Bind to pt listener
  162. lsock = await listensockstr(pttarg, ptsockaccept)
  163. nfs = []
  164. event = asyncio.Event()
  165. async def runnf(rdr, wrr):
  166. print('a')
  167. a = await NoiseForwarder('resp', self.server_key_pair[1], (rdr, wrr), pttarg)
  168. print('b')
  169. nfs.append(a)
  170. print('c')
  171. event.set()
  172. print('d')
  173. # Setup server listener
  174. ssock = await listensockstr(servarg, runnf)
  175. # Connect to server
  176. reader, writer = await connectsockstr(servarg)
  177. # Create client
  178. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  179. proto.set_as_initiator()
  180. # Setup required keys
  181. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  182. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  183. proto.start_handshake()
  184. # Send first message
  185. message = proto.write_message()
  186. self.assertEqual(len(message), _handshakelens[0])
  187. writer.write(message)
  188. # Get response
  189. respmsg = await reader.readexactly(_handshakelens[1])
  190. proto.read_message(respmsg)
  191. # Send final reply
  192. message = proto.write_message()
  193. writer.write(message)
  194. # Make sure handshake has completed
  195. self.assertTrue(proto.handshake_finished)
  196. # generate the keys for lengths
  197. enclenfun, _ = _genciphfun(proto.get_handshake_hash(), b'toresp')
  198. _, declenfun = _genciphfun(proto.get_handshake_hash(), b'toinit')
  199. # write a test message
  200. ptmsg = b'this is a test message that should be a little in length'
  201. encmsg = proto.encrypt(ptmsg)
  202. writer.write(enclenfun(encmsg))
  203. writer.write(encmsg)
  204. # XXX - how to sync?
  205. await asyncio.sleep(.1)
  206. # read the test message
  207. rptmsg = await ptsock[0][0].readexactly(len(ptmsg))
  208. self.assertEqual(rptmsg, ptmsg)
  209. # write a different message
  210. ptmsg = os.urandom(2843)
  211. encmsg = proto.encrypt(ptmsg)
  212. writer.write(enclenfun(encmsg))
  213. writer.write(encmsg)
  214. # XXX - how to sync?
  215. await asyncio.sleep(.1)
  216. # read the test message
  217. rptmsg = await ptsock[0][0].readexactly(len(ptmsg))
  218. self.assertEqual(rptmsg, ptmsg)
  219. # now try the other way
  220. ptmsg = os.urandom(912)
  221. ptsock[0][1].write(ptmsg)
  222. # find out how much we need to read
  223. encmsg = await reader.readexactly(2 + 16)
  224. tlen = declenfun(encmsg)
  225. # read the rest of the message
  226. rencmsg = await reader.readexactly(tlen - 16)
  227. tmsg = encmsg[2:] + rencmsg
  228. rptmsg = proto.decrypt(tmsg)
  229. self.assertEqual(rptmsg, ptmsg)
  230. # shut everything down
  231. writer.write_eof()
  232. #ptsock[0][1].write_eof()
  233. # XXX - how to sync?
  234. await asyncio.sleep(.1)
  235. await event.wait()
  236. print(repr(nfs))