An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

327 lines
10 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. __author__ = 'John-Mark Gurney'
  16. __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.'
  17. __license__ = '2-clause BSD license'
  18. # Copyright 2019 John-Mark Gurney.
  19. # All rights reserved.
  20. #
  21. # Redistribution and use in source and binary forms, with or without
  22. # modification, are permitted provided that the following conditions
  23. # are met:
  24. # 1. Redistributions of source code must retain the above copyright
  25. # notice, this list of conditions and the following disclaimer.
  26. # 2. Redistributions in binary form must reproduce the above copyright
  27. # notice, this list of conditions and the following disclaimer in the
  28. # documentation and/or other materials provided with the distribution.
  29. #
  30. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  31. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  32. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  33. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  34. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  35. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  36. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  37. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  38. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  39. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  40. # SUCH DAMAGE.
  41. # Notes:
  42. # Using XK, so that the connecting party's identity is hidden and that the
  43. # server's party's key is known.
  44. #
  45. # Noise packets are 16 bytes + length of data
  46. #
  47. # Proposed method to hide message lengths:
  48. # Immediately after handshake completes, each side generates and sends
  49. # an n byte key that will be used for encrypting (algo tbd) their own
  50. # byte counts. The length field will be encrypted via
  51. # E(pktnum, key) XOR 2 byte length.
  52. #
  53. # Note that authenticating the message length is NOT needed. This is
  54. # because the noise message blocks themselves are authenticated. The
  55. # worse that could happen is that a larger read (64k) is done, and then
  56. # the connection aborts because of decryption failure.
  57. #
  58. def _makeunix(path):
  59. '''Make a properly formed unix path socket string.'''
  60. return 'unix:%s' % path
  61. def genkeypair():
  62. '''Generates a keypair, and returns a tuple of (public, private).
  63. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  64. key = x448.X448PrivateKey.generate()
  65. enc = serialization.Encoding.Raw
  66. pubformat = serialization.PublicFormat.Raw
  67. privformat = serialization.PrivateFormat.Raw
  68. encalgo = serialization.NoEncryption()
  69. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  70. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  71. return pub, priv
  72. class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
  73. '''This class acts as a Noise Protocol responder. The factory that
  74. creates this Protocol is required to have the properties server_key
  75. and endpoint.
  76. The server_key propery is the key for the server that the clients are
  77. required to have (due to Noise XK protocol used) to authenticate the
  78. server.
  79. The endpoint property contains the endpoint as a string that will be
  80. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  81. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  82. for information on how to use this property.'''
  83. def connectionMade(self):
  84. # Initialize Noise
  85. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  86. self.noise = noise
  87. if self.mode == 'resp':
  88. noise.set_as_responder()
  89. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)
  90. # Start Handshake
  91. noise.start_handshake()
  92. def encData(self, data):
  93. '''Receive plain text data, encrypt it, and send it down the
  94. wire.'''
  95. self.transport.write(self.noise.encrypt(data))
  96. def dataReceived(self, data):
  97. '''Receive encrypted data, and write it to the endpoint that
  98. was connected via the plaintextConnected method.'''
  99. if not self.noise.handshake_finished:
  100. self.noise.read_message(data)
  101. if not self.noise.handshake_finished:
  102. self.transport.write(self.noise.write_message())
  103. if self.noise.handshake_finished:
  104. self.handshakeFinished()
  105. else:
  106. r = self.noise.decrypt(data)
  107. self.endpoint.transport.write(r)
  108. def handshakeFinished(self): # pragma: no cover
  109. '''This function is called when the handshake has been
  110. completed. This is used to start data flowing, and to
  111. do any necessary connection work.'''
  112. raise NotImplementedError
  113. def plaintextConnected(self, endpoint):
  114. '''Connect the plain text endpoint to the factory. All the
  115. decrypted data will be written to this protocol,
  116. (specifically, it's transport).'''
  117. self.endpoint = endpoint
  118. self.transport.resumeProducing()
  119. class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
  120. mode = 'resp'
  121. def handshakeFinished(self):
  122. self.transport.pauseProducing()
  123. # start the connection to the endpoint
  124. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  125. epdef = ep.connect(ServerPTProxyFactory(self))
  126. epdef.addCallback(self.plaintextConnected)
  127. class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
  128. mode = 'init'
  129. class ServerPTProxyProtocol(twisted.internet.protocol.Protocol):
  130. def dataReceived(self, data):
  131. self.factory.noiseproto.encData(data)
  132. class ServerPTProxyFactory(Factory):
  133. protocol = ServerPTProxyProtocol
  134. def __init__(self, noiseproto):
  135. self.noiseproto = noiseproto
  136. class TwistedNoiseServerFactory(Factory):
  137. protocol = TwistedNoiseServerProtocol
  138. def __init__(self, priv_key, endpoint):
  139. self.priv_key = priv_key
  140. self.endpoint = endpoint
  141. class TNServerTest(unittest.TestCase):
  142. @defer.inlineCallbacks
  143. def setUp(self):
  144. # setup temporary directory
  145. d = os.path.realpath(tempfile.mkdtemp())
  146. self.basetempdir = d
  147. self.tempdir = os.path.join(d, 'subdir')
  148. os.mkdir(self.tempdir)
  149. # Generate key pairs
  150. self.server_key_pair = genkeypair()
  151. self.client_key_pair = genkeypair()
  152. # Server's PT client will be here
  153. self.protos = []
  154. self.connectionmade = defer.Deferred()
  155. class AccProtFactory(Factory):
  156. protocol = proto_helpers.AccumulatingProtocol
  157. def __init__(self, tc):
  158. self.__tc = tc
  159. Factory.__init__(self)
  160. protocolConnectionMade = self.connectionmade
  161. def buildProtocol(self, addr):
  162. r = Factory.buildProtocol(self, addr)
  163. self.__tc.protos.append(r)
  164. return r
  165. # Setup PT client endpoint
  166. sockpath = os.path.join(self.tempdir, 'servptsock')
  167. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  168. lpobj = yield ep.listen(AccProtFactory(self))
  169. self.testserv = ep
  170. self.listenportobj = lpobj
  171. self.endpoint = _makeunix(sockpath)
  172. # Setup server, and configure where to connect to.
  173. self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)
  174. @defer.inlineCallbacks
  175. def tearDown(self):
  176. d = yield self.listenportobj.stopListening()
  177. shutil.rmtree(self.basetempdir)
  178. self.tempdir = None
  179. @defer.inlineCallbacks
  180. def test_testserver(self):
  181. #
  182. # How this test is plumbed:
  183. #
  184. # proto (NoiseConnection) -> self.tr (StringTransport) ->
  185. # self.proto (TwistedNoiseServerProtocol) ->
  186. # self.proto.endpoint (ServerPTProxyProtocol) -> unix sock ->
  187. # self.protos[0] (AccumulatingProtocol)
  188. #
  189. # Generate a server protocol, and bind it to a string
  190. # transport for testing
  191. self.proto = self.servfactory.buildProtocol(None)
  192. self.tr = proto_helpers.StringTransport()
  193. self.proto.makeConnection(self.tr)
  194. # Create client
  195. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  196. proto.set_as_initiator()
  197. # Setup required keys
  198. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  199. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  200. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  201. proto.start_handshake()
  202. # Send first message
  203. message = proto.write_message()
  204. self.proto.dataReceived(message)
  205. # Get response
  206. resp = self.tr.value()
  207. self.tr.clear()
  208. # And process it
  209. proto.read_message(resp)
  210. # Send second message
  211. message = proto.write_message()
  212. self.proto.dataReceived(message)
  213. # assert handshake finished
  214. self.assertTrue(proto.handshake_finished)
  215. # Make sure incoming data is paused till we establish client
  216. # connection, otherwise no place to write the data
  217. self.assertEqual(self.tr.producerState, 'paused')
  218. # Wait for the connection to be made
  219. d = yield self.connectionmade
  220. d = yield task.deferLater(reactor, .1, bool, 1)
  221. # How to make this ready?
  222. self.assertEqual(self.tr.producerState, 'producing')
  223. # Encrypt the message
  224. ptmsg = b'this is a test message'
  225. encmsg = proto.encrypt(ptmsg)
  226. # Feed it into the protocol
  227. self.proto.dataReceived(encmsg)
  228. # XXX - fix
  229. # wait to pass it through
  230. d = yield task.deferLater(reactor, .1, bool, 1)
  231. # fetch remote end out
  232. clientend = self.protos[0]
  233. self.assertEqual(clientend.data, ptmsg)
  234. # send a message the other direction
  235. rptmsg = b'this is a different test message going the other way'
  236. clientend.transport.write(rptmsg)
  237. # XXX - fix
  238. # wait to pass it through
  239. d = yield task.deferLater(reactor, .1, bool, 1)
  240. # receive it and decrypt it
  241. resp = self.tr.value()
  242. self.assertEqual(proto.decrypt(resp), rptmsg)
  243. # clean up connection
  244. clientend.transport.loseConnection()
  245. @defer.inlineCallbacks
  246. def test_clientserver(self):
  247. # Path that the client "listener" sits on.
  248. cptsockpath = os.path.join(self.tempdir, 'clientptsock')
  249. # Path that the server sits on
  250. servsockpath = os.path.join(self.tempdir, 'servsock')
  251. servep = endpoints.serverFromString(reactor, _makeunix(servsockpath))
  252. servlpobj = yield servep.listen(self.servfactory)
  253. d = yield servlpobj.stopListening()