An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

276 lines
8.9 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. __author__ = 'John-Mark Gurney'
  16. __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.'
  17. __license__ = '2-clause BSD license'
  18. # Copyright 2019 John-Mark Gurney.
  19. # All rights reserved.
  20. #
  21. # Redistribution and use in source and binary forms, with or without
  22. # modification, are permitted provided that the following conditions
  23. # are met:
  24. # 1. Redistributions of source code must retain the above copyright
  25. # notice, this list of conditions and the following disclaimer.
  26. # 2. Redistributions in binary form must reproduce the above copyright
  27. # notice, this list of conditions and the following disclaimer in the
  28. # documentation and/or other materials provided with the distribution.
  29. #
  30. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  31. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  32. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  33. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  34. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  35. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  36. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  37. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  38. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  39. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  40. # SUCH DAMAGE.
  41. # Notes:
  42. # Using XK, so that the connecting party's identity is hidden and that the
  43. # server's party's key is known.
  44. #
  45. # Noise packets are 16 bytes + length of data
  46. #
  47. # Proposed method to hide message lengths:
  48. # Immediately after handshake completes, each side generates and sends
  49. # an n byte key that will be used for encrypting (algo tbd) their own
  50. # byte counts. The length field will be encrypted via
  51. # E(pktnum, key) XOR 2 byte length.
  52. #
  53. # Note that authenticating the message length is NOT needed. This is
  54. # because the noise message blocks themselves are authenticated. The
  55. # worse that could happen is that a larger read (64k) is done, and then
  56. # the connection aborts because of decryption failure.
  57. #
  58. def genkeypair():
  59. '''Generates a keypair, and returns a tuple of (public, private).
  60. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  61. key = x448.X448PrivateKey.generate()
  62. enc = serialization.Encoding.Raw
  63. pubformat = serialization.PublicFormat.Raw
  64. privformat = serialization.PrivateFormat.Raw
  65. encalgo = serialization.NoEncryption()
  66. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  67. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  68. return pub, priv
  69. class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
  70. '''This class acts as a Noise Protocol responder. The factory that
  71. creates this Protocol is required to have the properties server_key
  72. and endpoint.
  73. The server_key propery is the key for the server that the clients are
  74. required to have (due to Noise XK protocol used) to authenticate the
  75. server.
  76. The endpoint property contains the endpoint as a string that will be
  77. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  78. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  79. for information on how to use this property.'''
  80. def connectionMade(self):
  81. # Initialize Noise
  82. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  83. self.noise = noise
  84. noise.set_as_responder()
  85. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  86. # Start Handshake
  87. noise.start_handshake()
  88. def encData(self, data):
  89. self.transport.write(self.noise.encrypt(data))
  90. def dataReceived(self, data):
  91. if not self.noise.handshake_finished:
  92. self.noise.read_message(data)
  93. if not self.noise.handshake_finished:
  94. self.transport.write(self.noise.write_message())
  95. if self.noise.handshake_finished:
  96. self.handshakeFinished()
  97. else:
  98. r = self.noise.decrypt(data)
  99. self.endpoint.transport.write(r)
  100. def handshakeFinished(self):
  101. raise NotImplementedError
  102. def plaintextConnected(self, endpoint):
  103. self.endpoint = endpoint
  104. self.transport.resumeProducing()
  105. class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
  106. def handshakeFinished(self):
  107. self.transport.pauseProducing()
  108. # start the connection to the endpoint
  109. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  110. epdef = ep.connect(ClientProxyFactory(self))
  111. epdef.addCallback(self.plaintextConnected)
  112. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  113. def dataReceived(self, data):
  114. self.factory.noiseproto.encData(data)
  115. class ClientProxyFactory(Factory):
  116. protocol = ClientProxyProtocol
  117. def __init__(self, noiseproto):
  118. self.noiseproto = noiseproto
  119. class TwistedNoiseServerFactory(Factory):
  120. protocol = TwistedNoiseServerProtocol
  121. def __init__(self, server_key, endpoint):
  122. self.server_key = server_key
  123. self.endpoint = endpoint
  124. class TNServerTest(unittest.TestCase):
  125. @defer.inlineCallbacks
  126. def setUp(self):
  127. d = os.path.realpath(tempfile.mkdtemp())
  128. self.basetempdir = d
  129. self.tempdir = os.path.join(d, 'subdir')
  130. os.mkdir(self.tempdir)
  131. self.server_key_pair = genkeypair()
  132. self.protos = []
  133. self.connectionmade = defer.Deferred()
  134. class AccProtFactory(Factory):
  135. protocol = proto_helpers.AccumulatingProtocol
  136. def __init__(self, tc):
  137. self.__tc = tc
  138. Factory.__init__(self)
  139. protocolConnectionMade = self.connectionmade
  140. def buildProtocol(self, addr):
  141. r = Factory.buildProtocol(self, addr)
  142. self.__tc.protos.append(r)
  143. return r
  144. sockpath = os.path.join(self.tempdir, 'clientsock')
  145. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  146. lpobj = yield ep.listen(AccProtFactory(self))
  147. self.testserv = ep
  148. self.listenportobj = lpobj
  149. self.endpoint = 'unix:path=%s' % sockpath
  150. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  151. self.proto = factory.buildProtocol(None)
  152. self.tr = proto_helpers.StringTransport()
  153. self.proto.makeConnection(self.tr)
  154. self.client_key_pair = genkeypair()
  155. def tearDown(self):
  156. self.listenportobj.stopListening()
  157. shutil.rmtree(self.basetempdir)
  158. self.tempdir = None
  159. @defer.inlineCallbacks
  160. def test_testserver(self):
  161. #
  162. # How this test is plumbed:
  163. #
  164. # proto (NoiseConnection) -> self.tr (StringTransport) ->
  165. # self.proto (TwistedNoiseServerProtocol) ->
  166. # self.proto.endpoint (ClientProxyProtocol) -> unix sock ->
  167. # self.protos[0] (AccumulatingProtocol)
  168. #
  169. # Create client
  170. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  171. proto.set_as_initiator()
  172. # Setup required keys
  173. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  174. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  175. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  176. proto.start_handshake()
  177. # Send first message
  178. message = proto.write_message()
  179. self.proto.dataReceived(message)
  180. # Get response
  181. resp = self.tr.value()
  182. self.tr.clear()
  183. # And process it
  184. proto.read_message(resp)
  185. # Send second message
  186. message = proto.write_message()
  187. self.proto.dataReceived(message)
  188. # assert handshake finished
  189. self.assertTrue(proto.handshake_finished)
  190. # Make sure incoming data is paused till we establish client
  191. # connection, otherwise no place to write the data
  192. self.assertEqual(self.tr.producerState, 'paused')
  193. # Wait for the connection to be made
  194. d = yield self.connectionmade
  195. d = yield task.deferLater(reactor, .1, bool, 1)
  196. # How to make this ready?
  197. self.assertEqual(self.tr.producerState, 'producing')
  198. # Encrypt the message
  199. ptmsg = b'this is a test message'
  200. encmsg = proto.encrypt(ptmsg)
  201. # Feed it into the protocol
  202. self.proto.dataReceived(encmsg)
  203. # wait to pass it through
  204. d = yield task.deferLater(reactor, .1, bool, 1)
  205. # fetch remote end out
  206. clientend = self.protos[0]
  207. self.assertEqual(clientend.data, ptmsg)
  208. # send a message the other direction
  209. rptmsg = b'this is a different test message going the other way'
  210. clientend.transport.write(rptmsg)
  211. # wait to pass it through
  212. d = yield task.deferLater(reactor, .1, bool, 1)
  213. # receive it and decrypt it
  214. resp = self.tr.value()
  215. self.assertEqual(proto.decrypt(resp), rptmsg)
  216. # clean up connection
  217. clientend.transport.loseConnection()