An stunnel like program that utilizes the Noise protocol.
 
 

227 lines
6.7 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. # Notes:
  16. # Using XK, so that the connecting party's identity is hidden and that the
  17. # server's party's key is known.
  18. def genkeypair():
  19. '''Generates a keypair, and returns a tuple of (public, private).
  20. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  21. key = x448.X448PrivateKey.generate()
  22. enc = serialization.Encoding.Raw
  23. pubformat = serialization.PublicFormat.Raw
  24. privformat = serialization.PrivateFormat.Raw
  25. encalgo = serialization.NoEncryption()
  26. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  27. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  28. return pub, priv
  29. class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
  30. '''This class acts as a Noise Protocol responder. The factory that
  31. creates this Protocol is required to have the properties server_key
  32. and endpoint.
  33. The server_key propery is the key for the server that the clients are
  34. required to have (due to Noise XK protocol used) to authenticate the
  35. server.
  36. The endpoint property contains the endpoint as a string that will be
  37. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  38. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  39. for information on how to use this property.'''
  40. def connectionMade(self):
  41. # Initialize Noise
  42. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  43. self.noise = noise
  44. noise.set_as_responder()
  45. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  46. # Start Handshake
  47. noise.start_handshake()
  48. def encData(self, data):
  49. self.transport.write(self.noise.encrypt(data))
  50. def dataReceived(self, data):
  51. if not self.noise.handshake_finished:
  52. self.noise.read_message(data)
  53. if not self.noise.handshake_finished:
  54. self.transport.write(self.noise.write_message())
  55. if self.noise.handshake_finished:
  56. self.transport.pauseProducing()
  57. # start the connection to the endpoint
  58. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  59. epdef = ep.connect(ClientProxyFactory(self))
  60. epdef.addCallback(self.proxyConnected)
  61. else:
  62. r = self.noise.decrypt(data)
  63. self.endpoint.transport.write(r)
  64. def proxyConnected(self, endpoint):
  65. self.endpoint = endpoint
  66. self.transport.resumeProducing()
  67. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  68. def dataReceived(self, data):
  69. self.factory.noiseproto.encData(data)
  70. class ClientProxyFactory(Factory):
  71. protocol = ClientProxyProtocol
  72. def __init__(self, noiseproto):
  73. self.noiseproto = noiseproto
  74. class TwistedNoiseServerFactory(Factory):
  75. protocol = TwistedNoiseServerProtocol
  76. def __init__(self, server_key, endpoint):
  77. self.server_key = server_key
  78. self.endpoint = endpoint
  79. class TNServerTest(unittest.TestCase):
  80. @defer.inlineCallbacks
  81. def setUp(self):
  82. d = os.path.realpath(tempfile.mkdtemp())
  83. self.basetempdir = d
  84. self.tempdir = os.path.join(d, 'subdir')
  85. os.mkdir(self.tempdir)
  86. self.server_key_pair = genkeypair()
  87. self.protos = []
  88. self.connectionmade = defer.Deferred()
  89. class AccProtFactory(Factory):
  90. protocol = proto_helpers.AccumulatingProtocol
  91. def __init__(self, tc):
  92. self.__tc = tc
  93. Factory.__init__(self)
  94. protocolConnectionMade = self.connectionmade
  95. def buildProtocol(self, addr):
  96. r = Factory.buildProtocol(self, addr)
  97. self.__tc.protos.append(r)
  98. return r
  99. sockpath = os.path.join(self.tempdir, 'clientsock')
  100. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  101. lpobj = yield ep.listen(AccProtFactory(self))
  102. self.testserv = ep
  103. self.listenportobj = lpobj
  104. self.endpoint = 'unix:path=%s' % sockpath
  105. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  106. self.proto = factory.buildProtocol(None)
  107. self.tr = proto_helpers.StringTransport()
  108. self.proto.makeConnection(self.tr)
  109. self.client_key_pair = genkeypair()
  110. def tearDown(self):
  111. self.listenportobj.stopListening()
  112. shutil.rmtree(self.basetempdir)
  113. self.tempdir = None
  114. @defer.inlineCallbacks
  115. def test_testprotocol(self):
  116. #
  117. # How this test is plumbed:
  118. #
  119. # proto (NoiseConnection) -> self.tr (StringTransport) ->
  120. # self.proto (TwistedNoiseServerProtocol) ->
  121. # self.proto.endpoint (ClientProxyProtocol) -> unix sock ->
  122. # self.protos[0] (AccumulatingProtocol)
  123. #
  124. # Create client
  125. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  126. proto.set_as_initiator()
  127. # Setup required keys
  128. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  129. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  130. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  131. proto.start_handshake()
  132. # Send first message
  133. message = proto.write_message()
  134. self.proto.dataReceived(message)
  135. # Get response
  136. resp = self.tr.value()
  137. self.tr.clear()
  138. # And process it
  139. proto.read_message(resp)
  140. # Send second message
  141. message = proto.write_message()
  142. self.proto.dataReceived(message)
  143. # assert handshake finished
  144. self.assertTrue(proto.handshake_finished)
  145. # Make sure incoming data is paused till we establish client
  146. # connection, otherwise no place to write the data
  147. self.assertEqual(self.tr.producerState, 'paused')
  148. # Wait for the connection to be made
  149. d = yield self.connectionmade
  150. d = yield task.deferLater(reactor, .1, bool, 1)
  151. # How to make this ready?
  152. self.assertEqual(self.tr.producerState, 'producing')
  153. # Encrypt the message
  154. ptmsg = b'this is a test message'
  155. encmsg = proto.encrypt(ptmsg)
  156. # Feed it into the protocol
  157. self.proto.dataReceived(encmsg)
  158. # wait to pass it through
  159. d = yield task.deferLater(reactor, .1, bool, 1)
  160. # fetch remote end out
  161. clientend = self.protos[0]
  162. self.assertEqual(clientend.data, ptmsg)
  163. # send a message the other direction
  164. rptmsg = b'this is a different test message going the other way'
  165. clientend.transport.write(rptmsg)
  166. # wait to pass it through
  167. d = yield task.deferLater(reactor, .1, bool, 1)
  168. # receive it and decrypt it
  169. resp = self.tr.value()
  170. self.assertEqual(proto.decrypt(resp), rptmsg)
  171. # clean up connection
  172. clientend.transport.loseConnection()