An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

195 lines
5.8 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import twisted.internet.protocol
  11. import mock
  12. # Notes:
  13. # Using XK, so that the connecting party's identity is hidden and that the
  14. # server's party's key is known.
  15. def genkeypair():
  16. '''Generates a keypair, and returns a tuple of (public, private).
  17. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  18. key = x448.X448PrivateKey.generate()
  19. enc = serialization.Encoding.Raw
  20. pubformat = serialization.PublicFormat.Raw
  21. privformat = serialization.PrivateFormat.Raw
  22. encalgo = serialization.NoEncryption()
  23. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  24. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  25. return pub, priv
  26. class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
  27. '''This class acts as a Noise Protocol responder. The factory that
  28. creates this Protocol is required to have the properties server_key
  29. and endpoint.
  30. The server_key propery is the key for the server that the clients are
  31. required to have (due to Noise XK protocol used) to authenticate the
  32. server.
  33. The endpoint property contains the endpoint as a string that will be
  34. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  35. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  36. for information on how to use this property.'''
  37. def connectionMade(self):
  38. # Initialize Noise
  39. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  40. self.noise = noise
  41. noise.set_as_responder()
  42. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  43. # Start Handshake
  44. noise.start_handshake()
  45. def dataReceived(self, data):
  46. if not self.noise.handshake_finished:
  47. self.noise.read_message(data)
  48. if not self.noise.handshake_finished:
  49. self.transport.write(self.noise.write_message())
  50. if self.noise.handshake_finished:
  51. self.transport.pauseProducing()
  52. # start the connection to the endpoint
  53. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  54. epdef = ep.connect(ClientProxyFactory(self))
  55. epdef.addCallback(self.proxyConnected)
  56. else:
  57. r = self.noise.decrypt(data)
  58. self.endpoint.transport.write(r)
  59. def proxyConnected(self, endpoint):
  60. self.endpoint = endpoint
  61. self.transport.resumeProducing()
  62. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  63. pass
  64. class ClientProxyFactory(Factory):
  65. protocol = ClientProxyProtocol
  66. def __init__(self, noiseproto):
  67. self.noiseproto = noiseproto
  68. class TwistedNoiseServerFactory(Factory):
  69. protocol = TwistedNoiseServerProtocol
  70. def __init__(self, server_key, endpoint):
  71. self.server_key = server_key
  72. self.endpoint = endpoint
  73. class TNServerTest(unittest.TestCase):
  74. @defer.inlineCallbacks
  75. def setUp(self):
  76. self.server_key_pair = genkeypair()
  77. self.protos = []
  78. self.connectionmade = defer.Deferred()
  79. class AccProtFactory(Factory):
  80. protocol = proto_helpers.AccumulatingProtocol
  81. def __init__(self, tc):
  82. self.__tc = tc
  83. Factory.__init__(self)
  84. protocolConnectionMade = self.connectionmade
  85. def buildProtocol(self, addr):
  86. r = Factory.buildProtocol(self, addr)
  87. self.__tc.protos.append(r)
  88. return r
  89. for i in range(10000, 20000):
  90. ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1')
  91. try:
  92. lpobj = yield ep.listen(AccProtFactory(self))
  93. except Exception: # pragma: no cover
  94. continue
  95. break
  96. else: # pragma: no cover
  97. raise RuntimeError('all ports occupied')
  98. self.testserv = ep
  99. self.listenportobj = lpobj
  100. self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i
  101. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  102. self.proto = factory.buildProtocol(('127.0.0.1', 0))
  103. self.tr = proto_helpers.StringTransport()
  104. self.proto.makeConnection(self.tr)
  105. self.client_key_pair = genkeypair()
  106. def tearDown(self):
  107. self.listenportobj.stopListening()
  108. @defer.inlineCallbacks
  109. def test_testprotocol(self):
  110. # Create client
  111. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  112. proto.set_as_initiator()
  113. # Setup required keys
  114. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  115. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  116. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  117. proto.start_handshake()
  118. # Send first message
  119. message = proto.write_message()
  120. self.proto.dataReceived(message)
  121. # Get response
  122. resp = self.tr.value()
  123. self.tr.clear()
  124. # And process it
  125. proto.read_message(resp)
  126. # Send second message
  127. message = proto.write_message()
  128. self.proto.dataReceived(message)
  129. # assert handshake finished
  130. self.assertTrue(proto.handshake_finished)
  131. # Make sure incoming data is paused till we establish client
  132. # connection, otherwise no place to write the data
  133. self.assertEqual(self.tr.producerState, 'paused')
  134. # Wait for the connection to be made
  135. d = yield self.connectionmade
  136. d = yield task.deferLater(reactor, .1, bool, 1)
  137. # How to make this ready?
  138. self.assertEqual(self.tr.producerState, 'producing')
  139. # Encrypt the message
  140. ptmsg = b'this is a test message'
  141. encmsg = proto.encrypt(ptmsg)
  142. # Feed it into the protocol
  143. self.proto.dataReceived(encmsg)
  144. d = yield task.deferLater(reactor, .1, bool, 1)
  145. clientend = self.protos[0]
  146. self.assertEqual(clientend.data, ptmsg)
  147. clientend.transport.loseConnection()