|
|
@@ -0,0 +1,51 @@ |
|
|
|
import asyncio |
|
|
|
import socket |
|
|
|
from unittest import mock |
|
|
|
try: |
|
|
|
from asyncio import ensure_future |
|
|
|
except ImportError: |
|
|
|
ensure_future = asyncio.async |
|
|
|
|
|
|
|
|
|
|
|
def fake_coroutine(return_value): |
|
|
|
def coro(*args, **kwargs): |
|
|
|
if isinstance(return_value, Exception): |
|
|
|
raise return_value |
|
|
|
return return_value |
|
|
|
|
|
|
|
return mock.Mock(side_effect=asyncio.coroutine(coro)) |
|
|
|
|
|
|
|
|
|
|
|
def find_unused_port(): |
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
|
s.bind(('127.0.0.1', 0)) |
|
|
|
port = s.getsockname()[1] |
|
|
|
s.close() |
|
|
|
return port |
|
|
|
|
|
|
|
|
|
|
|
class SocksPrimitiveProtocol(asyncio.Protocol): |
|
|
|
def __init__(self, write_buff): |
|
|
|
self._write_buff = write_buff |
|
|
|
self._transport = None |
|
|
|
|
|
|
|
def connection_made(self, transport): |
|
|
|
self._transport = transport |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
|
self._transport.write(self._write_buff) |
|
|
|
|
|
|
|
def connection_lost(self, exc): |
|
|
|
self._transport.close() |
|
|
|
|
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def fake_socks_srv(loop, write_buff): |
|
|
|
port = find_unused_port() |
|
|
|
|
|
|
|
def factory(): |
|
|
|
return SocksPrimitiveProtocol(write_buff) |
|
|
|
|
|
|
|
server = yield from loop.create_server(factory, '127.0.0.1', port) |
|
|
|
return server, port |
|
|
|
|