An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

198 lines
6.0 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor,defer
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import twisted.internet.protocol
  11. import mock
  12. # Notes:
  13. # Using XK, so that the connecting party's identity is hidden and that the
  14. # server's party's key is known.
  15. def genkeypair():
  16. '''Generates a keypair, and returns a tuple of (public, private).
  17. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  18. key = x448.X448PrivateKey.generate()
  19. enc = serialization.Encoding.Raw
  20. pubformat = serialization.PublicFormat.Raw
  21. privformat = serialization.PrivateFormat.Raw
  22. encalgo = serialization.NoEncryption()
  23. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  24. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  25. return pub, priv
  26. class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
  27. '''This class acts as a Noise Protocol responder. The factory that
  28. creates this Protocol is required to have the properties server_key
  29. and endpoint.
  30. The server_key propery is the key for the server that the clients are
  31. required to have (due to Noise XK protocol used) to authenticate the
  32. server.
  33. The endpoint property contains the endpoint as a string that will be
  34. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  35. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  36. for information on how to use this property.'''
  37. def connectionMade(self):
  38. # Initialize Noise
  39. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  40. self.noise = noise
  41. noise.set_as_responder()
  42. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  43. # Start Handshake
  44. noise.start_handshake()
  45. def dataReceived(self, data):
  46. if not self.noise.handshake_finished:
  47. self.noise.read_message(data)
  48. if not self.noise.handshake_finished:
  49. self.transport.write(self.noise.write_message())
  50. if self.noise.handshake_finished:
  51. self.transport.pauseProducing()
  52. # start the connection to the endpoint
  53. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  54. epdef = ep.connect(ClientProxyFactory(self))
  55. epdef.addCallback(self.proxyConnected)
  56. else:
  57. r = self.noise.decrypt(data)
  58. self.endpoint.write(r)
  59. def proxyConnected(self, endpoint):
  60. print('pc')
  61. self.endpoint = endpoint
  62. self.transport.resumeProducing()
  63. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  64. pass
  65. class ClientProxyFactory(Factory):
  66. protocol = ClientProxyProtocol
  67. def __init__(self, noiseproto):
  68. self.noiseproto = noiseproto
  69. class TwistedNoiseServerFactory(Factory):
  70. protocol = TwistedNoiseServerProtocol
  71. def __init__(self, server_key, endpoint):
  72. self.server_key = server_key
  73. self.endpoint = endpoint
  74. class TNServerTest(unittest.TestCase):
  75. @defer.inlineCallbacks
  76. def setUp(self):
  77. self.server_key_pair = genkeypair()
  78. self.protos = []
  79. class AccProtFactory(Factory):
  80. protocol = proto_helpers.AccumulatingProtocol
  81. def __init__(self, tc):
  82. self.__tc = tc
  83. Factory.__init__(self)
  84. def buildProtocol(self, addr):
  85. r = Factory.buildProtocol(addr)
  86. self.__tc.append(r)
  87. return r
  88. for i in range(10000, 20000):
  89. ep = endpoints.TCP4ServerEndpoint(reactor, i)
  90. try:
  91. lpobj = yield ep.listen(AccProtFactory(self))
  92. except Exception:
  93. continue
  94. break
  95. else:
  96. raise RuntimeError('all ports occupied')
  97. self.testserv = ep
  98. self.listenportobj = lpobj
  99. self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i
  100. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  101. self.proto = factory.buildProtocol(('127.0.0.1', 0))
  102. self.tr = proto_helpers.StringTransport()
  103. self.proto.makeConnection(self.tr)
  104. self.client_key_pair = genkeypair()
  105. def tearDown(self):
  106. self.listenportobj.stopListening()
  107. @mock.patch('twisted.internet.endpoints.clientFromString')
  108. def test_testprotocol(self, cfs):
  109. # Create client
  110. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  111. proto.set_as_initiator()
  112. # Setup required keys
  113. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  114. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  115. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  116. proto.start_handshake()
  117. # Send first message
  118. message = proto.write_message()
  119. self.proto.dataReceived(message)
  120. # Get response
  121. resp = self.tr.value()
  122. self.tr.clear()
  123. # And process it
  124. proto.read_message(resp)
  125. clientconnection = defer.Deferred()
  126. cfs().connect.return_value = clientconnection
  127. # Send second message
  128. message = proto.write_message()
  129. self.proto.dataReceived(message)
  130. # assert handshake finished
  131. self.assertTrue(proto.handshake_finished)
  132. # Make sure incoming data is paused till we establish client
  133. # connection, otherwise no place to write the data
  134. self.assertEqual(self.tr.producerState, 'paused')
  135. # Make sure that clientFromString is called properly
  136. cfs.assert_called_with(reactor, self.endpoint)
  137. # And that it was connect'ed
  138. cfs().connect.assert_called()
  139. # and that ClientProxyFactory was called properly
  140. args = cfs().connect.call_args.args
  141. self.assertIsInstance(args[0], ClientProxyFactory)
  142. self.assertIs(args[0].noiseproto, self.proto)
  143. # Simulate that a connection has happened
  144. remoteend = proto_helpers.StringTransport()
  145. remoteproto = args[0].buildProtocol(None)
  146. remoteproto.makeConnection(remoteend)
  147. # Encrypt the message
  148. ptmsg = b'this is a test message'
  149. encmsg = proto.encrypt(ptmsg)
  150. # Feed it into the protocol
  151. self.proto.dataReceived(encmsg)
  152. self.assertEqual(remoteend.value(), ptmsg)