An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

378 lines
12 KiB

  1. from noise.connection import NoiseConnection, Keypair
  2. from twisted.trial import unittest
  3. from twisted.test import proto_helpers
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. __author__ = 'John-Mark Gurney'
  16. __copyright__ = 'Copyright 2019 John-Mark Gurney. All rights reserved.'
  17. __license__ = '2-clause BSD license'
  18. # Copyright 2019 John-Mark Gurney.
  19. # All rights reserved.
  20. #
  21. # Redistribution and use in source and binary forms, with or without
  22. # modification, are permitted provided that the following conditions
  23. # are met:
  24. # 1. Redistributions of source code must retain the above copyright
  25. # notice, this list of conditions and the following disclaimer.
  26. # 2. Redistributions in binary form must reproduce the above copyright
  27. # notice, this list of conditions and the following disclaimer in the
  28. # documentation and/or other materials provided with the distribution.
  29. #
  30. # THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
  31. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  32. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  33. # ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
  34. # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
  35. # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
  36. # OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  37. # HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
  38. # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
  39. # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
  40. # SUCH DAMAGE.
  41. # Notes:
  42. # Using XK, so that the connecting party's identity is hidden and that the
  43. # server's party's key is known.
  44. #
  45. # The client and server class names are used to refer to the initiator and
  46. # responder sides. Even though both client and server each have a server
  47. # component (listen in on a socket to start comms), and a client component
  48. # (create connection).
  49. #
  50. # Noise packets are 16 bytes + length of data
  51. #
  52. # Proposed method to hide message lengths:
  53. # Immediately after handshake completes, each side generates and sends
  54. # an n byte key that will be used for encrypting (algo tbd) their own
  55. # byte counts. The length field will be encrypted via
  56. # E(pktnum, key) XOR 2 byte length.
  57. #
  58. # Note that authenticating the message length is NOT needed. This is
  59. # because the noise message blocks themselves are authenticated. The
  60. # worse that could happen is that a larger read (64k) is done, and then
  61. # the connection aborts because of decryption failure.
  62. #
  63. def _makeunix(path):
  64. '''Make a properly formed unix path socket string.'''
  65. return 'unix:%s' % path
  66. def genkeypair():
  67. '''Generates a keypair, and returns a tuple of (public, private).
  68. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  69. key = x448.X448PrivateKey.generate()
  70. enc = serialization.Encoding.Raw
  71. pubformat = serialization.PublicFormat.Raw
  72. privformat = serialization.PrivateFormat.Raw
  73. encalgo = serialization.NoEncryption()
  74. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  75. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  76. return pub, priv
  77. class TwistedNoiseProtocol(twisted.internet.protocol.Protocol):
  78. '''This class acts as a Noise Protocol responder. The factory that
  79. creates this Protocol is required to have the properties server_key
  80. and endpoint.
  81. The server_key propery is the key for the server that the clients are
  82. required to have (due to Noise XK protocol used) to authenticate the
  83. server.
  84. The endpoint property contains the endpoint as a string that will be
  85. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  86. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  87. for information on how to use this property.'''
  88. def connectionMade(self):
  89. # Initialize Noise
  90. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  91. self.noise = noise
  92. if self.mode == 'resp':
  93. noise.set_as_responder()
  94. elif self.mode == 'init':
  95. noise.set_as_initiator()
  96. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.priv_key)
  97. if hasattr(self.factory, 'pub_key'):
  98. noise.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.factory.pub_key)
  99. # Start Handshake
  100. noise.start_handshake()
  101. def encData(self, data):
  102. '''Receive plain text data, encrypt it, and send it down the
  103. wire.'''
  104. self.transport.write(self.noise.encrypt(data))
  105. def dataReceived(self, data):
  106. '''Receive encrypted data, and write it to the endpoint that
  107. was connected via the plaintextConnected method.'''
  108. if not self.noise.handshake_finished:
  109. self.noise.read_message(data)
  110. if not self.noise.handshake_finished:
  111. self.transport.write(self.noise.write_message())
  112. if self.noise.handshake_finished:
  113. self.handshakeFinished()
  114. else:
  115. r = self.noise.decrypt(data)
  116. self.endpoint.transport.write(r)
  117. def handshakeFinished(self): # pragma: no cover
  118. '''This function is called when the handshake has been
  119. completed. This is used to start data flowing, and to
  120. do any necessary connection work.'''
  121. raise NotImplementedError
  122. def plaintextConnected(self, endpoint):
  123. '''Connect the plain text endpoint to the factory. All the
  124. decrypted data will be written to this protocol,
  125. (specifically, it's transport).'''
  126. self.endpoint = endpoint
  127. self.transport.resumeProducing()
  128. class TwistedNoiseServerProtocol(TwistedNoiseProtocol):
  129. mode = 'resp'
  130. def handshakeFinished(self):
  131. self.transport.pauseProducing()
  132. # start the connection to the endpoint
  133. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  134. epdef = ep.connect(ServerPTProxyFactory(self))
  135. epdef.addCallback(self.plaintextConnected)
  136. class TwistedNoiseServerFactory(Factory):
  137. protocol = TwistedNoiseServerProtocol
  138. def __init__(self, priv_key, endpoint):
  139. self.priv_key = priv_key
  140. self.endpoint = endpoint
  141. # Supporting classes for TwistedNoiseServer
  142. class PTProxyProtocol(twisted.internet.protocol.Protocol):
  143. '''Simple protocol then when data is received, encrypts the data
  144. w/ the connected noise protocol.'''
  145. def dataReceived(self, data):
  146. self.factory.noiseproto.encData(data)
  147. class ServerPTProxyFactory(Factory):
  148. protocol = PTProxyProtocol
  149. def __init__(self, noiseproto):
  150. self.noiseproto = noiseproto
  151. class TwistedNoiseClientProtocol(TwistedNoiseProtocol):
  152. mode = 'init'
  153. class ClientPTFactory(Factory):
  154. protocol = TwistedNoiseClientProtocol
  155. def __init__(self, priv_key, servpub, sockstr):
  156. self.priv_key = priv_key
  157. self.pub_key = servpub
  158. self.sockstr = sockstr
  159. class TNServerTest(unittest.TestCase):
  160. @defer.inlineCallbacks
  161. def setUp(self):
  162. # setup temporary directory
  163. d = os.path.realpath(tempfile.mkdtemp())
  164. self.basetempdir = d
  165. self.tempdir = os.path.join(d, 'subdir')
  166. os.mkdir(self.tempdir)
  167. # Generate key pairs
  168. self.server_key_pair = genkeypair()
  169. self.client_key_pair = genkeypair()
  170. # Server's PT client will be here
  171. self.protos = []
  172. self.connectionmade = defer.Deferred()
  173. class AccProtFactory(Factory):
  174. protocol = proto_helpers.AccumulatingProtocol
  175. def __init__(self, tc):
  176. self.__tc = tc
  177. Factory.__init__(self)
  178. protocolConnectionMade = self.connectionmade
  179. def buildProtocol(self, addr):
  180. r = Factory.buildProtocol(self, addr)
  181. self.__tc.protos.append(r)
  182. return r
  183. self.AccProtFactory = AccProtFactory
  184. # Setup PT client endpoint
  185. sockpath = os.path.join(self.tempdir, 'servptsock')
  186. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  187. lpobj = yield ep.listen(AccProtFactory(self))
  188. self.testserv = ep
  189. self.listenportobj = lpobj
  190. self.endpoint = _makeunix(sockpath)
  191. # Setup server, and configure where to connect to.
  192. self.servfactory = TwistedNoiseServerFactory(priv_key=self.server_key_pair[1], endpoint=self.endpoint)
  193. @defer.inlineCallbacks
  194. def tearDown(self):
  195. d = yield self.listenportobj.stopListening()
  196. shutil.rmtree(self.basetempdir)
  197. self.tempdir = None
  198. @defer.inlineCallbacks
  199. def test_testserver(self):
  200. #
  201. # How this test is plumbed:
  202. #
  203. # proto (NoiseConnection) -> self.tr (StringTransport) ->
  204. # self.proto (TwistedNoiseServerProtocol) ->
  205. # self.proto.endpoint (PTProxyProtocol) -> unix sock ->
  206. # self.protos[0] (AccumulatingProtocol)
  207. #
  208. # Generate a server protocol, and bind it to a string
  209. # transport for testing
  210. self.proto = self.servfactory.buildProtocol(None)
  211. self.tr = proto_helpers.StringTransport()
  212. self.proto.makeConnection(self.tr)
  213. # Create client
  214. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  215. proto.set_as_initiator()
  216. # Setup required keys
  217. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  218. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  219. proto.start_handshake()
  220. # Send first message
  221. message = proto.write_message()
  222. self.proto.dataReceived(message)
  223. # Get response
  224. resp = self.tr.value()
  225. self.tr.clear()
  226. # And process it
  227. proto.read_message(resp)
  228. # Send second message
  229. message = proto.write_message()
  230. self.proto.dataReceived(message)
  231. # assert handshake finished
  232. self.assertTrue(proto.handshake_finished)
  233. # Make sure incoming data is paused till we establish client
  234. # connection, otherwise no place to write the data
  235. self.assertEqual(self.tr.producerState, 'paused')
  236. # Wait for the connection to be made
  237. d = yield self.connectionmade
  238. d = yield task.deferLater(reactor, .1, bool, 1)
  239. # How to make this ready?
  240. self.assertEqual(self.tr.producerState, 'producing')
  241. # Encrypt the message
  242. ptmsg = b'this is a test message'
  243. encmsg = proto.encrypt(ptmsg)
  244. # Feed it into the protocol
  245. self.proto.dataReceived(encmsg)
  246. # XXX - fix
  247. # wait to pass it through
  248. d = yield task.deferLater(reactor, .1, bool, 1)
  249. # fetch remote end out
  250. clientend = self.protos[0]
  251. self.assertEqual(clientend.data, ptmsg)
  252. # send a message the other direction
  253. rptmsg = b'this is a different test message going the other way'
  254. clientend.transport.write(rptmsg)
  255. # XXX - fix
  256. # wait to pass it through
  257. d = yield task.deferLater(reactor, .1, bool, 1)
  258. # receive it and decrypt it
  259. resp = self.tr.value()
  260. self.assertEqual(proto.decrypt(resp), rptmsg)
  261. if False:
  262. import time
  263. s = time.time()
  264. cnt = 40000
  265. blksz = 1024
  266. rnd = os.urandom(blksz)
  267. for i in range(0, cnt):
  268. proto.encrypt(rnd)
  269. e = time.time()
  270. print('%f MB/sec' % (1.0 * cnt * blksz / (e - s) / 1024 / 1024))
  271. # clean up connection
  272. clientend.transport.loseConnection()
  273. @defer.inlineCallbacks
  274. def test_clientserver(self):
  275. # Path that the client "listener" will sit on.
  276. cptsockpath = os.path.join(self.tempdir, 'clientptsock')
  277. # Path that the server will sit on
  278. servsockpath = os.path.join(self.tempdir, 'servsock')
  279. # Start up the server
  280. servep = endpoints.serverFromString(reactor, _makeunix(servsockpath))
  281. servlpobj = yield servep.listen(self.servfactory)
  282. # Start up the client half
  283. clientep = endpoints.serverFromString(reactor, _makeunix(cptsockpath))
  284. clientlpobj = yield clientep.listen(ClientPTFactory(self.client_key_pair[1], self.server_key_pair[0], _makeunix(servsockpath)))
  285. # Conenct to the client
  286. clptep = endpoints.clientFromString(reactor, _makeunix(cptsockpath))
  287. clptconobj = yield clptep.connect(self.AccProtFactory(self))
  288. # The client plain text connection
  289. clptproto = self.protos[-1]
  290. clptproto.transport.write('this is a test')
  291. # Clean up
  292. d = yield servlpobj.stopListening()
  293. d = yield clientlpobj.stopListening()