|
- import asyncio
- import aiohttp
- import contextlib
- import gc
- import os
- import socket
- import ssl
- import struct
- import threading
- from unittest import mock
- from aiohttp.server import ServerHttpProtocol
- try:
- from asyncio import ensure_future
- except ImportError:
- ensure_future = asyncio.async
-
-
- def fake_coroutine(return_value):
- def coro(*args, **kwargs):
- if isinstance(return_value, Exception):
- raise return_value
- return return_value
-
- return mock.Mock(side_effect=asyncio.coroutine(coro))
-
-
- def find_unused_port():
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.bind(('127.0.0.1', 0))
- port = s.getsockname()[1]
- s.close()
- return port
-
-
- @contextlib.contextmanager
- def fake_socks_srv(loop, write_buff):
- transports = []
-
- class SocksPrimitiveProtocol(asyncio.Protocol):
- _transport = None
-
- def connection_made(self, transport):
- self._transport = transport
- transports.append(transport)
-
- def data_received(self, data):
- self._transport.write(write_buff)
-
- port = find_unused_port()
-
- def factory():
- return SocksPrimitiveProtocol()
-
- srv = loop.run_until_complete(
- loop.create_server(factory, '127.0.0.1', port))
-
- yield port
-
- for tr in transports:
- tr.close()
-
- srv.close()
- loop.run_until_complete(srv.wait_closed())
- gc.collect()
-
-
- @contextlib.contextmanager
- def fake_socks4_srv(loop):
- port = find_unused_port()
- transports = []
- futures = []
-
- class Socks4Protocol(asyncio.StreamReaderProtocol):
- def __init__(self, _loop):
- self._loop = _loop
- reader = asyncio.StreamReader(loop=self._loop)
- super().__init__(reader, client_connected_cb=self.negotiate,
- loop=self._loop)
-
- def connection_made(self, transport):
- transports.append(transport)
- super().connection_made(transport)
-
- @asyncio.coroutine
- def negotiate(self, reader, writer):
- writer.write(b'\x00\x5a\x04W\x01\x01\x01\x01')
-
- data = yield from reader.read(9)
-
- dst_port = struct.unpack('>H', data[2:4])[0]
- dst_addr = data[4:8]
-
- if data[-1] != 0x00:
- while True:
- byte = yield from reader.read(1)
- if byte == 0x00:
- break
-
- if dst_addr == b'\x00\x00\x00\x01':
- dst_addr = bytearray()
-
- while True:
- byte = yield from reader.read(1)
- if byte == 0x00:
- break
- dst_addr.append(byte)
- else:
- dst_addr = socket.inet_ntoa(dst_addr)
-
- cl_reader, cl_writer = yield from asyncio.open_connection(
- host=dst_addr, port=dst_port, loop=self._loop
- )
- transports.append(cl_writer)
-
- cl_fut = ensure_future(
- self.retranslator(reader, cl_writer), loop=self._loop)
- dst_fut = ensure_future(
- self.retranslator(cl_reader, writer), loop=self._loop)
- futures.append(cl_fut)
- futures.append(dst_fut)
-
- @asyncio.coroutine
- def retranslator(self, reader, writer):
- data = bytearray()
- while True:
- try:
- byte = yield from reader.read(1)
- if not byte:
- break
- data.append(byte[0])
- writer.write(byte)
- yield from writer.drain()
- except:
- break
-
- def run(_fut):
- thread_loop = asyncio.new_event_loop()
- asyncio.set_event_loop(thread_loop)
-
- srv_coroutine = thread_loop.create_server(
- lambda: Socks4Protocol(thread_loop), '127.0.0.1', port)
- srv = thread_loop.run_until_complete(srv_coroutine)
-
- waiter = asyncio.Future(loop=thread_loop)
- loop.call_soon_threadsafe(
- _fut.set_result, (thread_loop, waiter))
-
- try:
- thread_loop.run_until_complete(waiter)
- finally:
- # close opened transports
- for tr in transports:
- tr.close()
- for ft in futures:
- if not ft.done():
- ft.set_result(1)
-
- srv.close()
- thread_loop.stop()
- thread_loop.close()
- gc.collect()
-
- fut = asyncio.Future(loop=loop)
- srv_thread = threading.Thread(target=run, args=(fut,))
- srv_thread.start()
-
- _thread_loop, _waiter = loop.run_until_complete(fut)
-
- yield port
- _thread_loop.call_soon_threadsafe(_waiter.set_result, None)
- srv_thread.join()
-
-
- @contextlib.contextmanager
- def http_srv(loop, *, listen_addr=('127.0.0.1', 0), use_ssl=False):
- transports = []
-
- class TestHttpServer(ServerHttpProtocol):
-
- def connection_made(self, transport):
- transports.append(transport)
- super().connection_made(transport)
-
- @asyncio.coroutine
- def handle_request(self, message, payload):
- response = aiohttp.Response(self.writer, 200, message.version)
-
- text = b'Test message'
- response.add_header('Content-type', 'text/plain')
- response.add_header('Content-length', str(len(text)))
- response.send_headers()
- response.write(text)
- response.write_eof()
-
- if use_ssl:
- here = os.path.join(os.path.dirname(__file__), '..', 'tests')
- keyfile = os.path.join(here, 'sample.key')
- certfile = os.path.join(here, 'sample.crt')
- sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
- sslcontext.load_cert_chain(certfile, keyfile)
- else:
- sslcontext = None
-
- def run(_fut):
- thread_loop = asyncio.new_event_loop()
- asyncio.set_event_loop(thread_loop)
-
- host, port = listen_addr
-
- srv_coroutine = thread_loop.create_server(
- lambda: TestHttpServer(), host, port, ssl=sslcontext)
- srv = thread_loop.run_until_complete(srv_coroutine)
-
- waiter = asyncio.Future(loop=thread_loop)
- loop.call_soon_threadsafe(
- _fut.set_result, (thread_loop, waiter,
- srv.sockets[0].getsockname()))
-
- try:
- thread_loop.run_until_complete(waiter)
- finally:
- # close opened transports
- for tr in transports:
- tr.close()
-
- srv.close()
- thread_loop.stop()
- thread_loop.close()
- gc.collect()
-
- fut = asyncio.Future(loop=loop)
- srv_thread = threading.Thread(target=run, args=(fut,))
- srv_thread.start()
-
- _thread_loop, _waiter, _addr = loop.run_until_complete(fut)
-
- url = '{}://{}:{}'.format(
- 'https' if use_ssl else 'http', *_addr)
-
- yield url
- _thread_loop.call_soon_threadsafe(_waiter.set_result, None)
- srv_thread.join()
|