From ad0a6d52fd94ec2b3b6e0176a23a2a69db41564f Mon Sep 17 00:00:00 2001 From: nibrag Date: Sat, 14 May 2016 14:00:37 +0300 Subject: [PATCH] Added fake_coroutine helper. Use protocol for fake_socks_srv --- tests/helpers.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ tests/socks_serv.py | 25 ---------------------- 2 files changed, 51 insertions(+), 25 deletions(-) create mode 100644 tests/helpers.py delete mode 100644 tests/socks_serv.py diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..7d81462 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,51 @@ +import asyncio +import socket +from unittest import mock +try: + from asyncio import ensure_future +except ImportError: + ensure_future = asyncio.async + + +def fake_coroutine(return_value): + def coro(*args, **kwargs): + if isinstance(return_value, Exception): + raise return_value + return return_value + + return mock.Mock(side_effect=asyncio.coroutine(coro)) + + +def find_unused_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + return port + + +class SocksPrimitiveProtocol(asyncio.Protocol): + def __init__(self, write_buff): + self._write_buff = write_buff + self._transport = None + + def connection_made(self, transport): + self._transport = transport + + def data_received(self, data): + self._transport.write(self._write_buff) + + def connection_lost(self, exc): + self._transport.close() + + +@asyncio.coroutine +def fake_socks_srv(loop, write_buff): + port = find_unused_port() + + def factory(): + return SocksPrimitiveProtocol(write_buff) + + server = yield from loop.create_server(factory, '127.0.0.1', port) + return server, port + diff --git a/tests/socks_serv.py b/tests/socks_serv.py deleted file mode 100644 index 87628db..0000000 --- a/tests/socks_serv.py +++ /dev/null @@ -1,25 +0,0 @@ -import asyncio -import socket -import functools - - -def find_unused_port(): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('127.0.0.1', 0)) - port = s.getsockname()[1] - s.close() - return port - - -@asyncio.coroutine -def socks_handler(reader, writer, write_buff): - writer.write(write_buff) - - -@asyncio.coroutine -def fake_socks_srv(loop, write_buff): - port = find_unused_port() - handler = functools.partial(socks_handler, write_buff=write_buff) - srv = yield from asyncio.start_server( - handler, '127.0.0.1', port, family=socket.AF_INET, loop=loop) - return srv, port