diff --git a/twistednoise.py b/twistednoise.py index 4343ae4..e51fd3d 100644 --- a/twistednoise.py +++ b/twistednoise.py @@ -8,8 +8,11 @@ from twisted.internet import endpoints, reactor, defer, task from cryptography.hazmat.primitives.asymmetric import x448 from cryptography.hazmat.primitives import serialization -import twisted.internet.protocol import mock +import os.path +import shutil +import tempfile +import twisted.internet.protocol # Notes: # Using XK, so that the connecting party's identity is hidden and that the @@ -95,6 +98,11 @@ class TwistedNoiseServerFactory(Factory): class TNServerTest(unittest.TestCase): @defer.inlineCallbacks def setUp(self): + d = os.path.realpath(tempfile.mkdtemp()) + self.basetempdir = d + self.tempdir = os.path.join(d, 'subdir') + os.mkdir(self.tempdir) + self.server_key_pair = genkeypair() self.protos = [] self.connectionmade = defer.Deferred() @@ -113,21 +121,16 @@ class TNServerTest(unittest.TestCase): self.__tc.protos.append(r) return r - for i in range(10000, 20000): - ep = endpoints.TCP4ServerEndpoint(reactor, i, interface='127.0.0.1') - try: - lpobj = yield ep.listen(AccProtFactory(self)) - except Exception: # pragma: no cover - continue - break - else: # pragma: no cover - raise RuntimeError('all ports occupied') + sockpath = os.path.join(self.tempdir, 'clientsock') + ep = endpoints.UNIXServerEndpoint(reactor, sockpath) + lpobj = yield ep.listen(AccProtFactory(self)) self.testserv = ep self.listenportobj = lpobj - self.endpoint = 'tcp:host=127.0.0.1:port=%d' % i + self.endpoint = 'unix:path=%s' % sockpath + factory = TwistedNoiseServerFactory(server_key=self.server_key_pair[1], endpoint=self.endpoint) - self.proto = factory.buildProtocol(('127.0.0.1', 0)) + self.proto = factory.buildProtocol(None) self.tr = proto_helpers.StringTransport() self.proto.makeConnection(self.tr) @@ -136,6 +139,9 @@ class TNServerTest(unittest.TestCase): def tearDown(self): self.listenportobj.stopListening() + shutil.rmtree(self.basetempdir) + self.tempdir = None + @defer.inlineCallbacks def test_testprotocol(self): # Create client