An stunnel like program that utilizes the Noise protocol.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

204 lines
6.0 KiB

  1. from twisted.trial import unittest
  2. from twisted.test import proto_helpers
  3. from noise.connection import NoiseConnection, Keypair
  4. from twisted.internet.protocol import Factory
  5. from twisted.internet import endpoints, reactor, defer, task
  6. # XXX - shouldn't need to access the underlying primitives, but that's what
  7. # noiseprotocol module requires.
  8. from cryptography.hazmat.primitives.asymmetric import x448
  9. from cryptography.hazmat.primitives import serialization
  10. import mock
  11. import os.path
  12. import shutil
  13. import tempfile
  14. import twisted.internet.protocol
  15. # Notes:
  16. # Using XK, so that the connecting party's identity is hidden and that the
  17. # server's party's key is known.
  18. def genkeypair():
  19. '''Generates a keypair, and returns a tuple of (public, private).
  20. They are encoded as raw bytes, and sutible for use w/ Noise.'''
  21. key = x448.X448PrivateKey.generate()
  22. enc = serialization.Encoding.Raw
  23. pubformat = serialization.PublicFormat.Raw
  24. privformat = serialization.PrivateFormat.Raw
  25. encalgo = serialization.NoEncryption()
  26. pub = key.public_key().public_bytes(encoding=enc, format=pubformat)
  27. priv = key.private_bytes(encoding=enc, format=privformat, encryption_algorithm=encalgo)
  28. return pub, priv
  29. class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
  30. '''This class acts as a Noise Protocol responder. The factory that
  31. creates this Protocol is required to have the properties server_key
  32. and endpoint.
  33. The server_key propery is the key for the server that the clients are
  34. required to have (due to Noise XK protocol used) to authenticate the
  35. server.
  36. The endpoint property contains the endpoint as a string that will be
  37. used w/ clientFromString, see https://twistedmatrix.com/documents/current/api/twisted.internet.endpoints.html#clientFromString
  38. and https://twistedmatrix.com/documents/current/core/howto/endpoints.html#clients
  39. for information on how to use this property.'''
  40. def connectionMade(self):
  41. # Initialize Noise
  42. noise = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  43. self.noise = noise
  44. noise.set_as_responder()
  45. noise.set_keypair_from_private_bytes(Keypair.STATIC, self.factory.server_key)
  46. # Start Handshake
  47. noise.start_handshake()
  48. def dataReceived(self, data):
  49. if not self.noise.handshake_finished:
  50. self.noise.read_message(data)
  51. if not self.noise.handshake_finished:
  52. self.transport.write(self.noise.write_message())
  53. if self.noise.handshake_finished:
  54. self.transport.pauseProducing()
  55. # start the connection to the endpoint
  56. ep = endpoints.clientFromString(reactor, self.factory.endpoint)
  57. epdef = ep.connect(ClientProxyFactory(self))
  58. epdef.addCallback(self.proxyConnected)
  59. else:
  60. r = self.noise.decrypt(data)
  61. self.endpoint.transport.write(r)
  62. def proxyConnected(self, endpoint):
  63. self.endpoint = endpoint
  64. self.transport.resumeProducing()
  65. class ClientProxyProtocol(twisted.internet.protocol.Protocol):
  66. pass
  67. class ClientProxyFactory(Factory):
  68. protocol = ClientProxyProtocol
  69. def __init__(self, noiseproto):
  70. self.noiseproto = noiseproto
  71. class TwistedNoiseServerFactory(Factory):
  72. protocol = TwistedNoiseServerProtocol
  73. def __init__(self, server_key, endpoint):
  74. self.server_key = server_key
  75. self.endpoint = endpoint
  76. class TNServerTest(unittest.TestCase):
  77. @defer.inlineCallbacks
  78. def setUp(self):
  79. d = os.path.realpath(tempfile.mkdtemp())
  80. self.basetempdir = d
  81. self.tempdir = os.path.join(d, 'subdir')
  82. os.mkdir(self.tempdir)
  83. self.server_key_pair = genkeypair()
  84. self.protos = []
  85. self.connectionmade = defer.Deferred()
  86. class AccProtFactory(Factory):
  87. protocol = proto_helpers.AccumulatingProtocol
  88. def __init__(self, tc):
  89. self.__tc = tc
  90. Factory.__init__(self)
  91. protocolConnectionMade = self.connectionmade
  92. def buildProtocol(self, addr):
  93. r = Factory.buildProtocol(self, addr)
  94. self.__tc.protos.append(r)
  95. return r
  96. sockpath = os.path.join(self.tempdir, 'clientsock')
  97. ep = endpoints.UNIXServerEndpoint(reactor, sockpath)
  98. lpobj = yield ep.listen(AccProtFactory(self))
  99. self.testserv = ep
  100. self.listenportobj = lpobj
  101. self.endpoint = 'unix:path=%s' % sockpath
  102. factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint)
  103. self.proto = factory.buildProtocol(None)
  104. self.tr = proto_helpers.StringTransport()
  105. self.proto.makeConnection(self.tr)
  106. self.client_key_pair = genkeypair()
  107. def tearDown(self):
  108. self.listenportobj.stopListening()
  109. shutil.rmtree(self.basetempdir)
  110. self.tempdir = None
  111. @defer.inlineCallbacks
  112. def test_testprotocol(self):
  113. # Create client
  114. proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
  115. proto.set_as_initiator()
  116. # Setup required keys
  117. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  118. proto.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, self.server_key_pair[0])
  119. proto.set_keypair_from_private_bytes(Keypair.STATIC, self.client_key_pair[1])
  120. proto.start_handshake()
  121. # Send first message
  122. message = proto.write_message()
  123. self.proto.dataReceived(message)
  124. # Get response
  125. resp = self.tr.value()
  126. self.tr.clear()
  127. # And process it
  128. proto.read_message(resp)
  129. # Send second message
  130. message = proto.write_message()
  131. self.proto.dataReceived(message)
  132. # assert handshake finished
  133. self.assertTrue(proto.handshake_finished)
  134. # Make sure incoming data is paused till we establish client
  135. # connection, otherwise no place to write the data
  136. self.assertEqual(self.tr.producerState, 'paused')
  137. # Wait for the connection to be made
  138. d = yield self.connectionmade
  139. d = yield task.deferLater(reactor, .1, bool, 1)
  140. # How to make this ready?
  141. self.assertEqual(self.tr.producerState, 'producing')
  142. # Encrypt the message
  143. ptmsg = b'this is a test message'
  144. encmsg = proto.encrypt(ptmsg)
  145. # Feed it into the protocol
  146. self.proto.dataReceived(encmsg)
  147. # wait to pass it through
  148. d = yield task.deferLater(reactor, .1, bool, 1)
  149. # fetch it out
  150. clientend = self.protos[0]
  151. self.assertEqual(clientend.data, ptmsg)
  152. # clean up connection
  153. clientend.transport.loseConnection()