diff --git a/tests/helpers.py b/tests/helpers.py index a379c56..cd9ca66 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,8 +1,14 @@ import asyncio +import aiohttp import contextlib +import gc +import os import socket +import ssl +import struct +import threading from unittest import mock -import gc +from aiohttp.server import ServerHttpProtocol try: from asyncio import ensure_future except ImportError: @@ -56,3 +62,181 @@ def fake_socks_srv(loop, write_buff): srv.close() loop.run_until_complete(srv.wait_closed()) gc.collect() + + +@contextlib.contextmanager +def fake_socks4_srv(loop): + port = find_unused_port() + transports = [] + futures = [] + + class Socks4Protocol(asyncio.StreamReaderProtocol): + def __init__(self, _loop): + self._loop = _loop + reader = asyncio.StreamReader(loop=self._loop) + super().__init__(reader, client_connected_cb=self.negotiate, + loop=self._loop) + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + @asyncio.coroutine + def negotiate(self, reader, writer): + writer.write(b'\x00\x5a\x04W\x01\x01\x01\x01') + + data = yield from reader.read(9) + + dst_port = struct.unpack('>H', data[2:4])[0] + dst_addr = data[4:8] + + if data[-1] != 0x00: + while True: + byte = yield from reader.read(1) + if byte == 0x00: + break + + if dst_addr == b'\x00\x00\x00\x01': + dst_addr = bytearray() + + while True: + byte = yield from reader.read(1) + if byte == 0x00: + break + dst_addr.append(byte) + else: + dst_addr = socket.inet_ntoa(dst_addr) + + cl_reader, cl_writer = yield from asyncio.open_connection( + host=dst_addr, port=dst_port, loop=self._loop + ) + transports.append(cl_writer) + + cl_fut = ensure_future( + self.retranslator(reader, cl_writer), loop=self._loop) + dst_fut = ensure_future( + self.retranslator(cl_reader, writer), loop=self._loop) + futures.append(cl_fut) + futures.append(dst_fut) + + @asyncio.coroutine + def retranslator(self, reader, writer): + data = bytearray() + while True: + try: + byte = yield from reader.read(1) + if not byte: + break + data.append(byte[0]) + writer.write(byte) + yield from writer.drain() + except: + break + + def run(_fut): + thread_loop = asyncio.new_event_loop() + asyncio.set_event_loop(thread_loop) + + srv_coroutine = thread_loop.create_server( + lambda: Socks4Protocol(thread_loop), '127.0.0.1', port) + srv = thread_loop.run_until_complete(srv_coroutine) + + waiter = asyncio.Future(loop=thread_loop) + loop.call_soon_threadsafe( + _fut.set_result, (thread_loop, waiter)) + + try: + thread_loop.run_until_complete(waiter) + finally: + # close opened transports + for tr in transports: + tr.close() + for ft in futures: + if not ft.done(): + ft.set_result(1) + + srv.close() + thread_loop.stop() + thread_loop.close() + gc.collect() + + fut = asyncio.Future(loop=loop) + srv_thread = threading.Thread(target=run, args=(fut,)) + srv_thread.start() + + _thread_loop, _waiter = loop.run_until_complete(fut) + + yield port + _thread_loop.call_soon_threadsafe(_waiter.set_result, None) + srv_thread.join() + + +@contextlib.contextmanager +def http_srv(loop, *, listen_addr=('127.0.0.1', 0), use_ssl=False): + transports = [] + + class TestHttpServer(ServerHttpProtocol): + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + @asyncio.coroutine + def handle_request(self, message, payload): + response = aiohttp.Response(self.writer, 200, message.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + def run(_fut): + thread_loop = asyncio.new_event_loop() + asyncio.set_event_loop(thread_loop) + + host, port = listen_addr + + srv_coroutine = thread_loop.create_server( + lambda: TestHttpServer(), host, port, ssl=sslcontext) + srv = thread_loop.run_until_complete(srv_coroutine) + + waiter = asyncio.Future(loop=thread_loop) + loop.call_soon_threadsafe( + _fut.set_result, (thread_loop, waiter, + srv.sockets[0].getsockname())) + + try: + thread_loop.run_until_complete(waiter) + finally: + # close opened transports + for tr in transports: + tr.close() + + srv.close() + thread_loop.stop() + thread_loop.close() + gc.collect() + + fut = asyncio.Future(loop=loop) + srv_thread = threading.Thread(target=run, args=(fut,)) + srv_thread.start() + + _thread_loop, _waiter, _addr = loop.run_until_complete(fut) + + url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', *_addr) + + yield url + _thread_loop.call_soon_threadsafe(_waiter.set_result, None) + srv_thread.join() diff --git a/tests/sample.crt b/tests/sample.crt new file mode 100644 index 0000000..0acd5fc --- /dev/null +++ b/tests/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- \ No newline at end of file diff --git a/tests/sample.key b/tests/sample.key new file mode 100644 index 0000000..a75ac3c --- /dev/null +++ b/tests/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- \ No newline at end of file