Browse Source

make a basic version of this work...

tags/v0.1.0
John-Mark Gurney 5 years ago
parent
commit
2888bc3a73
1 changed files with 24 additions and 27 deletions
  1. +24
    -27
      twistednoise.py

+ 24
- 27
twistednoise.py View File

@@ -2,7 +2,7 @@ from twisted.trial import unittest
from twisted.test import proto_helpers from twisted.test import proto_helpers
from noise.connection import NoiseConnection, Keypair from noise.connection import NoiseConnection, Keypair
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.internet import endpoints, reactor,defer
from twisted.internet import endpoints, reactor, defer, task
# XXX - shouldn't need to access the underlying primitives, but that's what # XXX - shouldn't need to access the underlying primitives, but that's what
# noiseprotocol module requires. # noiseprotocol module requires.
from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives.asymmetric import x448
@@ -70,10 +70,9 @@ class TwistedNoiseServerProtocol(twisted.internet.protocol.Protocol):
else: else:
r = self.noise.decrypt(data) r = self.noise.decrypt(data)


self.endpoint.write(r)
self.endpoint.transport.write(r)


def proxyConnected(self, endpoint): def proxyConnected(self, endpoint):
print('pc')
self.endpoint = endpoint self.endpoint = endpoint
self.transport.resumeProducing() self.transport.resumeProducing()


@@ -98,6 +97,8 @@ class TNServerTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.server_key_pair = genkeypair() self.server_key_pair = genkeypair()
self.protos = [] self.protos = []
self.connectionmade = defer.Deferred()

class AccProtFactory(Factory): class AccProtFactory(Factory):
protocol = proto_helpers.AccumulatingProtocol protocol = proto_helpers.AccumulatingProtocol


@@ -105,19 +106,21 @@ class TNServerTest(unittest.TestCase):
self.__tc = tc self.__tc = tc
Factory.__init__(self) Factory.__init__(self)


protocolConnectionMade = self.connectionmade

def buildProtocol(self, addr): def buildProtocol(self, addr):
r = Factory.buildProtocol(addr)
self.__tc.append(r)
r = Factory.buildProtocol(self, addr)
self.__tc.protos.append(r)
return r return r


for i in range(10000, 20000): for i in range(10000, 20000):
ep = endpoints.TCP4ServerEndpoint(reactor, i)
ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1')
try: try:
lpobj = yield ep.listen(AccProtFactory(self)) lpobj = yield ep.listen(AccProtFactory(self))
except Exception:
except Exception: # pragma: no cover
continue continue
break break
else:
else: # pragma: no cover
raise RuntimeError('all ports occupied') raise RuntimeError('all ports occupied')


self.testserv = ep self.testserv = ep
@@ -133,8 +136,8 @@ class TNServerTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.listenportobj.stopListening() self.listenportobj.stopListening()


@mock.patch('twisted.internet.endpoints.clientFromString')
def test_testprotocol(self, cfs):
@defer.inlineCallbacks
def test_testprotocol(self):
# Create client # Create client
proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256') proto = NoiseConnection.from_name(b'Noise_XK_448_ChaChaPoly_SHA256')
proto.set_as_initiator() proto.set_as_initiator()
@@ -157,9 +160,6 @@ class TNServerTest(unittest.TestCase):
# And process it # And process it
proto.read_message(resp) proto.read_message(resp)


clientconnection = defer.Deferred()
cfs().connect.return_value = clientconnection

# Send second message # Send second message
message = proto.write_message() message = proto.write_message()
self.proto.dataReceived(message) self.proto.dataReceived(message)
@@ -171,21 +171,13 @@ class TNServerTest(unittest.TestCase):
# connection, otherwise no place to write the data # connection, otherwise no place to write the data
self.assertEqual(self.tr.producerState, 'paused') self.assertEqual(self.tr.producerState, 'paused')


# Make sure that clientFromString is called properly
cfs.assert_called_with(reactor, self.endpoint)

# And that it was connect'ed
cfs().connect.assert_called()
# Wait for the connection to be made
d = yield self.connectionmade


# and that ClientProxyFactory was called properly
args = cfs().connect.call_args.args
self.assertIsInstance(args[0], ClientProxyFactory)
self.assertIs(args[0].noiseproto, self.proto)
d = yield task.deferLater(reactor, .1, bool, 1)


# Simulate that a connection has happened
remoteend = proto_helpers.StringTransport()
remoteproto = args[0].buildProtocol(None)
remoteproto.makeConnection(remoteend)
# How to make this ready?
self.assertEqual(self.tr.producerState, 'producing')


# Encrypt the message # Encrypt the message
ptmsg = b'this is a test message' ptmsg = b'this is a test message'
@@ -194,4 +186,9 @@ class TNServerTest(unittest.TestCase):
# Feed it into the protocol # Feed it into the protocol
self.proto.dataReceived(encmsg) self.proto.dataReceived(encmsg)


self.assertEqual(remoteend.value(), ptmsg)
d = yield task.deferLater(reactor, .1, bool, 1)

clientend = self.protos[0]
self.assertEqual(clientend.data, ptmsg)

clientend.transport.loseConnection()

Loading…
Cancel
Save