From b8c62da7a6084d5b59923885821fe3cfeb225f7f Mon Sep 17 00:00:00 2001 From: nibrag Date: Sat, 15 Apr 2017 19:16:33 +0300 Subject: [PATCH] Dropped supporting python 3.4, added supporting aiohttp 2.0, migrate to pytest --- README.rst | 65 +-- aiosocks/__init__.py | 20 +- aiosocks/connector.py | 153 ++--- aiosocks/protocols.py | 68 +-- aiosocks/test_utils.py | 136 +++++ tests/__init__.py | 0 tests/conftest.py | 1 + tests/helpers.py | 242 -------- tests/test_connector.py | 246 ++++---- tests/test_create_connect.py | 227 +++----- tests/test_functional.py | 617 +++++++++----------- tests/test_protocols.py | 1049 +++++++++++++++++----------------- 12 files changed, 1297 insertions(+), 1527 deletions(-) create mode 100644 aiosocks/test_utils.py delete mode 100644 tests/__init__.py create mode 100644 tests/conftest.py delete mode 100644 tests/helpers.py diff --git a/README.rst b/README.rst index e5e37e7..f750325 100644 --- a/README.rst +++ b/README.rst @@ -11,6 +11,12 @@ SOCKS proxy client for asyncio and aiohttp .. image:: https://badge.fury.io/py/aiosocks.svg :target: https://badge.fury.io/py/aiosocks + +Dependencies +------------ +python 3.5+ +aiohttp 2.0+ + Features -------- - SOCKS4, SOCKS4a and SOCKS5 version @@ -136,56 +142,47 @@ aiohttp usage import asyncio import aiohttp import aiosocks - from aiosocks.connector import ( - SocksConnector, HttpProxyAddr, HttpProxyAuth - ) + from yarl import URL + from aiosocks.connector import ProxyConnecotr, ProxyClientRequest async def load_github_main(): - addr = aiosocks.Socks5Addr('127.0.0.1', 1080) - auth = aiosocks.Socks5Auth('proxyuser1', password='pwd') + auth5 = aiosocks.Socks5Auth('proxyuser1', password='pwd') + auth4 = aiosocks.Socks4Auth('proxyuser1') + ba = aiohttp.BasicAuth('login') # remote resolve - conn = SocksConnector(proxy=addr, proxy_auth=auth, remote_resolve=True) + conn = ProxyConnecotr(remote_resolve=True) # or locale resolve - conn = SocksConnector(proxy=addr, proxy_auth=auth, remote_resolve=False) + conn = SocksConnector(remote_resolve=False) try: - with aiohttp.ClientSession(connector=conn) as session: - async with session.get('http://github.com/') as resp: + with aiohttp.ClientSession(connector=conn, request_class=ProxyClientRequest) as session: + # socks5 proxy + async with session.get('http://github.com/', proxy=URL('socks5://127.0.0.1:1080'), + proxy_auth=auth5) as resp: + if resp.status == 200: + print(await resp.text()) + + # socks4 proxy + async with session.get('http://github.com/', proxy=URL('socks4://127.0.0.1:1081'), + proxy_auth=auth4) as resp: + if resp.status == 200: + print(await resp.text()) + + # http proxy + async with session.get('http://github.com/', proxy=URL('http://127.0.0.1:8080'), + proxy_auth=ba) as resp: if resp.status == 200: print(await resp.text()) except aiohttp.ProxyConnectionError: # connection problem except aiosocks.SocksError: # communication problem - - + + if __name__ == '__main__': loop = asyncio.get_event_loop() loop.run_until_complete(load_github_main()) loop.close() - -proxy_connector -^^^^^^^^^^^^^^^ -A unified method to create `connector`. - -.. code-block:: python - - import asyncio - import aiohttp - import aiosocks - from aiosocks.connector import ( - proxy_connector, HttpProxyAddr, HttpProxyAuth - ) - - # make SocksConnector - conn = proxy_connector(aiosocks.Socks5Addr(...), - remote_resolve=True, verify_ssl=False) - # return SocksConnector instance - - # make aiohttp.ProxyConnector (http proxy) - conn = proxy_connector(HttpProxyAddr('http://proxy'), - HttpProxyAuth('login', 'pwd'), verify_ssl=True) - # return aiohttp.ProxyConnector instance diff --git a/aiosocks/__init__.py b/aiosocks/__init__.py index c4b0a6d..91572b4 100644 --- a/aiosocks/__init__.py +++ b/aiosocks/__init__.py @@ -17,11 +17,10 @@ __all__ = ('Socks4Protocol', 'Socks5Protocol', 'Socks4Auth', 'InvalidServerReply', 'create_connection', 'open_connection') -@asyncio.coroutine -def create_connection(protocol_factory, proxy, proxy_auth, dst, *, - remote_resolve=True, loop=None, ssl=None, family=0, - proto=0, flags=0, sock=None, local_addr=None, - server_hostname=None, reader_limit=DEFAULT_LIMIT): +async def create_connection(protocol_factory, proxy, proxy_auth, dst, *, + remote_resolve=True, loop=None, ssl=None, family=0, + proto=0, flags=0, sock=None, local_addr=None, + server_hostname=None, reader_limit=DEFAULT_LIMIT): assert isinstance(proxy, SocksAddr), ( 'proxy must be Socks4Addr() or Socks5Addr() tuple' ) @@ -70,7 +69,7 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, reader_limit=reader_limit) try: - transport, protocol = yield from loop.create_connection( + transport, protocol = await loop.create_connection( socks_factory, proxy.host, proxy.port, family=family, proto=proto, flags=flags, sock=sock, local_addr=local_addr) except OSError as exc: @@ -79,7 +78,7 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, (exc.errno, proxy.host, proxy.port, exc.strerror)) from exc try: - yield from waiter + await waiter except: transport.close() raise @@ -87,10 +86,9 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, return protocol.app_transport, protocol.app_protocol -@asyncio.coroutine -def open_connection(proxy, proxy_auth, dst, *, remote_resolve=True, - loop=None, limit=DEFAULT_LIMIT, **kwds): - _, protocol = yield from create_connection( +async def open_connection(proxy, proxy_auth, dst, *, remote_resolve=True, + loop=None, limit=DEFAULT_LIMIT, **kwds): + _, protocol = await create_connection( None, proxy, proxy_auth, dst, reader_limit=limit, remote_resolve=remote_resolve, loop=loop, **kwds) diff --git a/aiosocks/connector.py b/aiosocks/connector.py index 2af54d5..4bd9300 100644 --- a/aiosocks/connector.py +++ b/aiosocks/connector.py @@ -1,105 +1,99 @@ try: import aiohttp - from aiohttp.errors import ProxyConnectionError - from aiohttp.helpers import BasicAuth as HttpProxyAuth + from aiohttp.connector import sentinel except ImportError: raise ImportError('aiosocks.SocksConnector require aiohttp library') -import asyncio -from collections import namedtuple from .errors import SocksError, SocksConnectionError -from .helpers import SocksAddr +from .helpers import Socks4Auth, Socks5Auth, Socks4Addr, Socks5Addr from . import create_connection -__all__ = ('SocksConnector', 'HttpProxyAddr', 'HttpProxyAuth') +__all__ = ('ProxyConnector', 'ProxyClientRequest') + + +class ProxyClientRequest(aiohttp.ClientRequest): + def update_proxy(self, proxy, proxy_auth): + if proxy and proxy.scheme not in ['http', 'socks4', 'socks5']: + raise ValueError( + "Only http, socks4 and socks5 proxies are supported") + if proxy and proxy_auth: + if proxy.scheme == 'http' and \ + not isinstance(proxy_auth, aiohttp.BasicAuth): + raise ValueError("proxy_auth must be None or " + "BasicAuth() tuple for http proxy") + if proxy.scheme == 'socks4' and \ + not isinstance(proxy_auth, Socks4Auth): + raise ValueError("proxy_auth must be None or Socks4Auth() " + "tuple for socks4 proxy") + if proxy.scheme == 'socks5' and \ + not isinstance(proxy_auth, Socks5Auth): + raise ValueError("proxy_auth must be None or Socks5Auth() " + "tuple for socks5 proxy") + + self.proxy = proxy + self.proxy_auth = proxy_auth + + +class ProxyConnector(aiohttp.TCPConnector): + def __init__(self, *, verify_ssl=True, fingerprint=None, + resolve=sentinel, use_dns_cache=True, + family=0, ssl_context=None, local_addr=None, + resolver=None, keepalive_timeout=sentinel, + force_close=False, limit=100, limit_per_host=0, + enable_cleanup_closed=False, loop=None, remote_resolve=True): + super().__init__( + verify_ssl=verify_ssl, fingerprint=fingerprint, resolve=resolve, + family=family, ssl_context=ssl_context, local_addr=local_addr, + resolver=resolver, keepalive_timeout=keepalive_timeout, + force_close=force_close, limit=limit, loop=loop, + limit_per_host=limit_per_host, use_dns_cache=use_dns_cache, + enable_cleanup_closed=enable_cleanup_closed) - -class HttpProxyAddr(namedtuple('HttpProxyAddr', ['url'])): - def __new__(cls, url): - if url is None: - raise ValueError('None is not allowed as url value') - return super().__new__(cls, url) - - -class SocksConnector(aiohttp.TCPConnector): - def __init__(self, proxy, proxy_auth=None, *, remote_resolve=True, **kwgs): - super().__init__(**kwgs) - - self._proxy = proxy - self._proxy_auth = proxy_auth self._remote_resolve = remote_resolve - @property - def proxy(self): - """Proxy info. - Should be Socks4Server/Socks5Server instance. - """ - return self._proxy - - @property - def proxy_auth(self): - """Proxy auth info. - Should be Socks4Auth/Socks5Auth instance. - """ - return self._proxy_auth - - def _validate_ssl_fingerprint(self, transport, host): - has_cert = transport.get_extra_info('sslcontext') - if has_cert and self._fingerprint: - sock = transport.get_extra_info('socket') - if not hasattr(sock, 'getpeercert'): - # Workaround for asyncio 3.5.0 - # Starting from 3.5.1 version - # there is 'ssl_object' extra info in transport - sock = transport._ssl_protocol._sslpipe.ssl_object - # gives DER-encoded cert as a sequence of bytes (or None) - cert = sock.getpeercert(binary_form=True) - assert cert - got = self._hashfunc(cert).digest() - expected = self._fingerprint - if got != expected: - transport.close() - raise aiohttp.FingerprintMismatch( - expected, got, host, 80 - ) + async def _create_proxy_connection(self, req): + if req.proxy.scheme == 'http': + return await super()._create_proxy_connection(req) + else: + return await self._create_socks_connection(req) - @asyncio.coroutine - def _create_connection(self, req): + async def _create_socks_connection(self, req): if req.ssl: sslcontext = self.ssl_context else: sslcontext = None if not self._remote_resolve: - dst_hosts = yield from self._resolve_host(req.host, req.port) + dst_hosts = await self._resolve_host(req.host, req.port) dst = dst_hosts[0]['host'], dst_hosts[0]['port'] else: dst = req.host, req.port - proxy_hosts = yield from self._resolve_host(self._proxy.host, - self._proxy.port) + proxy_hosts = await self._resolve_host(req.proxy.host, req.proxy.port) exc = None for hinfo in proxy_hosts: - try: - proxy = self._proxy.__class__(host=hinfo['host'], - port=hinfo['port']) + if req.proxy.scheme == 'socks4': + proxy = Socks4Addr(hinfo['host'], hinfo['port']) + else: + proxy = Socks5Addr(hinfo['host'], hinfo['port']) - transp, proto = yield from create_connection( - self._factory, proxy, self._proxy_auth, dst, + try: + transp, proto = await create_connection( + self._factory, proxy, req.proxy_auth, dst, loop=self._loop, remote_resolve=self._remote_resolve, ssl=sslcontext, family=hinfo['family'], proto=hinfo['proto'], flags=hinfo['flags'], local_addr=self._local_addr, server_hostname=req.host if sslcontext else None) - self._validate_ssl_fingerprint(transp, req.host) + self._validate_ssl_fingerprint(transp, req.host, req.port) return transp, proto except (OSError, SocksError, SocksConnectionError) as e: exc = e else: if isinstance(exc, SocksConnectionError): - raise ProxyConnectionError(*exc.args) + raise aiohttp.ClientProxyConnectionError(*exc.args) if isinstance(exc, SocksError): raise exc else: @@ -107,12 +101,23 @@ class SocksConnector(aiohttp.TCPConnector): exc.errno, 'Can not connect to %s:%s [%s]' % (req.host, req.port, exc.strerror)) from exc - -def proxy_connector(proxy, proxy_auth=None, **kwargs): - if isinstance(proxy, HttpProxyAddr): - return aiohttp.ProxyConnector( - proxy.url, proxy_auth=proxy_auth, **kwargs) - elif isinstance(proxy, SocksAddr): - return SocksConnector(proxy, proxy_auth, **kwargs) - else: - raise ValueError('Unsupported `proxy` format') + def _validate_ssl_fingerprint(self, transp, host, port): + has_cert = transp.get_extra_info('sslcontext') + if has_cert and self._fingerprint: + sock = transp.get_extra_info('socket') + if not hasattr(sock, 'getpeercert'): + # Workaround for asyncio 3.5.0 + # Starting from 3.5.1 version + # there is 'ssl_object' extra info in transport + sock = transp._ssl_protocol._sslpipe.ssl_object + # gives DER-encoded cert as a sequence of bytes (or None) + cert = sock.getpeercert(binary_form=True) + assert cert + got = self._hashfunc(cert).digest() + expected = self._fingerprint + if got != expected: + transp.close() + if not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transp) + raise aiohttp.ServerFingerprintMismatch( + expected, got, host, port) diff --git a/aiosocks/protocols.py b/aiosocks/protocols.py index 196100c..a1fa4ff 100644 --- a/aiosocks/protocols.py +++ b/aiosocks/protocols.py @@ -10,11 +10,6 @@ from .errors import ( InvalidServerReply, InvalidServerVersion ) -try: - from asyncio import ensure_future -except ImportError: - ensure_future = asyncio.async - DEFAULT_LIMIT = getattr(asyncio.streams, '_DEFAULT_LIMIT', 2**16) @@ -54,11 +49,10 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): super().__init__(stream_reader=reader, client_connected_cb=self.negotiate, loop=self._loop) - @asyncio.coroutine - def negotiate(self, reader, writer): + async def negotiate(self, reader, writer): try: req = self.socks_request(c.SOCKS_CMD_CONNECT) - self._proxy_peername, self._proxy_sockname = yield from req + self._proxy_peername, self._proxy_sockname = await req except SocksError as exc: exc = SocksError('Can not connect to %s:%s. %s' % (self._dst_host, self._dst_port, exc)) @@ -134,8 +128,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): self._app_protocol.eof_received() super().eof_received() - @asyncio.coroutine - def socks_request(self, cmd): + async def socks_request(self, cmd): raise NotImplementedError def write_request(self, request): @@ -150,17 +143,15 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): raise ValueError('Unsupported item') self._stream_writer.write(bdata) - @asyncio.coroutine - def read_response(self, n): + async def read_response(self, n): try: - return (yield from self._stream_reader.readexactly(n)) + return (await self._stream_reader.readexactly(n)) except asyncio.IncompleteReadError as e: raise InvalidServerReply( 'Server sent fewer bytes than required (%s)' % str(e)) - @asyncio.coroutine - def _get_dst_addr(self): - infos = yield from self._loop.getaddrinfo( + async def _get_dst_addr(self): + infos = await self._loop.getaddrinfo( self._dst_host, self._dst_port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP, flags=socket.AI_ADDRCONFIG) @@ -227,8 +218,7 @@ class Socks4Protocol(BaseSocksProtocol): reader_limit=reader_limit, negotiate_done_cb=negotiate_done_cb) - @asyncio.coroutine - def socks_request(self, cmd): + async def socks_request(self, cmd): # prepare destination addr/port host, port = self._dst_host, self._dst_port port_bytes = struct.pack(b'>H', port) @@ -242,7 +232,7 @@ class Socks4Protocol(BaseSocksProtocol): include_hostname = True else: # it's not an IP number, so it's probably a DNS name. - family, host = yield from self._get_dst_addr() + family, host = await self._get_dst_addr() host_bytes = socket.inet_aton(host) # build and send connect command @@ -254,7 +244,7 @@ class Socks4Protocol(BaseSocksProtocol): self.write_request(req) # read/process result - resp = yield from self.read_response(8) + resp = await self.read_response(8) if resp[0] != c.NULL: raise InvalidServerReply('SOCKS4 proxy server sent invalid data') @@ -285,17 +275,16 @@ class Socks5Protocol(BaseSocksProtocol): reader_limit=reader_limit, negotiate_done_cb=negotiate_done_cb) - @asyncio.coroutine - def socks_request(self, cmd): - yield from self.authenticate() + async def socks_request(self, cmd): + await self.authenticate() # build and send command - dst_addr, resolved = yield from self.build_dst_address( + dst_addr, resolved = await self.build_dst_address( self._dst_host, self._dst_port) self.write_request([c.SOCKS_VER5, cmd, c.RSV] + dst_addr) # read/process command response - resp = yield from self.read_response(3) + resp = await self.read_response(3) if resp[0] != c.SOCKS_VER5: raise InvalidServerVersion( @@ -305,12 +294,11 @@ class Socks5Protocol(BaseSocksProtocol): error = c.SOCKS5_ERRORS.get(resp[1], 'Unknown error') raise SocksError('[Errno {0:#04x}]: {1}'.format(resp[1], error)) - binded = yield from self.read_address() + binded = await self.read_address() return resolved, binded - @asyncio.coroutine - def authenticate(self): + async def authenticate(self): # send available auth methods if self._auth.login and self._auth.password: req = [c.SOCKS_VER5, 0x02, @@ -321,7 +309,7 @@ class Socks5Protocol(BaseSocksProtocol): self.write_request(req) # read/process response and send auth data if necessary - chosen_auth = yield from self.read_response(2) + chosen_auth = await self.read_response(2) if chosen_auth[0] != c.SOCKS_VER5: raise InvalidServerVersion( @@ -333,7 +321,7 @@ class Socks5Protocol(BaseSocksProtocol): chr(len(self._auth.password)).encode(), self._auth.password] self.write_request(req) - auth_status = yield from self.read_response(2) + auth_status = await self.read_response(2) if auth_status[0] != 0x01: raise InvalidServerReply( 'SOCKS5 proxy server sent invalid data' @@ -353,8 +341,7 @@ class Socks5Protocol(BaseSocksProtocol): 'SOCKS5 proxy server sent invalid data' ) - @asyncio.coroutine - def build_dst_address(self, host, port): + async def build_dst_address(self, host, port): family_to_byte = {socket.AF_INET: c.SOCKS5_ATYP_IPv4, socket.AF_INET6: c.SOCKS5_ATYP_IPv6} port_bytes = struct.pack('>H', port) @@ -375,29 +362,28 @@ class Socks5Protocol(BaseSocksProtocol): req = [c.SOCKS5_ATYP_DOMAIN, chr(len(host_bytes)).encode(), host_bytes, port_bytes] else: - family, host_bytes = yield from self._get_dst_addr() + family, host_bytes = await self._get_dst_addr() host_bytes = socket.inet_pton(family, host_bytes) req = [family_to_byte[family], host_bytes, port_bytes] host = socket.inet_ntop(family, host_bytes) return req, (host, port) - @asyncio.coroutine - def read_address(self): - atype = yield from self.read_response(1) + async def read_address(self): + atype = await self.read_response(1) if atype[0] == c.SOCKS5_ATYP_IPv4: - addr = socket.inet_ntoa((yield from self.read_response(4))) + addr = socket.inet_ntoa((await self.read_response(4))) elif atype[0] == c.SOCKS5_ATYP_DOMAIN: - length = yield from self.read_response(1) - addr = yield from self.read_response(ord(length)) + length = await self.read_response(1) + addr = await self.read_response(ord(length)) elif atype[0] == c.SOCKS5_ATYP_IPv6: - addr = yield from self.read_response(16) + addr = await self.read_response(16) addr = socket.inet_ntop(socket.AF_INET6, addr) else: raise InvalidServerReply('SOCKS5 proxy server sent invalid data') - port = yield from self.read_response(2) + port = await self.read_response(2) port = struct.unpack('>H', port)[0] return addr, port diff --git a/aiosocks/test_utils.py b/aiosocks/test_utils.py new file mode 100644 index 0000000..6facce7 --- /dev/null +++ b/aiosocks/test_utils.py @@ -0,0 +1,136 @@ +import asyncio +import struct +import socket +from aiohttp.test_utils import unused_port + + +class FakeSocksSrv: + def __init__(self, loop, write_buff): + self._loop = loop + self._write_buff = write_buff + self._transports = [] + self._srv = None + self.port = unused_port() + + async def __aenter__(self): + transports = self._transports + write_buff = self._write_buff + + class SocksPrimitiveProtocol(asyncio.Protocol): + _transport = None + + def connection_made(self, transport): + self._transport = transport + transports.append(transport) + + def data_received(self, data): + self._transport.write(write_buff) + + def factory(): + return SocksPrimitiveProtocol() + + self._srv = await self._loop.create_server( + factory, '127.0.0.1', self.port) + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + for tr in self._transports: + tr.close() + + self._srv.close() + await self._srv.wait_closed() + + +class FakeSocks4Srv: + def __init__(self, loop): + self._loop = loop + self._transports = [] + self._futures = [] + self._srv = None + self.port = unused_port() + + async def __aenter__(self): + transports = self._transports + futures = self._futures + + class Socks4Protocol(asyncio.StreamReaderProtocol): + def __init__(self, _loop): + self._loop = _loop + reader = asyncio.StreamReader(loop=self._loop) + super().__init__(reader, client_connected_cb=self.negotiate, + loop=self._loop) + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + async def negotiate(self, reader, writer): + writer.write(b'\x00\x5a\x04W\x01\x01\x01\x01') + + data = await reader.read(9) + + dst_port = struct.unpack('>H', data[2:4])[0] + dst_addr = data[4:8] + + if data[-1] != 0x00: + while True: + byte = await reader.read(1) + if byte == 0x00: + break + + if dst_addr == b'\x00\x00\x00\x01': + dst_addr = bytearray() + + while True: + byte = await reader.read(1) + if byte == 0x00: + break + dst_addr.append(byte) + else: + dst_addr = socket.inet_ntoa(dst_addr) + + cl_reader, cl_writer = await asyncio.open_connection( + host=dst_addr, port=dst_port, loop=self._loop + ) + transports.append(cl_writer) + + cl_fut = asyncio.ensure_future( + self.retranslator(reader, cl_writer), loop=self._loop) + dst_fut = asyncio.ensure_future( + self.retranslator(cl_reader, writer), loop=self._loop) + + futures.append(cl_fut) + futures.append(dst_fut) + + async def retranslator(self, reader, writer): + data = bytearray() + while True: + try: + byte = await reader.read(10) + if not byte: + break + data.append(byte[0]) + writer.write(byte) + await writer.drain() + except: + break + + def factory(): + return Socks4Protocol(_loop=self._loop) + + self._srv = await self._loop.create_server( + factory, '127.0.0.1', self.port) + + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + for tr in self._transports: + tr.close() + + self._srv.close() + await self._srv.wait_closed() + + for f in self._futures: + if not f.cancelled() or not f.done(): + f.cancel() diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..bbc64f0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1 @@ +from aiohttp.pytest_plugin import * # noqa diff --git a/tests/helpers.py b/tests/helpers.py deleted file mode 100644 index cd9ca66..0000000 --- a/tests/helpers.py +++ /dev/null @@ -1,242 +0,0 @@ -import asyncio -import aiohttp -import contextlib -import gc -import os -import socket -import ssl -import struct -import threading -from unittest import mock -from aiohttp.server import ServerHttpProtocol -try: - from asyncio import ensure_future -except ImportError: - ensure_future = asyncio.async - - -def fake_coroutine(return_value): - def coro(*args, **kwargs): - if isinstance(return_value, Exception): - raise return_value - return return_value - - return mock.Mock(side_effect=asyncio.coroutine(coro)) - - -def find_unused_port(): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('127.0.0.1', 0)) - port = s.getsockname()[1] - s.close() - return port - - -@contextlib.contextmanager -def fake_socks_srv(loop, write_buff): - transports = [] - - class SocksPrimitiveProtocol(asyncio.Protocol): - _transport = None - - def connection_made(self, transport): - self._transport = transport - transports.append(transport) - - def data_received(self, data): - self._transport.write(write_buff) - - port = find_unused_port() - - def factory(): - return SocksPrimitiveProtocol() - - srv = loop.run_until_complete( - loop.create_server(factory, '127.0.0.1', port)) - - yield port - - for tr in transports: - tr.close() - - srv.close() - loop.run_until_complete(srv.wait_closed()) - gc.collect() - - -@contextlib.contextmanager -def fake_socks4_srv(loop): - port = find_unused_port() - transports = [] - futures = [] - - class Socks4Protocol(asyncio.StreamReaderProtocol): - def __init__(self, _loop): - self._loop = _loop - reader = asyncio.StreamReader(loop=self._loop) - super().__init__(reader, client_connected_cb=self.negotiate, - loop=self._loop) - - def connection_made(self, transport): - transports.append(transport) - super().connection_made(transport) - - @asyncio.coroutine - def negotiate(self, reader, writer): - writer.write(b'\x00\x5a\x04W\x01\x01\x01\x01') - - data = yield from reader.read(9) - - dst_port = struct.unpack('>H', data[2:4])[0] - dst_addr = data[4:8] - - if data[-1] != 0x00: - while True: - byte = yield from reader.read(1) - if byte == 0x00: - break - - if dst_addr == b'\x00\x00\x00\x01': - dst_addr = bytearray() - - while True: - byte = yield from reader.read(1) - if byte == 0x00: - break - dst_addr.append(byte) - else: - dst_addr = socket.inet_ntoa(dst_addr) - - cl_reader, cl_writer = yield from asyncio.open_connection( - host=dst_addr, port=dst_port, loop=self._loop - ) - transports.append(cl_writer) - - cl_fut = ensure_future( - self.retranslator(reader, cl_writer), loop=self._loop) - dst_fut = ensure_future( - self.retranslator(cl_reader, writer), loop=self._loop) - futures.append(cl_fut) - futures.append(dst_fut) - - @asyncio.coroutine - def retranslator(self, reader, writer): - data = bytearray() - while True: - try: - byte = yield from reader.read(1) - if not byte: - break - data.append(byte[0]) - writer.write(byte) - yield from writer.drain() - except: - break - - def run(_fut): - thread_loop = asyncio.new_event_loop() - asyncio.set_event_loop(thread_loop) - - srv_coroutine = thread_loop.create_server( - lambda: Socks4Protocol(thread_loop), '127.0.0.1', port) - srv = thread_loop.run_until_complete(srv_coroutine) - - waiter = asyncio.Future(loop=thread_loop) - loop.call_soon_threadsafe( - _fut.set_result, (thread_loop, waiter)) - - try: - thread_loop.run_until_complete(waiter) - finally: - # close opened transports - for tr in transports: - tr.close() - for ft in futures: - if not ft.done(): - ft.set_result(1) - - srv.close() - thread_loop.stop() - thread_loop.close() - gc.collect() - - fut = asyncio.Future(loop=loop) - srv_thread = threading.Thread(target=run, args=(fut,)) - srv_thread.start() - - _thread_loop, _waiter = loop.run_until_complete(fut) - - yield port - _thread_loop.call_soon_threadsafe(_waiter.set_result, None) - srv_thread.join() - - -@contextlib.contextmanager -def http_srv(loop, *, listen_addr=('127.0.0.1', 0), use_ssl=False): - transports = [] - - class TestHttpServer(ServerHttpProtocol): - - def connection_made(self, transport): - transports.append(transport) - super().connection_made(transport) - - @asyncio.coroutine - def handle_request(self, message, payload): - response = aiohttp.Response(self.writer, 200, message.version) - - text = b'Test message' - response.add_header('Content-type', 'text/plain') - response.add_header('Content-length', str(len(text))) - response.send_headers() - response.write(text) - response.write_eof() - - if use_ssl: - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - keyfile = os.path.join(here, 'sample.key') - certfile = os.path.join(here, 'sample.crt') - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.load_cert_chain(certfile, keyfile) - else: - sslcontext = None - - def run(_fut): - thread_loop = asyncio.new_event_loop() - asyncio.set_event_loop(thread_loop) - - host, port = listen_addr - - srv_coroutine = thread_loop.create_server( - lambda: TestHttpServer(), host, port, ssl=sslcontext) - srv = thread_loop.run_until_complete(srv_coroutine) - - waiter = asyncio.Future(loop=thread_loop) - loop.call_soon_threadsafe( - _fut.set_result, (thread_loop, waiter, - srv.sockets[0].getsockname())) - - try: - thread_loop.run_until_complete(waiter) - finally: - # close opened transports - for tr in transports: - tr.close() - - srv.close() - thread_loop.stop() - thread_loop.close() - gc.collect() - - fut = asyncio.Future(loop=loop) - srv_thread = threading.Thread(target=run, args=(fut,)) - srv_thread.start() - - _thread_loop, _waiter, _addr = loop.run_until_complete(fut) - - url = '{}://{}:{}'.format( - 'https' if use_ssl else 'http', *_addr) - - yield url - _thread_loop.call_soon_threadsafe(_waiter.set_result, None) - srv_thread.join() diff --git a/tests/test_connector.py b/tests/test_connector.py index 765a123..5109f6f 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1,155 +1,179 @@ -import unittest -import asyncio import aiosocks import aiohttp +import pytest +from yarl import URL from unittest import mock -from aiohttp.client_reqrep import ClientRequest -from aiosocks.connector import SocksConnector, proxy_connector, HttpProxyAddr -from .helpers import fake_coroutine +from aiohttp.test_utils import make_mocked_coro +from aiohttp import BasicAuth +from aiosocks.connector import ProxyConnector, ProxyClientRequest +from aiosocks.helpers import Socks4Auth, Socks5Auth -class TestSocksConnector(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_properties(self): - addr = aiosocks.Socks4Addr('localhost') - auth = aiosocks.Socks4Auth('login') - conn = SocksConnector(addr, auth, loop=self.loop) - self.assertIs(conn.proxy, addr) - self.assertIs(conn.proxy_auth, auth) - - @mock.patch('aiosocks.connector.create_connection') - def test_connect_proxy_ip(self, cr_conn_mock): - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - cr_conn_mock.side_effect = \ - fake_coroutine((tr, proto)).side_effect +async def test_connect_proxy_ip(): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + with mock.patch('aiosocks.connector.create_connection', + make_mocked_coro((tr, proto))): loop_mock = mock.Mock() + loop_mock.getaddrinfo = make_mocked_coro( + [[0, 0, 0, 0, ['127.0.0.1', 1080]]]) - req = ClientRequest('GET', 'http://python.org', loop=self.loop) - connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), - None, loop=loop_mock) + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop_mock, + proxy=URL('socks5://proxy.org')) + connector = ProxyConnector(loop=loop_mock) + conn = await connector.connect(req) - loop_mock.getaddrinfo = fake_coroutine([mock.MagicMock()]) + assert loop_mock.getaddrinfo.called + assert conn.protocol is proto - conn = self.loop.run_until_complete(connector.connect(req)) + conn.close() - self.assertTrue(loop_mock.getaddrinfo.is_called) - self.assertIs(conn._transport, tr) - conn.close() +async def test_connect_proxy_domain(): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - @mock.patch('aiosocks.connector.create_connection') - def test_connect_proxy_domain(self, cr_conn_mock): - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - cr_conn_mock.side_effect = \ - fake_coroutine((tr, proto)).side_effect + with mock.patch('aiosocks.connector.create_connection', + make_mocked_coro((tr, proto))): loop_mock = mock.Mock() - req = ClientRequest('GET', 'http://python.org', loop=self.loop) - connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), - None, loop=loop_mock) + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop_mock, + proxy=URL('socks5://proxy.example')) + connector = ProxyConnector(loop=loop_mock) - connector._resolve_host = fake_coroutine([mock.MagicMock()]) + connector._resolve_host = make_mocked_coro([mock.MagicMock()]) + conn = await connector.connect(req) - conn = self.loop.run_until_complete(connector.connect(req)) + assert connector._resolve_host.call_count == 1 + assert conn.protocol is proto - self.assertTrue(connector._resolve_host.is_called) - self.assertEqual(connector._resolve_host.call_count, 1) - self.assertIs(conn._transport, tr) + conn.close() - conn.close() - @mock.patch('aiosocks.connector.create_connection') - def test_connect_remote_resolve(self, cr_conn_mock): - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - cr_conn_mock.side_effect = \ - fake_coroutine((tr, proto)).side_effect +async def test_connect_remote_resolve(loop): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - req = ClientRequest('GET', 'http://python.org', loop=self.loop) - connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), - None, loop=self.loop, remote_resolve=True) + with mock.patch('aiosocks.connector.create_connection', + make_mocked_coro((tr, proto))): + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('socks5://127.0.0.1')) + connector = ProxyConnector(loop=loop, remote_resolve=True) + connector._resolve_host = make_mocked_coro([mock.MagicMock()]) - connector._resolve_host = fake_coroutine([mock.MagicMock()]) + conn = await connector.connect(req) - conn = self.loop.run_until_complete(connector.connect(req)) + assert connector._resolve_host.call_count == 1 + assert conn.protocol is proto - self.assertEqual(connector._resolve_host.call_count, 1) + conn.close() - conn.close() - @mock.patch('aiosocks.connector.create_connection') - def test_connect_locale_resolve(self, cr_conn_mock): - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - cr_conn_mock.side_effect = \ - fake_coroutine((tr, proto)).side_effect +async def test_connect_locale_resolve(loop): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - req = ClientRequest('GET', 'http://python.org', loop=self.loop) - connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), - None, loop=self.loop, remote_resolve=False) + with mock.patch('aiosocks.connector.create_connection', + make_mocked_coro((tr, proto))): + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('socks5://proxy.example')) + connector = ProxyConnector(loop=loop, remote_resolve=False) + connector._resolve_host = make_mocked_coro([mock.MagicMock()]) - connector._resolve_host = fake_coroutine([mock.MagicMock()]) + conn = await connector.connect(req) - conn = self.loop.run_until_complete(connector.connect(req)) + assert connector._resolve_host.call_count == 2 + assert conn.protocol is proto - self.assertTrue(connector._resolve_host.is_called) - self.assertEqual(connector._resolve_host.call_count, 2) + conn.close() - conn.close() - @mock.patch('aiosocks.connector.create_connection') - def test_proxy_connect_fail(self, cr_conn_mock): - loop_mock = mock.Mock() - cr_conn_mock.side_effect = \ - fake_coroutine(aiosocks.SocksConnectionError()).side_effect +async def test_proxy_connect_fail(loop): + loop_mock = mock.Mock() + loop_mock.getaddrinfo = make_mocked_coro( + [[0, 0, 0, 0, ['127.0.0.1', 1080]]]) + cc_coro = make_mocked_coro( + raise_exception=aiosocks.SocksConnectionError()) - req = ClientRequest('GET', 'http://python.org', loop=self.loop) - connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), - None, loop=loop_mock) + with mock.patch('aiosocks.connector.create_connection', cc_coro): + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('socks5://127.0.0.1')) + connector = ProxyConnector(loop=loop_mock) - loop_mock.getaddrinfo = fake_coroutine([mock.MagicMock()]) + with pytest.raises(aiohttp.ClientConnectionError): + await connector.connect(req) - with self.assertRaises(aiohttp.ProxyConnectionError): - self.loop.run_until_complete(connector.connect(req)) - @mock.patch('aiosocks.connector.create_connection') - def test_proxy_negotiate_fail(self, cr_conn_mock): - loop_mock = mock.Mock() - cr_conn_mock.side_effect = \ - fake_coroutine(aiosocks.SocksError()).side_effect +async def test_proxy_negotiate_fail(loop): + loop_mock = mock.Mock() + loop_mock.getaddrinfo = make_mocked_coro( + [[0, 0, 0, 0, ['127.0.0.1', 1080]]]) - req = ClientRequest('GET', 'http://python.org', loop=self.loop) - connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), - None, loop=loop_mock) + with mock.patch('aiosocks.connector.create_connection', + make_mocked_coro(raise_exception=aiosocks.SocksError())): + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('socks5://127.0.0.1')) + connector = ProxyConnector(loop=loop_mock) - loop_mock.getaddrinfo = fake_coroutine([mock.MagicMock()]) + with pytest.raises(aiosocks.SocksError): + await connector.connect(req) - with self.assertRaises(aiosocks.SocksError): - self.loop.run_until_complete(connector.connect(req)) - def test_proxy_connector(self): - socks4_addr = aiosocks.Socks4Addr('h') - socks5_addr = aiosocks.Socks5Addr('h') - http_addr = HttpProxyAddr('http://proxy') +async def test_proxy_connect_http(loop): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + loop_mock = mock.Mock() + loop_mock.getaddrinfo = make_mocked_coro([ + [0, 0, 0, 0, ['127.0.0.1', 1080]]]) + loop_mock.create_connection = make_mocked_coro((tr, proto)) - self.assertIsInstance(proxy_connector(socks4_addr, loop=self.loop), - SocksConnector) - self.assertIsInstance(proxy_connector(socks5_addr, loop=self.loop), - SocksConnector) - self.assertIsInstance(proxy_connector(http_addr, loop=self.loop), - aiohttp.ProxyConnector) + req = ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('http://127.0.0.1')) + connector = ProxyConnector(loop=loop_mock) - with self.assertRaises(ValueError): - proxy_connector(None) + await connector.connect(req) - def test_http_proxy_addr(self): - addr = HttpProxyAddr('http://proxy') - self.assertEqual(addr.url, 'http://proxy') - with self.assertRaises(ValueError): - HttpProxyAddr(None) +@pytest.mark.parametrize('proxy', [ + (URL('socks4://proxy.org'), Socks4Auth('login')), + (URL('socks5://proxy.org'), Socks5Auth('login', 'password')), + (URL('http://proxy.org'), BasicAuth('login')), (None, BasicAuth('login')), + (URL('socks4://proxy.org'), None), (None, None)]) +def test_proxy_client_request_valid(proxy, loop): + proxy, proxy_auth = proxy + p = ProxyClientRequest('GET', URL('http://python.org'), + proxy=proxy, proxy_auth=proxy_auth, loop=loop) + assert p.proxy is proxy + assert p.proxy_auth is proxy_auth + + +def test_proxy_client_request_invalid(loop): + with pytest.raises(ValueError) as cm: + ProxyClientRequest( + 'GET', URL('http://python.org'), + proxy=URL('socks6://proxy.org'), proxy_auth=None, loop=loop) + assert 'Only http, socks4 and socks5 proxies are supported' in str(cm) + + with pytest.raises(ValueError) as cm: + ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('http://proxy.org'), proxy_auth=Socks4Auth('l')) + assert 'proxy_auth must be None or BasicAuth() ' \ + 'tuple for http proxy' in str(cm) + + with pytest.raises(ValueError) as cm: + ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('socks4://proxy.org'), proxy_auth=BasicAuth('l')) + assert 'proxy_auth must be None or Socks4Auth() ' \ + 'tuple for socks4 proxy' in str(cm) + + with pytest.raises(ValueError) as cm: + ProxyClientRequest( + 'GET', URL('http://python.org'), loop=loop, + proxy=URL('socks5://proxy.org'), proxy_auth=Socks4Auth('l')) + assert 'proxy_auth must be None or Socks5Auth() ' \ + 'tuple for socks5 proxy' in str(cm) diff --git a/tests/test_create_connect.py b/tests/test_create_connect.py index cb43f75..c872bbf 100644 --- a/tests/test_create_connect.py +++ b/tests/test_create_connect.py @@ -1,134 +1,97 @@ -import unittest +import pytest import aiosocks -import asyncio +from aiohttp.test_utils import make_mocked_coro from unittest import mock -from .helpers import fake_coroutine - -try: - from asyncio import ensure_future -except ImportError: - ensure_future = asyncio.async - - -class TestCreateConnection(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_init(self): - addr = aiosocks.Socks5Addr('localhost') - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - # proxy argument - with self.assertRaises(AssertionError) as ct: - conn = aiosocks.create_connection(None, None, auth, dst) - self.loop.run_until_complete(conn) - self.assertEqual(str(ct.exception), - 'proxy must be Socks4Addr() or Socks5Addr() tuple') - - with self.assertRaises(AssertionError) as ct: - conn = aiosocks.create_connection(None, auth, auth, dst) - self.loop.run_until_complete(conn) - self.assertEqual(str(ct.exception), - 'proxy must be Socks4Addr() or Socks5Addr() tuple') - - # proxy_auth - with self.assertRaises(AssertionError) as ct: - conn = aiosocks.create_connection(None, addr, addr, dst) - self.loop.run_until_complete(conn) - self.assertIn('proxy_auth must be None or Socks4Auth()', - str(ct.exception)) - - # dst - with self.assertRaises(AssertionError) as ct: - conn = aiosocks.create_connection(None, addr, auth, None) - self.loop.run_until_complete(conn) - self.assertIn('invalid dst format, tuple("dst_host", dst_port))', - str(ct.exception)) - - # addr and auth compatibility - with self.assertRaises(ValueError) as ct: - conn = aiosocks.create_connection( - None, addr, aiosocks.Socks4Auth(''), dst - ) - self.loop.run_until_complete(conn) - self.assertIn('proxy is Socks5Addr but proxy_auth is not Socks5Auth', - str(ct.exception)) - - with self.assertRaises(ValueError) as ct: - conn = aiosocks.create_connection( - None, aiosocks.Socks4Addr(''), auth, dst - ) - self.loop.run_until_complete(conn) - self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', - str(ct.exception)) - - # test ssl, server_hostname - with self.assertRaises(ValueError) as ct: - conn = aiosocks.create_connection( - None, addr, auth, dst, server_hostname='python.org' - ) - self.loop.run_until_complete(conn) - self.assertIn('server_hostname is only meaningful with ssl', - str(ct.exception)) - - def test_connection_fail(self): - addr = aiosocks.Socks5Addr('localhost') - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - loop_mock = mock.Mock() - loop_mock.create_connection = fake_coroutine(OSError()) - - with self.assertRaises(aiosocks.SocksConnectionError): - conn = aiosocks.create_connection( - None, addr, auth, dst, loop=loop_mock - ) - self.loop.run_until_complete(conn) - - @mock.patch('aiosocks.asyncio.Future') - def test_negotiate_fail(self, future_mock): - addr = aiosocks.Socks5Addr('localhost') - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - loop_mock = mock.Mock() - loop_mock.create_connection = fake_coroutine( - (mock.Mock(), mock.Mock()) - ) - - fut = fake_coroutine(aiosocks.SocksError()) - future_mock.side_effect = fut.side_effect - - with self.assertRaises(aiosocks.SocksError): - conn = aiosocks.create_connection( - None, addr, auth, dst, loop=loop_mock - ) - self.loop.run_until_complete(conn) - - @mock.patch('aiosocks.asyncio.Future') - def test_open_connection(self, future_mock): - addr = aiosocks.Socks5Addr('localhost') - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - transp, proto = mock.Mock(), mock.Mock() - reader, writer = mock.Mock(), mock.Mock() - - proto.app_protocol.reader, proto.app_protocol.writer = reader, writer - - loop_mock = mock.Mock() - loop_mock.create_connection = fake_coroutine((transp, proto)) - - fut = fake_coroutine(True) - future_mock.side_effect = fut.side_effect - - conn = aiosocks.open_connection(addr, auth, dst, loop=loop_mock) - r, w = self.loop.run_until_complete(conn) - - self.assertIs(reader, r) - self.assertIs(writer, w) + + +async def test_create_connection_init(): + addr = aiosocks.Socks5Addr('localhost') + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + # proxy argument + with pytest.raises(AssertionError) as ct: + await aiosocks.create_connection(None, None, auth, dst) + assert 'proxy must be Socks4Addr() or Socks5Addr() tuple' in str(ct) + + with pytest.raises(AssertionError) as ct: + await aiosocks.create_connection(None, auth, auth, dst) + assert 'proxy must be Socks4Addr() or Socks5Addr() tuple' in str(ct) + + # proxy_auth + with pytest.raises(AssertionError) as ct: + await aiosocks.create_connection(None, addr, addr, dst) + assert 'proxy_auth must be None or Socks4Auth()' in str(ct) + + # dst + with pytest.raises(AssertionError) as ct: + await aiosocks.create_connection(None, addr, auth, None) + assert 'invalid dst format, tuple("dst_host", dst_port))' in str(ct) + + # addr and auth compatibility + with pytest.raises(ValueError) as ct: + await aiosocks.create_connection( + None, addr, aiosocks.Socks4Auth(''), dst) + assert 'proxy is Socks5Addr but proxy_auth is not Socks5Auth' in str(ct) + + with pytest.raises(ValueError) as ct: + await aiosocks.create_connection( + None, aiosocks.Socks4Addr(''), auth, dst) + assert 'proxy is Socks4Addr but proxy_auth is not Socks4Auth' in str(ct) + + # test ssl, server_hostname + with pytest.raises(ValueError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, server_hostname='python.org') + assert 'server_hostname is only meaningful with ssl' in str(ct) + + +async def test_connection_fail(): + addr = aiosocks.Socks5Addr('localhost') + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + loop_mock = mock.Mock() + loop_mock.create_connection = make_mocked_coro(raise_exception=OSError()) + + with pytest.raises(aiosocks.SocksConnectionError): + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop_mock) + + +async def test_negotiate_fail(): + addr = aiosocks.Socks5Addr('localhost') + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + loop_mock = mock.Mock() + loop_mock.create_connection = make_mocked_coro((mock.Mock(), mock.Mock())) + + with mock.patch('aiosocks.asyncio.Future') as future_mock: + future_mock.side_effect = make_mocked_coro( + raise_exception=aiosocks.SocksError()) + + with pytest.raises(aiosocks.SocksError): + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop_mock) + + +async def test_open_connection(): + addr = aiosocks.Socks5Addr('localhost') + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + transp, proto = mock.Mock(), mock.Mock() + reader, writer = mock.Mock(), mock.Mock() + + proto.app_protocol.reader, proto.app_protocol.writer = reader, writer + + loop_mock = mock.Mock() + loop_mock.create_connection = make_mocked_coro((transp, proto)) + + with mock.patch('aiosocks.asyncio.Future') as future_mock: + future_mock.side_effect = make_mocked_coro(True) + r, w = await aiosocks.open_connection(addr, auth, dst, loop=loop_mock) + + assert reader is r + assert writer is w diff --git a/tests/test_functional.py b/tests/test_functional.py index 277f1ae..798de9b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,368 +1,293 @@ +import pytest import aiosocks -import asyncio import aiohttp -import unittest -from aiosocks.connector import SocksConnector - -try: - from asyncio import ensure_future -except ImportError: - ensure_future = asyncio.async - -from .helpers import fake_socks_srv, fake_socks4_srv, http_srv - - -class TestCreateSocks4Connection(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_connect_success(self): - with fake_socks_srv(self.loop, - b'\x00\x5a\x04W\x01\x01\x01\x01test') as port: - addr = aiosocks.Socks4Addr('127.0.0.1', port) - auth = aiosocks.Socks4Auth('usr') - dst = ('python.org', 80) - - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - transport, protocol = self.loop.run_until_complete(coro) - - self.assertEqual(protocol.proxy_sockname, ('1.1.1.1', 1111)) - - data = self.loop.run_until_complete( - protocol._stream_reader.read(4)) - self.assertEqual(data, b'test') - - transport.close() - - def test_invalid_data(self): - with fake_socks_srv(self.loop, - b'\x01\x5a\x04W\x01\x01\x01\x01') as port: - addr = aiosocks.Socks4Addr('127.0.0.1', port) - auth = aiosocks.Socks4Auth('usr') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('invalid data', str(ct.exception)) - - def test_socks_srv_error(self): - with fake_socks_srv(self.loop, - b'\x00\x5b\x04W\x01\x01\x01\x01') as port: - addr = aiosocks.Socks4Addr('127.0.0.1', port) - auth = aiosocks.Socks4Auth('usr') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('0x5b', str(ct.exception)) - - -class TestCreateSocks5Connect(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_connect_success_anonymous(self): - with fake_socks_srv( - self.loop, - b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - transport, protocol = self.loop.run_until_complete(coro) - - self.assertEqual(protocol.proxy_sockname, ('1.1.1.1', 1111)) - - data = self.loop.run_until_complete( - protocol._stream_reader.read(4)) - self.assertEqual(data, b'test') - - transport.close() - - def test_connect_success_usr_pwd(self): - with fake_socks_srv( - self.loop, - b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' - ) as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - transport, protocol = self.loop.run_until_complete(coro) - - self.assertEqual(protocol.proxy_sockname, ('1.1.1.1', 1111)) - - data = self.loop.run_until_complete( - protocol._stream_reader.read(4)) - self.assertEqual(data, b'test') - transport.close() - - def test_auth_ver_err(self): - with fake_socks_srv(self.loop, b'\x04\x02') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('invalid version', str(ct.exception)) - - def test_auth_method_rejected(self): - with fake_socks_srv(self.loop, b'\x05\xFF') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('authentication methods were rejected', - str(ct.exception)) - - def test_auth_status_invalid(self): - with fake_socks_srv(self.loop, b'\x05\xF0') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('invalid data', str(ct.exception)) - - def test_auth_status_invalid2(self): - with fake_socks_srv(self.loop, b'\x05\x02\x02\x00') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('invalid data', str(ct.exception)) - - def test_auth_failed(self): - with fake_socks_srv(self.loop, b'\x05\x02\x01\x01') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('authentication failed', str(ct.exception)) - - def test_cmd_ver_err(self): - with fake_socks_srv(self.loop, - b'\x05\x02\x01\x00\x04\x00\x00') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('invalid version', str(ct.exception)) - - def test_cmd_not_granted(self): - with fake_socks_srv(self.loop, - b'\x05\x02\x01\x00\x05\x01\x00') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('General SOCKS server failure', str(ct.exception)) - - def test_invalid_address_type(self): - with fake_socks_srv(self.loop, - b'\x05\x02\x01\x00\x05\x00\x00\xFF') as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - with self.assertRaises(aiosocks.SocksError) as ct: - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - self.loop.run_until_complete(coro) - self.assertIn('invalid data', str(ct.exception)) - - def test_atype_ipv4(self): - with fake_socks_srv( - self.loop, - b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W' - ) as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - transport, protocol = self.loop.run_until_complete(coro) - - self.assertEqual(protocol.proxy_sockname, ('1.1.1.1', 1111)) - - transport.close() - - def test_atype_ipv6(self): - with fake_socks_srv( - self.loop, - b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00' - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W' - ) as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - transport, protocol = self.loop.run_until_complete(coro) - - self.assertEqual(protocol.proxy_sockname, ('::111', 1111)) - - transport.close() - - def test_atype_domain(self): - with fake_socks_srv( - self.loop, - b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W' - ) as port: - addr = aiosocks.Socks5Addr('127.0.0.1', port) - auth = aiosocks.Socks5Auth('usr', 'pwd') - dst = ('python.org', 80) - - coro = aiosocks.create_connection( - None, addr, auth, dst, loop=self.loop) - transport, protocol = self.loop.run_until_complete(coro) - - self.assertEqual(protocol.proxy_sockname, (b'python.org', 1111)) - - transport.close() - - -class TestSocksConnector(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_http_connect(self): - with fake_socks4_srv(self.loop) as proxy_port: - addr = aiosocks.Socks4Addr('127.0.0.1', proxy_port) - - conn = SocksConnector(proxy=addr, proxy_auth=None, loop=self.loop, - remote_resolve=False) - - with http_srv(self.loop) as url: - with aiohttp.ClientSession(connector=conn, - loop=self.loop) as ses: - @asyncio.coroutine - def make_req(): - return (yield from ses.request('get', url=url)) - - resp = self.loop.run_until_complete(make_req()) - - self.assertEqual(resp.status, 200) +import os +import ssl +from aiohttp import web +from aiohttp.test_utils import RawTestServer +from aiosocks.test_utils import FakeSocksSrv, FakeSocks4Srv +from aiosocks.connector import ProxyConnector, ProxyClientRequest - content = self.loop.run_until_complete(resp.text()) - self.assertEqual(content, 'Test message') - resp.close() +async def test_socks4_connect_success(loop): + pld = b'\x00\x5a\x04W\x01\x01\x01\x01test' - def test_https_connect(self): - with fake_socks4_srv(self.loop) as proxy_port: - addr = aiosocks.Socks4Addr('127.0.0.1', proxy_port) + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks4Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks4Auth('usr') + dst = ('python.org', 80) - conn = SocksConnector(proxy=addr, proxy_auth=None, loop=self.loop, - remote_resolve=False, verify_ssl=False) + transport, protocol = await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) - with http_srv(self.loop, use_ssl=True) as url: - with aiohttp.ClientSession(connector=conn, - loop=self.loop) as ses: - @asyncio.coroutine - def make_req(): - return (yield from ses.request('get', url=url)) + assert protocol.proxy_sockname == ('1.1.1.1', 1111) - resp = self.loop.run_until_complete(make_req()) + data = await protocol._stream_reader.read(4) + assert data == b'test' - self.assertEqual(resp.status, 200) + transport.close() - content = self.loop.run_until_complete(resp.text()) - self.assertEqual(content, 'Test message') - resp.close() +async def test_socks4_invalid_data(loop): + pld = b'\x01\x5a\x04W\x01\x01\x01\x01' - def test_fingerprint_success(self): - with fake_socks4_srv(self.loop) as proxy_port: - addr = aiosocks.Socks4Addr('127.0.0.1', proxy_port) - fp = (b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9' - b'\x1a\xe3\xc5\x7f\x89\xe7l\xf9') + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks4Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks4Auth('usr') + dst = ('python.org', 80) - conn = SocksConnector(proxy=addr, proxy_auth=None, loop=self.loop, - remote_resolve=False, verify_ssl=False, - fingerprint=fp) + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'invalid data' in str(ct) - with http_srv(self.loop, use_ssl=True) as url: - with aiohttp.ClientSession(connector=conn, - loop=self.loop) as ses: - @asyncio.coroutine - def make_req(): - return (yield from ses.request('get', url=url)) - resp = self.loop.run_until_complete(make_req()) +async def test_socks4_srv_error(loop): + pld = b'\x00\x5b\x04W\x01\x01\x01\x01' - self.assertEqual(resp.status, 200) + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks4Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks4Auth('usr') + dst = ('python.org', 80) - content = self.loop.run_until_complete(resp.text()) - self.assertEqual(content, 'Test message') + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert '0x5b' in str(ct) - resp.close() - def test_fingerprint_fail(self): - with fake_socks4_srv(self.loop) as proxy_port: - addr = aiosocks.Socks4Addr('127.0.0.1', proxy_port) - fp = (b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9' - b'\x1a\xe3\xc5\x7f\x89\xe7l\x10') +async def test_socks5_connect_success_anonymous(loop): + pld = b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' - conn = SocksConnector(proxy=addr, proxy_auth=None, loop=self.loop, - remote_resolve=False, verify_ssl=False, - fingerprint=fp) + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) - with http_srv(self.loop, use_ssl=True) as url: - with aiohttp.ClientSession(connector=conn, - loop=self.loop) as ses: - @asyncio.coroutine - def make_req(): - return (yield from ses.request('get', url=url)) + transport, protocol = await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) - with self.assertRaises(aiohttp.FingerprintMismatch): - self.loop.run_until_complete(make_req()) + assert protocol.proxy_sockname == ('1.1.1.1', 1111) + + data = await protocol._stream_reader.read(4) + assert data == b'test' + + transport.close() + + +async def test_socks5_connect_success_usr_pwd(loop): + pld = b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' + + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + transport, protocol = await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert protocol.proxy_sockname == ('1.1.1.1', 1111) + + data = await protocol._stream_reader.read(4) + assert data == b'test' + transport.close() + + +async def test_socks5_auth_ver_err(loop): + async with FakeSocksSrv(loop, b'\x04\x02') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'invalid version' in str(ct) + + +async def test_socks5_auth_method_rejected(loop): + async with FakeSocksSrv(loop, b'\x05\xFF') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'authentication methods were rejected' in str(ct) + + +async def test_socks5_auth_status_invalid(loop): + async with FakeSocksSrv(loop, b'\x05\xF0') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'invalid data' in str(ct) + + +async def test_socks5_auth_status_invalid2(loop): + async with FakeSocksSrv(loop, b'\x05\x02\x02\x00') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'invalid data' in str(ct) + + +async def test_socks5_auth_failed(loop): + async with FakeSocksSrv(loop, b'\x05\x02\x01\x01') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'authentication failed' in str(ct) + + +async def test_socks5_cmd_ver_err(loop): + async with FakeSocksSrv(loop, b'\x05\x02\x01\x00\x04\x00\x00') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'invalid version' in str(ct) + + +async def test_socks5_cmd_not_granted(loop): + async with FakeSocksSrv(loop, b'\x05\x02\x01\x00\x05\x01\x00') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'General SOCKS server failure' in str(ct) + + +async def test_socks5_invalid_address_type(loop): + async with FakeSocksSrv(loop, b'\x05\x02\x01\x00\x05\x00\x00\xFF') as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with pytest.raises(aiosocks.SocksError) as ct: + await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert 'invalid data' in str(ct) + + +async def test_socks5_atype_ipv4(loop): + pld = b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W' + + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + transport, protocol = await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert protocol.proxy_sockname == ('1.1.1.1', 1111) + + transport.close() + + +async def test_socks5_atype_ipv6(loop): + pld = b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00' \ + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W' + + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + transport, protocol = await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert protocol.proxy_sockname == ('::111', 1111) + + transport.close() + + +async def test_socks5_atype_domain(loop): + pld = b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W' + + async with FakeSocksSrv(loop, pld) as srv: + addr = aiosocks.Socks5Addr('127.0.0.1', srv.port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + transport, protocol = await aiosocks.create_connection( + None, addr, auth, dst, loop=loop) + assert protocol.proxy_sockname == (b'python.org', 1111) + + transport.close() + + +async def test_http_connect(loop): + async def handler(request): + return web.Response(text='Test message') + + async with RawTestServer(handler, host='127.0.0.1', loop=loop) as ws: + async with FakeSocks4Srv(loop) as srv: + conn = ProxyConnector(loop=loop, remote_resolve=False) + + async with aiohttp.ClientSession( + connector=conn, loop=loop, + request_class=ProxyClientRequest) as ses: + proxy = 'socks4://127.0.0.1:{}'.format(srv.port) + + async with ses.get(ws.make_url('/'), proxy=proxy) as resp: + assert resp.status == 200 + assert (await resp.text()) == 'Test message' + + +async def test_https_connect(loop): + async def handler(request): + return web.Response(text='Test message') + + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + + ws = RawTestServer(handler, scheme='https', host='127.0.0.1', loop=loop) + await ws.start_server(loop=loop, ssl=sslcontext) + + v_fp = b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\xf9' + inv_fp = b's\x93\xfd:\xed\x08\x1do\xa9\xaeq9\x1a\xe3\xc5\x7f\x89\xe7l\x10' + + async with FakeSocks4Srv(loop) as srv: + v_conn = ProxyConnector(loop=loop, remote_resolve=False, + verify_ssl=False, fingerprint=v_fp) + inv_conn = ProxyConnector(loop=loop, remote_resolve=False, + verify_ssl=False, fingerprint=inv_fp) + + async with aiohttp.ClientSession( + connector=v_conn, loop=loop, + request_class=ProxyClientRequest) as ses: + proxy = 'socks4://127.0.0.1:{}'.format(srv.port) + + async with ses.get(ws.make_url('/'), proxy=proxy) as resp: + assert resp.status == 200 + assert (await resp.text()) == 'Test message' + + async with aiohttp.ClientSession( + connector=inv_conn, loop=loop, + request_class=ProxyClientRequest) as ses: + proxy = 'socks4://127.0.0.1:{}'.format(srv.port) + + with pytest.raises(aiohttp.ServerFingerprintMismatch): + async with ses.get(ws.make_url('/'), proxy=proxy) as resp: + assert resp.status == 200 diff --git a/tests/test_protocols.py b/tests/test_protocols.py index c451be2..ee529f7 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -1,18 +1,13 @@ import asyncio import aiosocks -import unittest +import pytest import socket import ssl as ssllib from unittest import mock from asyncio import coroutine as coro +from aiohttp.test_utils import make_mocked_coro import aiosocks.constants as c from aiosocks.protocols import BaseSocksProtocol -from .helpers import fake_coroutine - -try: - from asyncio import ensure_future -except ImportError: - ensure_future = asyncio.async def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None): @@ -39,7 +34,6 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'', proto._get_dst_addr = mock.Mock( side_effect=coro(mock.Mock(return_value=(socket.AF_INET, '127.0.0.1'))) ) - return proto @@ -53,7 +47,7 @@ def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop, app_protocol_factory=ap_factory, waiter=whiter) proto._stream_writer = mock.Mock() - proto._stream_writer.drain = fake_coroutine(True) + proto._stream_writer.drain = make_mocked_coro(True) if not isinstance(r, (list, tuple)): proto.read_response = mock.Mock( @@ -65,577 +59,560 @@ def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, proto._get_dst_addr = mock.Mock( side_effect=coro(mock.Mock(return_value=(socket.AF_INET, '127.0.0.1'))) ) - return proto -class TestBaseSocksProtocol(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) +def test_base_ctor(loop): + with pytest.raises(ValueError): + BaseSocksProtocol(None, None, None, loop=loop, + waiter=None, app_protocol_factory=None) + + with pytest.raises(ValueError): + BaseSocksProtocol(None, None, 123, loop=loop, + waiter=None, app_protocol_factory=None) + + with pytest.raises(ValueError): + BaseSocksProtocol(None, None, ('python.org',), loop=loop, + waiter=None, app_protocol_factory=None) + + +def test_base_write_request(loop): + proto = make_base(loop) + proto._stream_writer = mock.Mock() + + proto.write_request([b'\x00', b'\x01\x02', 0x03]) + proto._stream_writer.write.assert_called_with(b'\x00\x01\x02\x03') + + with pytest.raises(ValueError): + proto.write_request(['\x00']) + + +async def test_base_negotiate_os_error(loop): + waiter = asyncio.Future(loop=loop) + proto = make_base(loop, waiter=waiter) + proto.socks_request = make_mocked_coro(raise_exception=OSError('test')) + await proto.negotiate(None, None) + + with pytest.raises(OSError) as ct: + await waiter + assert 'test' in str(ct) + + +async def test_base_negotiate_socks_err(loop): + waiter = asyncio.Future(loop=loop) + proto = make_base(loop, waiter=waiter) + proto.socks_request = make_mocked_coro( + raise_exception=aiosocks.SocksError('test')) + await proto.negotiate(None, None) + + with pytest.raises(aiosocks.SocksError) as ct: + await waiter + assert 'Can not connect to' in str(ct) + + +async def test_base_negotiate_without_app_proto(loop): + waiter = asyncio.Future(loop=loop) + proto = make_base(loop, waiter=waiter) + proto.socks_request = make_mocked_coro((None, None)) + proto._transport = True + + await proto.negotiate(None, None) + await waiter + assert waiter.done() + + +async def test_base_negotiate_with_app_proto(loop): + waiter = asyncio.Future(loop=loop) + proto = make_base(loop, waiter=waiter, + ap_factory=lambda: asyncio.Protocol()) + proto.socks_request = make_mocked_coro((None, None)) + + await proto.negotiate(None, None) + await waiter + assert waiter.done() + + +def test_base_connection_lost(): + loop_mock = mock.Mock() + app_proto = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + + # negotiate not completed + proto._negotiate_done = False + proto.connection_lost(True) + assert not loop_mock.call_soon.called + + # negotiate successfully competed + loop_mock.reset_mock() + proto._negotiate_done = True + proto.connection_lost(True) + assert loop_mock.call_soon.called + + # don't call connect_lost, if app_protocol == self + # otherwise recursion + loop_mock.reset_mock() + proto = make_base(loop_mock, ap_factory=None) + proto._negotiate_done = True + proto.connection_lost(True) + assert not loop_mock.call_soon.called + + +def test_base_pause_writing(): + loop_mock = mock.Mock() + app_proto = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + + # negotiate not completed + proto._negotiate_done = False + proto.pause_writing() + assert not proto._app_protocol.pause_writing.called + + # negotiate successfully competed + app_proto.reset_mock() + proto._negotiate_done = True + proto.pause_writing() + assert proto._app_protocol.pause_writing.called - def tearDown(self): - self.loop.close() + # don't call pause_writing, if app_protocol == self + # otherwise recursion + app_proto.reset_mock() + proto = make_base(loop_mock) + proto._negotiate_done = True + proto.pause_writing() - def test_init(self): - with self.assertRaises(ValueError): - BaseSocksProtocol(None, None, None, loop=self.loop, - waiter=None, app_protocol_factory=None) - - with self.assertRaises(ValueError): - BaseSocksProtocol(None, None, 123, loop=self.loop, - waiter=None, app_protocol_factory=None) - - with self.assertRaises(ValueError): - BaseSocksProtocol(None, None, ('python.org',), loop=self.loop, - waiter=None, app_protocol_factory=None) - def test_write_request(self): - proto = make_base(self.loop) - proto._stream_writer = mock.Mock() - - proto.write_request([b'\x00', b'\x01\x02', 0x03]) - proto._stream_writer.write.assert_called_with(b'\x00\x01\x02\x03') - - with self.assertRaises(ValueError): - proto.write_request(['\x00']) - - def test_negotiate_os_error(self): - waiter = asyncio.Future(loop=self.loop) - proto = make_base(self.loop, waiter=waiter) - proto.socks_request = fake_coroutine(OSError('test')) - - self.loop.run_until_complete(proto.negotiate(None, None)) - self.assertIn('test', str(waiter.exception())) - - def test_negotiate_socks_err(self): - waiter = asyncio.Future(loop=self.loop) - proto = make_base(self.loop, waiter=waiter) - proto.socks_request = fake_coroutine(aiosocks.SocksError('test')) - - self.loop.run_until_complete(proto.negotiate(None, None)) - self.assertIn('Can not connect to', str(waiter.exception())) - - def test_negotiate_without_app_proto(self): - waiter = asyncio.Future(loop=self.loop) - proto = make_base(self.loop, waiter=waiter) - proto.socks_request = fake_coroutine((None, None)) - proto._transport = True - - self.loop.run_until_complete(proto.negotiate(None, None)) - self.assertTrue(waiter.done()) - - def test_negotiate_with_app_proto(self): - waiter = asyncio.Future(loop=self.loop) - proto = make_base(self.loop, waiter=waiter, - ap_factory=lambda: asyncio.Protocol()) - proto.socks_request = fake_coroutine((None, None)) - - self.loop.run_until_complete(proto.negotiate(None, None)) - self.assertTrue(waiter.done()) - - def test_connection_lost(self): - loop_mock = mock.Mock() - app_proto = mock.Mock() - - proto = make_base(loop_mock, ap_factory=lambda: app_proto) - - # negotiate not completed - proto._negotiate_done = False - proto.connection_lost(True) - self.assertFalse(loop_mock.call_soon.called) - - # negotiate successfully competed - loop_mock.reset_mock() - proto._negotiate_done = True - proto.connection_lost(True) - self.assertTrue(loop_mock.call_soon.called) - - # don't call connect_lost, if app_protocol == self - # otherwise recursion - loop_mock.reset_mock() - proto = make_base(loop_mock, ap_factory=None) - proto._negotiate_done = True - proto.connection_lost(True) - self.assertFalse(loop_mock.call_soon.called) - - def test_pause_writing(self): - loop_mock = mock.Mock() - app_proto = mock.Mock() - - proto = make_base(loop_mock, ap_factory=lambda: app_proto) - - # negotiate not completed - proto._negotiate_done = False - proto.pause_writing() - self.assertFalse(proto._app_protocol.pause_writing.called) - - # negotiate successfully competed - app_proto.reset_mock() - proto._negotiate_done = True - proto.pause_writing() - self.assertTrue(proto._app_protocol.pause_writing.called) - - # don't call pause_writing, if app_protocol == self - # otherwise recursion - app_proto.reset_mock() - proto = make_base(loop_mock) - proto._negotiate_done = True - proto.pause_writing() - - def test_resume_writing(self): - loop_mock = mock.Mock() - app_proto = mock.Mock() - - proto = make_base(loop_mock, ap_factory=lambda: app_proto) - - # negotiate not completed - proto._negotiate_done = False - # negotiate not completed - with self.assertRaises(AssertionError): - proto.resume_writing() - self.assertFalse(proto._app_protocol.resume_writing.called) - - # negotiate successfully competed - loop_mock.reset_mock() - proto._negotiate_done = True +def test_base_resume_writing(): + loop_mock = mock.Mock() + app_proto = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + + # negotiate not completed + proto._negotiate_done = False + # negotiate not completed + with pytest.raises(AssertionError): proto.resume_writing() - self.assertTrue(proto._app_protocol.resume_writing.called) - - # don't call resume_writing, if app_protocol == self - # otherwise recursion - loop_mock.reset_mock() - proto = make_base(loop_mock) - proto._negotiate_done = True - with self.assertRaises(AssertionError): - proto.resume_writing() - - def test_data_received(self): - loop_mock = mock.Mock() - app_proto = mock.Mock() - - proto = make_base(loop_mock, ap_factory=lambda: app_proto) - - # negotiate not completed - proto._negotiate_done = False - proto.data_received(b'123') - self.assertFalse(proto._app_protocol.data_received.called) - - # negotiate successfully competed - app_proto.reset_mock() - proto._negotiate_done = True - proto.data_received(b'123') - self.assertTrue(proto._app_protocol.data_received.called) - - # don't call data_received, if app_protocol == self - # otherwise recursion - loop_mock.reset_mock() - proto = make_base(loop_mock) - proto._negotiate_done = True - proto.data_received(b'123') - - def test_eof_received(self): - loop_mock = mock.Mock() - app_proto = mock.Mock() - - proto = make_base(loop_mock, ap_factory=lambda: app_proto) - - # negotiate not completed - proto._negotiate_done = False - proto.eof_received() - self.assertFalse(proto._app_protocol.eof_received.called) - - # negotiate successfully competed - app_proto.reset_mock() - proto._negotiate_done = True - proto.eof_received() - self.assertTrue(proto._app_protocol.eof_received.called) - - # don't call pause_writing, if app_protocol == self - # otherwise recursion - app_proto.reset_mock() - proto = make_base(loop_mock) - proto._negotiate_done = True - proto.eof_received() - - def test_make_ssl_proto(self): - loop_mock = mock.Mock() - app_proto = mock.Mock() - - ssl_context = ssllib.create_default_context() - proto = make_base(loop_mock, - ap_factory=lambda: app_proto, ssl=ssl_context) - proto.socks_request = fake_coroutine((None, None)) - proto._transport = mock.Mock() - self.loop.run_until_complete(proto.negotiate(None, None)) - - self.assertTrue(loop_mock._make_ssl_transport.called) - self.assertIs(loop_mock._make_ssl_transport.call_args[1]['sslcontext'], - ssl_context) - - @mock.patch('aiosocks.protocols.asyncio.Task') - def test_func_negotiate_cb_call(self, task_mock): - loop_mock = mock.Mock() - waiter = mock.Mock() - - proto = make_base(loop_mock, waiter=waiter) - proto.socks_request = fake_coroutine((None, None)) - proto._negotiate_done_cb = mock.Mock() - - self.loop.run_until_complete(proto.negotiate(None, None)) - self.assertTrue(proto._negotiate_done_cb.called) - self.assertFalse(task_mock.called) - - @mock.patch('aiosocks.protocols.asyncio.Task') - def test_coro_negotiate_cb_call(self, task_mock): - loop_mock = mock.Mock() - waiter = mock.Mock() - - proto = make_base(loop_mock, waiter=waiter) - proto.socks_request = fake_coroutine((None, None)) - proto._negotiate_done_cb = fake_coroutine(None) - - self.loop.run_until_complete(proto.negotiate(None, None)) - self.assertTrue(proto._negotiate_done_cb.called) - self.assertTrue(task_mock.called) - - def test_reader_limit(self): - proto = BaseSocksProtocol(None, None, ('python.org', 80), - None, None, reader_limit=10, - loop=self.loop) - self.assertEqual(proto.reader._limit, 10) - - proto = BaseSocksProtocol(None, None, ('python.org', 80), - None, None, reader_limit=15, - loop=self.loop) - self.assertEqual(proto.reader._limit, 15) - - def test_incomplete_error(self): - proto = BaseSocksProtocol(None, None, ('python.org', 80), - None, None, reader_limit=10, - loop=self.loop) - proto._stream_reader.readexactly = fake_coroutine( - asyncio.IncompleteReadError(b'part', 5)) - with self.assertRaises(aiosocks.InvalidServerReply): - self.loop.run_until_complete(proto.read_response(4)) - - -class TestSocks4Protocol(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_init(self): - addr = aiosocks.Socks4Addr('localhost', 1080) - auth = aiosocks.Socks4Auth('user') - dst = ('python.org', 80) - - with self.assertRaises(ValueError): - aiosocks.Socks4Protocol(None, None, dst, loop=self.loop, - waiter=None, app_protocol_factory=None) - - with self.assertRaises(ValueError): - aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop, - waiter=None, app_protocol_factory=None) - - with self.assertRaises(ValueError): - aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, - loop=self.loop, waiter=None, - app_protocol_factory=None) - - with self.assertRaises(ValueError): - aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, - loop=self.loop, waiter=None, - app_protocol_factory=None) - - aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop, + assert not proto._app_protocol.resume_writing.called + + # negotiate successfully competed + loop_mock.reset_mock() + proto._negotiate_done = True + proto.resume_writing() + assert proto._app_protocol.resume_writing.called + + # don't call resume_writing, if app_protocol == self + # otherwise recursion + loop_mock.reset_mock() + proto = make_base(loop_mock) + proto._negotiate_done = True + with pytest.raises(AssertionError): + proto.resume_writing() + + +def test_base_data_received(): + loop_mock = mock.Mock() + app_proto = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + + # negotiate not completed + proto._negotiate_done = False + proto.data_received(b'123') + assert not proto._app_protocol.data_received.called + + # negotiate successfully competed + app_proto.reset_mock() + proto._negotiate_done = True + proto.data_received(b'123') + assert proto._app_protocol.data_received.called + + # don't call data_received, if app_protocol == self + # otherwise recursion + loop_mock.reset_mock() + proto = make_base(loop_mock) + proto._negotiate_done = True + proto.data_received(b'123') + + +def test_base_eof_received(): + loop_mock = mock.Mock() + app_proto = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + + # negotiate not completed + proto._negotiate_done = False + proto.eof_received() + assert not proto._app_protocol.eof_received.called + + # negotiate successfully competed + app_proto.reset_mock() + proto._negotiate_done = True + proto.eof_received() + assert proto._app_protocol.eof_received.called + + # don't call pause_writing, if app_protocol == self + # otherwise recursion + app_proto.reset_mock() + proto = make_base(loop_mock) + proto._negotiate_done = True + proto.eof_received() + + +async def test_base_make_ssl_proto(): + loop_mock = mock.Mock() + app_proto = mock.Mock() + + ssl_context = ssllib.create_default_context() + proto = make_base(loop_mock, + ap_factory=lambda: app_proto, ssl=ssl_context) + proto.socks_request = make_mocked_coro((None, None)) + proto._transport = mock.Mock() + await proto.negotiate(None, None) + + mtr = loop_mock._make_ssl_transport + + assert mtr.called + assert mtr.call_args[1]['sslcontext'] is ssl_context + + +async def test_base_func_negotiate_cb_call(): + loop_mock = mock.Mock() + waiter = mock.Mock() + + proto = make_base(loop_mock, waiter=waiter) + proto.socks_request = make_mocked_coro((None, None)) + proto._negotiate_done_cb = mock.Mock() + + with mock.patch('aiosocks.protocols.asyncio.Task') as task_mock: + await proto.negotiate(None, None) + assert proto._negotiate_done_cb.called + assert not task_mock.called + + +async def test_base_coro_negotiate_cb_call(): + loop_mock = mock.Mock() + waiter = mock.Mock() + + proto = make_base(loop_mock, waiter=waiter) + proto.socks_request = make_mocked_coro((None, None)) + proto._negotiate_done_cb = make_mocked_coro(None) + + with mock.patch('aiosocks.protocols.asyncio.Task') as task_mock: + await proto.negotiate(None, None) + assert proto._negotiate_done_cb.called + assert task_mock.called + + +async def test_base_reader_limit(loop): + proto = BaseSocksProtocol(None, None, ('python.org', 80), + None, None, reader_limit=10, loop=loop) + assert proto.reader._limit == 10 + + proto = BaseSocksProtocol(None, None, ('python.org', 80), + None, None, reader_limit=15, loop=loop) + assert proto.reader._limit == 15 + + +async def test_base_incomplete_error(loop): + proto = BaseSocksProtocol(None, None, ('python.org', 80), + None, None, reader_limit=10, loop=loop) + proto._stream_reader.readexactly = make_mocked_coro( + raise_exception=asyncio.IncompleteReadError(b'part', 5)) + with pytest.raises(aiosocks.InvalidServerReply): + await proto.read_response(4) + + +def test_socks4_ctor(loop): + addr = aiosocks.Socks4Addr('localhost', 1080) + auth = aiosocks.Socks4Auth('user') + dst = ('python.org', 80) + + with pytest.raises(ValueError): + aiosocks.Socks4Protocol(None, None, dst, loop=loop, waiter=None, app_protocol_factory=None) - aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop, + + with pytest.raises(ValueError): + aiosocks.Socks4Protocol(None, auth, dst, loop=loop, waiter=None, app_protocol_factory=None) - def test_dst_domain_with_remote_resolve(self): - proto = make_socks4(self.loop, dst=('python.org', 80), - r=b'\x00\x5a\x00P\x7f\x00\x00\x01') + with pytest.raises(ValueError): + aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, + loop=loop, waiter=None, + app_protocol_factory=None) + + with pytest.raises(ValueError): + aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, + loop=loop, waiter=None, + app_protocol_factory=None) + + aiosocks.Socks4Protocol(addr, None, dst, loop=loop, + waiter=None, app_protocol_factory=None) + aiosocks.Socks4Protocol(addr, auth, dst, loop=loop, + waiter=None, app_protocol_factory=None) + - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - self.loop.run_until_complete(req) +async def test_socks4_dst_domain_with_remote_resolve(loop): + proto = make_socks4(loop, dst=('python.org', 80), + r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - proto._stream_writer.write.assert_called_with( - b'\x04\x01\x00P\x00\x00\x00\x01user\x00python.org\x00' - ) + await proto.socks_request(c.SOCKS_CMD_CONNECT) + proto._stream_writer.write.assert_called_with( + b'\x04\x01\x00P\x00\x00\x00\x01user\x00python.org\x00') - def test_dst_domain_with_local_resolve(self): - proto = make_socks4(self.loop, dst=('python.org', 80), - rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - self.loop.run_until_complete(req) +async def test_socks4_dst_domain_with_local_resolve(loop): + proto = make_socks4(loop, dst=('python.org', 80), + rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - proto._stream_writer.write.assert_called_with( - b'\x04\x01\x00P\x7f\x00\x00\x01user\x00' - ) + await proto.socks_request(c.SOCKS_CMD_CONNECT) + proto._stream_writer.write.assert_called_with( + b'\x04\x01\x00P\x7f\x00\x00\x01user\x00') - def test_dst_ip_with_remote_resolve(self): - proto = make_socks4(self.loop, dst=('127.0.0.1', 8800), - r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - self.loop.run_until_complete(req) - proto._stream_writer.write.assert_called_with( - b'\x04\x01"`\x7f\x00\x00\x01user\x00' - ) +async def test_socks4_dst_ip_with_remote_resolve(loop): + proto = make_socks4(loop, dst=('127.0.0.1', 8800), + r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - def test_dst_ip_with_locale_resolve(self): - proto = make_socks4(self.loop, dst=('127.0.0.1', 8800), - rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01') + await proto.socks_request(c.SOCKS_CMD_CONNECT) + proto._stream_writer.write.assert_called_with( + b'\x04\x01"`\x7f\x00\x00\x01user\x00') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - self.loop.run_until_complete(req) - proto._stream_writer.write.assert_called_with( - b'\x04\x01"`\x7f\x00\x00\x01user\x00' - ) +async def test_socks4_dst_ip_with_locale_resolve(loop): + proto = make_socks4(loop, dst=('127.0.0.1', 8800), + rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - def test_dst_domain_without_user(self): - proto = make_socks4(self.loop, auth=aiosocks.Socks4Auth(''), - dst=('python.org', 80), - r=b'\x00\x5a\x00P\x7f\x00\x00\x01') + await proto.socks_request(c.SOCKS_CMD_CONNECT) + proto._stream_writer.write.assert_called_with( + b'\x04\x01"`\x7f\x00\x00\x01user\x00') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - self.loop.run_until_complete(req) - proto._stream_writer.write.assert_called_with( - b'\x04\x01\x00P\x00\x00\x00\x01\x00python.org\x00' - ) +async def test_socks4_dst_domain_without_user(loop): + proto = make_socks4(loop, auth=aiosocks.Socks4Auth(''), + dst=('python.org', 80), + r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - def test_dst_ip_without_user(self): - proto = make_socks4(self.loop, auth=aiosocks.Socks4Auth(''), - dst=('127.0.0.1', 8800), - r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - self.loop.run_until_complete(req) + await proto.socks_request(c.SOCKS_CMD_CONNECT) + proto._stream_writer.write.assert_called_with( + b'\x04\x01\x00P\x00\x00\x00\x01\x00python.org\x00') - proto._stream_writer.write.assert_called_with( - b'\x04\x01"`\x7f\x00\x00\x01\x00' - ) - def test_valid_resp_handling(self): - proto = make_socks4(self.loop, r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - req = ensure_future( - proto.socks_request(c.SOCKS_CMD_CONNECT), loop=self.loop) - self.loop.run_until_complete(req) +async def test_socks4_dst_ip_without_user(loop): + proto = make_socks4(loop, auth=aiosocks.Socks4Auth(''), + dst=('127.0.0.1', 8800), + r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80))) + await proto.socks_request(c.SOCKS_CMD_CONNECT) + proto._stream_writer.write.assert_called_with( + b'\x04\x01"`\x7f\x00\x00\x01\x00') - def test_invalid_reply_resp_handling(self): - proto = make_socks4(self.loop, r=b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - with self.assertRaises(aiosocks.InvalidServerReply): - self.loop.run_until_complete(req) +async def test_socks4_valid_resp_handling(loop): + proto = make_socks4(loop, r=b'\x00\x5a\x00P\x7f\x00\x00\x01') - def test_socks_err_resp_handling(self): - proto = make_socks4(self.loop, r=b'\x00\x5b\x00P\x7f\x00\x00\x01') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) + r = await proto.socks_request(c.SOCKS_CMD_CONNECT) + assert r == (('python.org', 80), ('127.0.0.1', 80)) - with self.assertRaises(aiosocks.SocksError) as cm: - self.loop.run_until_complete(req) - self.assertTrue('0x5b' in str(cm.exception)) +async def test_socks4_invalid_reply_resp_handling(loop): + proto = make_socks4(loop, r=b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF') - def test_unknown_err_resp_handling(self): - proto = make_socks4(self.loop, r=b'\x00\x5e\x00P\x7f\x00\x00\x01') - req = proto.socks_request(c.SOCKS_CMD_CONNECT) + with pytest.raises(aiosocks.InvalidServerReply): + await proto.socks_request(c.SOCKS_CMD_CONNECT) - with self.assertRaises(aiosocks.SocksError) as cm: - self.loop.run_until_complete(req) - self.assertTrue('Unknown error' in str(cm.exception)) +async def test_socks_err_resp_handling(loop): + proto = make_socks4(loop, r=b'\x00\x5b\x00P\x7f\x00\x00\x01') + with pytest.raises(aiosocks.SocksError) as cm: + await proto.socks_request(c.SOCKS_CMD_CONNECT) + assert '0x5b' in str(cm) -class TestSocks5Protocol(unittest.TestCase): - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - def tearDown(self): - self.loop.close() +async def test_socks4_unknown_err_resp_handling(loop): + proto = make_socks4(loop, r=b'\x00\x5e\x00P\x7f\x00\x00\x01') - def test_init(self): - addr = aiosocks.Socks5Addr('localhost', 1080) - auth = aiosocks.Socks5Auth('user', 'pwd') - dst = ('python.org', 80) + with pytest.raises(aiosocks.SocksError) as cm: + await proto.socks_request(c.SOCKS_CMD_CONNECT) + assert 'Unknown error' in str(cm) - with self.assertRaises(ValueError): - aiosocks.Socks5Protocol(None, None, dst, loop=self.loop, - waiter=None, app_protocol_factory=None) - with self.assertRaises(ValueError): - aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop, - waiter=None, app_protocol_factory=None) +def test_socks5_ctor(loop): + addr = aiosocks.Socks5Addr('localhost', 1080) + auth = aiosocks.Socks5Auth('user', 'pwd') + dst = ('python.org', 80) - with self.assertRaises(ValueError): - aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), - auth, dst, loop=self.loop, - waiter=None, app_protocol_factory=None) + with pytest.raises(ValueError): + aiosocks.Socks5Protocol(None, None, dst, loop=loop, + waiter=None, app_protocol_factory=None) - with self.assertRaises(ValueError): - aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), - dst, loop=self.loop, - waiter=None, app_protocol_factory=None) + with pytest.raises(ValueError): + aiosocks.Socks5Protocol(None, auth, dst, loop=loop, + waiter=None, app_protocol_factory=None) - aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop, + with pytest.raises(ValueError): + aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), + auth, dst, loop=loop, waiter=None, app_protocol_factory=None) - aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop, + + with pytest.raises(ValueError): + aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), + dst, loop=loop, waiter=None, app_protocol_factory=None) - def test_auth_inv_srv_ver(self): - proto = make_socks5(self.loop, r=b'\x00\x00') - req = proto.authenticate() - - with self.assertRaises(aiosocks.InvalidServerVersion): - self.loop.run_until_complete(req) - - def test_auth_no_acceptable_auth_methods(self): - proto = make_socks5(self.loop, r=b'\x05\xFF') - req = proto.authenticate() - with self.assertRaises(aiosocks.NoAcceptableAuthMethods): - self.loop.run_until_complete(req) - - def test_auth_unsupported_auth_method(self): - proto = make_socks5(self.loop, r=b'\x05\xF0') - req = proto.authenticate() - with self.assertRaises(aiosocks.InvalidServerReply): - self.loop.run_until_complete(req) - - def test_auth_usr_pwd_granted(self): - proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x01\x00',)) - self.loop.run_until_complete(proto.authenticate()) - proto._stream_writer.write.assert_has_calls([ - mock.call(b'\x05\x02\x00\x02'), - mock.call(b'\x01\x04user\x03pwd') - ]) - - def test_auth_invalid_reply(self): - proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x00\x00',)) - req = proto.authenticate() - with self.assertRaises(aiosocks.InvalidServerReply): - self.loop.run_until_complete(req) - - def test_auth_access_denied(self): - proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x01\x01',)) - req = proto.authenticate() - with self.assertRaises(aiosocks.LoginAuthenticationFailed): - self.loop.run_until_complete(req) - - def test_auth_anonymous_granted(self): - proto = make_socks5(self.loop, r=b'\x05\x00') - req = proto.authenticate() - self.loop.run_until_complete(req) - - def test_build_dst_addr_ipv4(self): - proto = make_socks5(self.loop) - c = proto.build_dst_address('127.0.0.1', 80) - dst_req, resolved = self.loop.run_until_complete(c) - - self.assertEqual(dst_req, [0x01, b'\x7f\x00\x00\x01', b'\x00P']) - self.assertEqual(resolved, ('127.0.0.1', 80)) - - def test_build_dst_addr_ipv6(self): - proto = make_socks5(self.loop) - c = proto.build_dst_address( - '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80) - dst_req, resolved = self.loop.run_until_complete(c) - - self.assertEqual(dst_req, [ - 0x04, b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', b'\x00P']) - self.assertEqual(resolved, - ('2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)) - - def test_build_dst_addr_domain_with_remote_resolve(self): - proto = make_socks5(self.loop) - c = proto.build_dst_address('python.org', 80) - dst_req, resolved = self.loop.run_until_complete(c) - - self.assertEqual(dst_req, [0x03, b'\n', b'python.org', b'\x00P']) - self.assertEqual(resolved, ('python.org', 80)) - - def test_build_dst_addr_domain_with_locale_resolve(self): - proto = make_socks5(self.loop, rr=False) - c = proto.build_dst_address('python.org', 80) - dst_req, resolved = self.loop.run_until_complete(c) - - self.assertEqual(dst_req, [0x01, b'\x7f\x00\x00\x01', b'\x00P']) - self.assertEqual(resolved, ('127.0.0.1', 80)) - - def test_rd_addr_ipv4(self): - proto = make_socks5( - self.loop, r=[b'\x01', b'\x7f\x00\x00\x01', b'\x00P']) - req = ensure_future(proto.read_address(), loop=self.loop) - self.loop.run_until_complete(req) - - self.assertEqual(req.result(), ('127.0.0.1', 80)) - - def test_rd_addr_ipv6(self): - resp = [ - b'\x04', - b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', - b'\x00P' - ] - proto = make_socks5(self.loop, r=resp) - req = ensure_future(proto.read_address(), loop=self.loop) - self.loop.run_until_complete(req) - - self.assertEqual( - req.result(), ('2001:db8:11a3:9d7:1f34:8a2e:7a0:765d', 80)) - - def test_rd_addr_domain(self): - proto = make_socks5( - self.loop, r=[b'\x03', b'\n', b'python.org', b'\x00P']) - req = ensure_future(proto.read_address(), loop=self.loop) - self.loop.run_until_complete(req) - - self.assertEqual(req.result(), (b'python.org', 80)) - - def test_socks_req_inv_ver(self): - proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x04\x00\x00']) - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - with self.assertRaises(aiosocks.InvalidServerVersion): - self.loop.run_until_complete(req) - - def test_socks_req_socks_srv_err(self): - proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x05\x02\x00']) - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - with self.assertRaises(aiosocks.SocksError) as ct: - self.loop.run_until_complete(req) - - self.assertTrue( - 'Connection not allowed by ruleset' in str(ct.exception)) - - def test_socks_req_unknown_err(self): - proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x05\xFF\x00']) - req = proto.socks_request(c.SOCKS_CMD_CONNECT) - with self.assertRaises(aiosocks.SocksError) as ct: - self.loop.run_until_complete(req) - - self.assertTrue('Unknown error' in str(ct.exception)) - - def test_socks_req_cmd_granted(self): - # cmd granted - resp = [b'\x05\x00', - b'\x05\x00\x00', - b'\x01', b'\x7f\x00\x00\x01', - b'\x00P'] - proto = make_socks5(self.loop, r=resp) - req = ensure_future(proto.socks_request(c.SOCKS_CMD_CONNECT), - loop=self.loop) - self.loop.run_until_complete(req) - - self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80))) - proto._stream_writer.write.assert_has_calls([ - mock.call(b'\x05\x02\x00\x02'), - mock.call(b'\x05\x01\x00\x03\npython.org\x00P') - ]) + aiosocks.Socks5Protocol(addr, None, dst, loop=loop, + waiter=None, app_protocol_factory=None) + aiosocks.Socks5Protocol(addr, auth, dst, loop=loop, + waiter=None, app_protocol_factory=None) + + +async def test_socks5_auth_inv_srv_ver(loop): + proto = make_socks5(loop, r=b'\x00\x00') + + with pytest.raises(aiosocks.InvalidServerVersion): + await proto.authenticate() + + +async def test_socks5_auth_no_acceptable_auth_methods(loop): + proto = make_socks5(loop, r=b'\x05\xFF') + + with pytest.raises(aiosocks.NoAcceptableAuthMethods): + await proto.authenticate() + + +async def test_socks5_auth_unsupported_auth_method(loop): + proto = make_socks5(loop, r=b'\x05\xF0') + + with pytest.raises(aiosocks.InvalidServerReply): + await proto.authenticate() + + +async def test_socks5_auth_usr_pwd_granted(loop): + proto = make_socks5(loop, r=(b'\x05\x02', b'\x01\x00',)) + await proto.authenticate() + + proto._stream_writer.write.assert_has_calls([ + mock.call(b'\x05\x02\x00\x02'), + mock.call(b'\x01\x04user\x03pwd') + ]) + + +async def test_socks5_auth_invalid_reply(loop): + proto = make_socks5(loop, r=(b'\x05\x02', b'\x00\x00',)) + + with pytest.raises(aiosocks.InvalidServerReply): + await proto.authenticate() + + +async def test_socks5_auth_access_denied(loop): + proto = make_socks5(loop, r=(b'\x05\x02', b'\x01\x01',)) + + with pytest.raises(aiosocks.LoginAuthenticationFailed): + await proto.authenticate() + + +async def test_socks5_auth_anonymous_granted(loop): + proto = make_socks5(loop, r=b'\x05\x00') + await proto.authenticate() + + +async def test_socks5_build_dst_addr_ipv4(loop): + proto = make_socks5(loop) + dst_req, resolved = await proto.build_dst_address('127.0.0.1', 80) + + assert dst_req == [0x01, b'\x7f\x00\x00\x01', b'\x00P'] + assert resolved == ('127.0.0.1', 80) + + +async def test_socks5_build_dst_addr_ipv6(loop): + proto = make_socks5(loop) + dst_req, resolved = await proto.build_dst_address( + '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80) + + assert dst_req == [ + 0x04, b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', b'\x00P'] + assert resolved == ('2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80) + + +async def test_socks5_build_dst_addr_domain_with_remote_resolve(loop): + proto = make_socks5(loop) + dst_req, resolved = await proto.build_dst_address('python.org', 80) + + assert dst_req == [0x03, b'\n', b'python.org', b'\x00P'] + assert resolved == ('python.org', 80) + + +async def test_socks5_build_dst_addr_domain_with_locale_resolve(loop): + proto = make_socks5(loop, rr=False) + dst_req, resolved = await proto.build_dst_address('python.org', 80) + + assert dst_req == [0x01, b'\x7f\x00\x00\x01', b'\x00P'] + assert resolved == ('127.0.0.1', 80) + + +async def test_socks5_rd_addr_ipv4(loop): + proto = make_socks5(loop, r=[b'\x01', b'\x7f\x00\x00\x01', b'\x00P']) + r = await proto.read_address() + + assert r == ('127.0.0.1', 80) + + +async def test_socks5_rd_addr_ipv6(loop): + resp = [ + b'\x04', + b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', + b'\x00P' + ] + proto = make_socks5(loop, r=resp) + r = await proto.read_address() + + assert r == ('2001:db8:11a3:9d7:1f34:8a2e:7a0:765d', 80) + + +async def test_socks5_rd_addr_domain(loop): + proto = make_socks5(loop, r=[b'\x03', b'\n', b'python.org', b'\x00P']) + r = await proto.read_address() + + assert r == (b'python.org', 80) + + +async def test_socks5_socks_req_inv_ver(loop): + proto = make_socks5(loop, r=[b'\x05\x00', b'\x04\x00\x00']) + + with pytest.raises(aiosocks.InvalidServerVersion): + await proto.socks_request(c.SOCKS_CMD_CONNECT) + + +async def test_socks5_socks_req_socks_srv_err(loop): + proto = make_socks5(loop, r=[b'\x05\x00', b'\x05\x02\x00']) + + with pytest.raises(aiosocks.SocksError) as ct: + await proto.socks_request(c.SOCKS_CMD_CONNECT) + assert 'Connection not allowed by ruleset' in str(ct) + + +async def test_socks5_socks_req_unknown_err(loop): + proto = make_socks5(loop, r=[b'\x05\x00', b'\x05\xFF\x00']) + + with pytest.raises(aiosocks.SocksError) as ct: + await proto.socks_request(c.SOCKS_CMD_CONNECT) + assert 'Unknown error' in str(ct) + + +async def test_socks_req_cmd_granted(loop): + # cmd granted + resp = [b'\x05\x00', + b'\x05\x00\x00', + b'\x01', b'\x7f\x00\x00\x01', + b'\x00P'] + proto = make_socks5(loop, r=resp) + r = await proto.socks_request(c.SOCKS_CMD_CONNECT) + + assert r == (('python.org', 80), ('127.0.0.1', 80)) + proto._stream_writer.write.assert_has_calls([ + mock.call(b'\x05\x02\x00\x02'), + mock.call(b'\x05\x01\x00\x03\npython.org\x00P') + ])