@@ -47,7 +47,15 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, | |||
"proxy is Socks5Addr but proxy_auth is not Socks5Auth" | |||
) | |||
if server_hostname is not None and not ssl: | |||
raise ValueError('server_hostname is only meaningful with ssl') | |||
if server_hostname is None and ssl: | |||
# read details: asyncio.create_connection | |||
server_hostname = dst[0] | |||
loop = loop or asyncio.get_event_loop() | |||
waiter = asyncio.Future(loop=loop) | |||
def socks_factory(): | |||
if isinstance(proxy, Socks4Addr): | |||
@@ -55,30 +63,24 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, | |||
else: | |||
socks_proto = Socks5Protocol | |||
return socks_proto( | |||
proxy=proxy, proxy_auth=proxy_auth, dst=dst, | |||
remote_resolve=remote_resolve, loop=loop) | |||
return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst, | |||
app_protocol_factory=protocol_factory, | |||
waiter=waiter, remote_resolve=remote_resolve, | |||
loop=loop, ssl=ssl, server_hostname=server_hostname) | |||
try: | |||
transport, protocol = yield from loop.create_connection( | |||
socks_factory, proxy.host, proxy.port, ssl=ssl, family=family, | |||
proto=proto, flags=flags, sock=sock, local_addr=local_addr, | |||
server_hostname=server_hostname) | |||
socks_factory, proxy.host, proxy.port, family=family, | |||
proto=proto, flags=flags, sock=sock, local_addr=local_addr) | |||
except OSError as exc: | |||
raise SocksConnectionError( | |||
'[Errno %s] Can not connect to proxy %s:%d [%s]' % | |||
(exc.errno, proxy.host, proxy.port, exc.strerror)) from exc | |||
# Wait until communication with proxy server is finished | |||
try: | |||
yield from protocol.negotiate_done() | |||
except SocksError as exc: | |||
raise SocksError('Can not connect to %s:%s [%s]' % | |||
(dst[0], dst[1], exc)) | |||
if protocol_factory: | |||
protocol = protocol_factory() | |||
protocol.connection_made(transport) | |||
transport._protocol = protocol | |||
yield from waiter | |||
except: | |||
transport.close() | |||
raise | |||
return transport, protocol | |||
return protocol.app_transport, protocol.app_protocol |
@@ -33,12 +33,16 @@ class SocksConnector(aiohttp.TCPConnector): | |||
@asyncio.coroutine | |||
def _create_connection(self, req): | |||
if req.ssl: | |||
sslcontext = self.ssl_context | |||
else: | |||
sslcontext = None | |||
if not self._remote_resolve: | |||
dst_hosts = yield from self._resolve_host(req.host, req.port) | |||
dst = dst_hosts[0]['host'], dst_hosts[0]['port'] | |||
else: | |||
dst = req.host, req.port | |||
exc = None | |||
# if self._resolver is AsyncResolver and self._proxy.host | |||
# is ip address, then aiodns raise DNSError. | |||
@@ -56,6 +60,7 @@ class SocksConnector(aiohttp.TCPConnector): | |||
except ValueError: | |||
proxy_hosts = yield from self._resolve_host(self._proxy.host, | |||
self._proxy.port) | |||
exc = None | |||
for hinfo in proxy_hosts: | |||
try: | |||
@@ -65,8 +70,29 @@ class SocksConnector(aiohttp.TCPConnector): | |||
transp, proto = yield from create_connection( | |||
self._factory, proxy, self._proxy_auth, dst, | |||
loop=self._loop, remote_resolve=self._remote_resolve, | |||
ssl=None, family=hinfo['family'], proto=hinfo['proto'], | |||
flags=hinfo['flags'], local_addr=self._local_addr) | |||
ssl=sslcontext, family=hinfo['family'], | |||
proto=hinfo['proto'], flags=hinfo['flags'], | |||
local_addr=self._local_addr, | |||
server_hostname=req.host if sslcontext else None) | |||
has_cert = transp.get_extra_info('sslcontext') | |||
if has_cert and self._fingerprint: | |||
sock = transp.get_extra_info('socket') | |||
if not hasattr(sock, 'getpeercert'): | |||
# Workaround for asyncio 3.5.0 | |||
# Starting from 3.5.1 version | |||
# there is 'ssl_object' extra info in transport | |||
sock = transp._ssl_protocol._sslpipe.ssl_object | |||
# gives DER-encoded cert as a sequence of bytes (or None) | |||
cert = sock.getpeercert(binary_form=True) | |||
assert cert | |||
got = self._hashfunc(cert).digest() | |||
expected = self._fingerprint | |||
if got != expected: | |||
transp.close() | |||
raise aiohttp.FingerprintMismatch( | |||
expected, got, req.host, 80 | |||
) | |||
return transp, proto | |||
except (OSError, SocksError, SocksConnectionError) as e: | |||
@@ -17,7 +17,9 @@ except ImportError: | |||
class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||
remote_resolve=True, loop=None, ssl=False, | |||
server_hostname=None): | |||
if not isinstance(dst, (tuple, list)) or len(dst) != 2: | |||
raise ValueError( | |||
'Invalid dst format, tuple("dst_host", dst_port))' | |||
@@ -30,18 +32,79 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
self._loop = loop or asyncio.get_event_loop() | |||
self._transport = None | |||
self._negotiate_done = None | |||
self._waiter = waiter | |||
self._negotiate_fut = None | |||
self._ssl = ssl | |||
self._server_hostname = server_hostname | |||
if app_protocol_factory: | |||
self._app_protocol = app_protocol_factory() | |||
else: | |||
self._app_protocol = self | |||
reader = asyncio.StreamReader(loop=self._loop) | |||
super().__init__(stream_reader=reader, loop=self._loop) | |||
def connection_made(self, transport): | |||
# connection_made is called | |||
if self._transport: | |||
return | |||
super().connection_made(transport) | |||
self._transport = transport | |||
def init_app_protocol(fut): | |||
exc = fut.exception() | |||
if exc: | |||
if isinstance(exc, SocksError): | |||
exc = SocksError('Can not connect to %s:%s. %s' % | |||
(self._dst_host, self._dst_port, exc)) | |||
self._waiter.set_exception(exc) | |||
else: | |||
if self._ssl: | |||
sock = self._transport.get_extra_info('socket') | |||
self._transport = self._loop._make_ssl_transport( | |||
rawsock=sock, protocol=self._app_protocol, | |||
sslcontext=self._ssl, server_side=False, | |||
server_hostname=self._server_hostname, | |||
waiter=self._waiter) | |||
else: | |||
self._app_protocol.connection_made(transport) | |||
self._waiter.set_result(True) | |||
req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | |||
self._negotiate_done = ensure_future(req_coro, loop=self._loop) | |||
self._negotiate_fut = ensure_future(req_coro, loop=self._loop) | |||
self._negotiate_fut.add_done_callback(init_app_protocol) | |||
def connection_lost(self, exc): | |||
if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||
self._loop.call_soon(self._app_protocol.connection_lost, exc) | |||
super().connection_lost(exc) | |||
def pause_writing(self): | |||
if self._negotiate_fut.done(): | |||
self._app_protocol.pause_writing() | |||
else: | |||
super().pause_writing() | |||
def resume_writing(self): | |||
if self._negotiate_fut.done(): | |||
self._app_protocol.resume_writing() | |||
else: | |||
super().resume_writing() | |||
def data_received(self, data): | |||
if self._negotiate_fut.done(): | |||
self._app_protocol.data_received(data) | |||
else: | |||
super().data_received(data) | |||
def eof_received(self): | |||
if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||
self._app_protocol.eof_received() | |||
super().eof_received() | |||
@asyncio.coroutine | |||
def socks_request(self, cmd): | |||
@@ -74,12 +137,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
raise OSError('getaddrinfo() returned empty list') | |||
return infos[0][0], infos[0][4][0] | |||
def negotiate_done(self): | |||
return self._negotiate_done | |||
@property | |||
def app_protocol(self): | |||
return self._app_protocol | |||
@property | |||
def app_transport(self): | |||
return self._transport | |||
class Socks4Protocol(BaseSocksProtocol): | |||
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||
remote_resolve=True, loop=None, ssl=False, | |||
server_hostname=None): | |||
proxy_auth = proxy_auth or Socks4Auth('') | |||
if not isinstance(proxy, Socks4Addr): | |||
@@ -88,7 +158,8 @@ class Socks4Protocol(BaseSocksProtocol): | |||
if not isinstance(proxy_auth, Socks4Auth): | |||
raise ValueError('Invalid proxy_auth format') | |||
super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) | |||
super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||
remote_resolve, loop, ssl, server_hostname) | |||
@asyncio.coroutine | |||
def socks_request(self, cmd): | |||
@@ -130,7 +201,9 @@ class Socks4Protocol(BaseSocksProtocol): | |||
class Socks5Protocol(BaseSocksProtocol): | |||
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||
remote_resolve=True, loop=None, ssl=False, | |||
server_hostname=None): | |||
proxy_auth = proxy_auth or Socks5Auth('', '') | |||
if not isinstance(proxy, Socks5Addr): | |||
@@ -139,7 +212,8 @@ class Socks5Protocol(BaseSocksProtocol): | |||
if not isinstance(proxy_auth, Socks5Auth): | |||
raise ValueError('Invalid proxy_auth format') | |||
super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) | |||
super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||
remote_resolve, loop, ssl, server_hostname) | |||
@asyncio.coroutine | |||
def socks_request(self, cmd): | |||
@@ -0,0 +1,25 @@ | |||
import asyncio | |||
import socket | |||
import functools | |||
def find_unused_port(): | |||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
s.bind(('127.0.0.1', 0)) | |||
port = s.getsockname()[1] | |||
s.close() | |||
return port | |||
@asyncio.coroutine | |||
def socks_handler(reader, writer, write_buff): | |||
writer.write(write_buff) | |||
@asyncio.coroutine | |||
def fake_socks_srv(loop, write_buff): | |||
port = find_unused_port() | |||
handler = functools.partial(socks_handler, write_buff=write_buff) | |||
srv = yield from asyncio.start_server( | |||
handler, '127.0.0.1', port, family=socket.AF_INET, loop=loop) | |||
return srv, port |
@@ -24,7 +24,12 @@ class TestSocksConnector(unittest.TestCase): | |||
return mock.Mock(side_effect=coroutine(coro)) | |||
def test_connect_proxy_ip(self): | |||
@mock.patch('aiosocks.connector.create_connection') | |||
def test_connect_proxy_ip(self, cr_conn_mock): | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
cr_conn_mock.side_effect = \ | |||
self._fake_coroutine((tr, proto)).side_effect | |||
loop_mock = mock.Mock() | |||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | |||
@@ -33,21 +38,18 @@ class TestSocksConnector(unittest.TestCase): | |||
loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
proto.negotiate_done = self._fake_coroutine(True) | |||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||
conn = self.loop.run_until_complete(connector.connect(req)) | |||
self.assertTrue(loop_mock.getaddrinfo.is_called) | |||
self.assertIs(conn._transport, tr) | |||
self.assertTrue( | |||
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||
) | |||
conn.close() | |||
def test_connect_proxy_domain(self): | |||
@mock.patch('aiosocks.connector.create_connection') | |||
def test_connect_proxy_domain(self, cr_conn_mock): | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
cr_conn_mock.side_effect = \ | |||
self._fake_coroutine((tr, proto)).side_effect | |||
loop_mock = mock.Mock() | |||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | |||
@@ -56,60 +58,53 @@ class TestSocksConnector(unittest.TestCase): | |||
connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
proto.negotiate_done = self._fake_coroutine(True) | |||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||
conn = self.loop.run_until_complete(connector.connect(req)) | |||
self.assertTrue(connector._resolve_host.is_called) | |||
self.assertEqual(connector._resolve_host.call_count, 1) | |||
self.assertIs(conn._transport, tr) | |||
self.assertTrue( | |||
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||
) | |||
conn.close() | |||
def test_connect_locale_resolve(self): | |||
loop_mock = mock.Mock() | |||
@mock.patch('aiosocks.connector.create_connection') | |||
def test_connect_locale_resolve(self, cr_conn_mock): | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
cr_conn_mock.side_effect = \ | |||
self._fake_coroutine((tr, proto)).side_effect | |||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | |||
connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), | |||
None, loop=loop_mock, remote_resolve=False) | |||
None, loop=self.loop, remote_resolve=False) | |||
connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
proto.negotiate_done = self._fake_coroutine(True) | |||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||
conn = self.loop.run_until_complete(connector.connect(req)) | |||
self.assertTrue(connector._resolve_host.is_called) | |||
self.assertEqual(connector._resolve_host.call_count, 2) | |||
self.assertIs(conn._transport, tr) | |||
self.assertTrue( | |||
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||
) | |||
conn.close() | |||
def test_proxy_connect_fail(self): | |||
@mock.patch('aiosocks.connector.create_connection') | |||
def test_proxy_connect_fail(self, cr_conn_mock): | |||
loop_mock = mock.Mock() | |||
cr_conn_mock.side_effect = \ | |||
self._fake_coroutine(aiosocks.SocksConnectionError()).side_effect | |||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | |||
connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | |||
None, loop=loop_mock) | |||
loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | |||
loop_mock.create_connection = self._fake_coroutine(OSError()) | |||
with self.assertRaises(aiohttp.ProxyConnectionError): | |||
self.loop.run_until_complete(connector.connect(req)) | |||
def test_proxy_negotiate_fail(self): | |||
@mock.patch('aiosocks.connector.create_connection') | |||
def test_proxy_negotiate_fail(self, cr_conn_mock): | |||
loop_mock = mock.Mock() | |||
cr_conn_mock.side_effect = \ | |||
self._fake_coroutine(aiosocks.SocksError()).side_effect | |||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | |||
connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | |||
@@ -117,9 +112,5 @@ class TestSocksConnector(unittest.TestCase): | |||
loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | |||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||
proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) | |||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||
with self.assertRaises(aiosocks.SocksError): | |||
self.loop.run_until_complete(connector.connect(req)) |
@@ -2,6 +2,7 @@ import unittest | |||
import aiosocks | |||
import asyncio | |||
from unittest import mock | |||
from .socks_serv import fake_socks_srv | |||
try: | |||
from asyncio import ensure_future | |||
@@ -74,6 +75,15 @@ class TestCreateConnection(unittest.TestCase): | |||
self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', | |||
str(ct.exception)) | |||
# test ssl, server_hostname | |||
with self.assertRaises(ValueError) as ct: | |||
conn = aiosocks.create_connection( | |||
None, addr, auth, dst, server_hostname='python.org' | |||
) | |||
self.loop.run_until_complete(conn) | |||
self.assertIn('server_hostname is only meaningful with ssl', | |||
str(ct.exception)) | |||
def test_connection_fail(self): | |||
addr = aiosocks.Socks5Addr('localhost') | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
@@ -88,45 +98,323 @@ class TestCreateConnection(unittest.TestCase): | |||
) | |||
self.loop.run_until_complete(conn) | |||
def test_negotiate_fail(self): | |||
addr = aiosocks.Socks5Addr('localhost') | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
class TestCreateSocks4Connection(unittest.TestCase): | |||
def setUp(self): | |||
self.loop = asyncio.new_event_loop() | |||
asyncio.set_event_loop(None) | |||
def tearDown(self): | |||
self.loop.close() | |||
def test_connect_success(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x00\x5a\x04W\x01\x01\x01\x01test') | |||
) | |||
addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks4Auth('usr') | |||
dst = ('python.org', 80) | |||
transp, proto = mock.Mock(), mock.Mock() | |||
proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
loop_mock = mock.Mock() | |||
loop_mock.create_connection = self._fake_coroutine((transp, proto)) | |||
_, addr = protocol._negotiate_fut.result() | |||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||
data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||
self.assertEqual(data, b'test') | |||
server.close() | |||
transport.close() | |||
def test_invalid_ver(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x01\x5a\x04W\x01\x01\x01\x01') | |||
) | |||
addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks4Auth('usr') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
conn = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=loop_mock | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('invalid data', str(ct.exception)) | |||
server.close() | |||
def test_access_not_granted(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x00\x5b\x04W\x01\x01\x01\x01') | |||
) | |||
addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks4Auth('usr') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('0x5b', str(ct.exception)) | |||
server.close() | |||
class TestCreateSocks5Connect(unittest.TestCase): | |||
def setUp(self): | |||
self.loop = asyncio.new_event_loop() | |||
asyncio.set_event_loop(None) | |||
def tearDown(self): | |||
self.loop.close() | |||
def test_connect_success_anonymous(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv( | |||
self.loop, | |||
b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' | |||
) | |||
self.loop.run_until_complete(conn) | |||
self.assertIn('Can not connect to python.org:80', | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
_, addr = protocol._negotiate_fut.result() | |||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||
data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||
self.assertEqual(data, b'test') | |||
server.close() | |||
transport.close() | |||
def test_connect_success_usr_pwd(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv( | |||
self.loop, | |||
b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' | |||
) | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
_, addr = protocol._negotiate_fut.result() | |||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||
data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||
self.assertEqual(data, b'test') | |||
server.close() | |||
transport.close() | |||
def test_auth_ver_err(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x04\x02') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('invalid version', str(ct.exception)) | |||
server.close() | |||
def test_auth_method_rejected(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\xFF') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('authentication methods were rejected', | |||
str(ct.exception)) | |||
def test_create_protocol(self): | |||
addr = aiosocks.Socks5Addr('localhost') | |||
server.close() | |||
def test_auth_status_invalid(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\xF0') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
transp, proto = mock.Mock(), mock.Mock() | |||
proto.negotiate_done = self._fake_coroutine(True) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('invalid data', str(ct.exception)) | |||
loop_mock = mock.Mock() | |||
loop_mock.create_connection = self._fake_coroutine((transp, proto)) | |||
server.close() | |||
def test_auth_status_invalid2(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\x02\x02\x00') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('invalid data', str(ct.exception)) | |||
user_proto = mock.Mock() | |||
server.close() | |||
conn = aiosocks.create_connection( | |||
lambda: user_proto, addr, auth, dst, loop=loop_mock | |||
def test_auth_failed(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\x02\x01\x01') | |||
) | |||
fut = ensure_future(conn, loop=self.loop) | |||
self.loop.run_until_complete(fut) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('authentication failed', str(ct.exception)) | |||
server.close() | |||
def test_cmd_ver_err(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x04\x00\x00') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('invalid version', str(ct.exception)) | |||
server.close() | |||
def test_cmd_not_granted(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x01\x00') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('General SOCKS server failure', str(ct.exception)) | |||
server.close() | |||
def test_invalid_address_type(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x00\x00\xFF') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
with self.assertRaises(aiosocks.SocksError) as ct: | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
transport.close() | |||
self.assertIn('invalid data', str(ct.exception)) | |||
server.close() | |||
def test_atype_ipv4(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv( | |||
self.loop, | |||
b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W' | |||
) | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
_, addr = protocol._negotiate_fut.result() | |||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||
transport.close() | |||
server.close() | |||
def test_atype_ipv6(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv( | |||
self.loop, | |||
b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00' | |||
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W') | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
_, addr = protocol._negotiate_fut.result() | |||
self.assertEqual(addr, ('::111', 1111)) | |||
transport.close() | |||
server.close() | |||
def test_atype_domain(self): | |||
server, port = self.loop.run_until_complete( | |||
fake_socks_srv( | |||
self.loop, | |||
b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W' | |||
) | |||
) | |||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||
dst = ('python.org', 80) | |||
coro = aiosocks.create_connection( | |||
None, addr, auth, dst, loop=self.loop) | |||
transport, protocol = self.loop.run_until_complete(coro) | |||
_, addr = protocol._negotiate_fut.result() | |||
self.assertEqual(addr, (b'python.org', 1111)) | |||
transport, protocol = fut.result() | |||
self.assertIs(transport, transp) | |||
self.assertIs(protocol, user_proto) | |||
self.assertIs(transport._protocol, user_proto) | |||
transport.close() | |||
server.close() |
@@ -13,13 +13,24 @@ except ImportError: | |||
ensure_future = asyncio.async | |||
def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): | |||
def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None): | |||
dst = dst or ('python.org', 80) | |||
proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl, | |||
loop=loop, waiter=waiter, | |||
app_protocol_factory=ap_factory) | |||
return proto | |||
def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'', | |||
ap_factory=None, whiter=None): | |||
addr = addr or aiosocks.Socks4Addr('localhost', 1080) | |||
auth = auth or aiosocks.Socks4Auth('user') | |||
dst = dst or ('python.org', 80) | |||
proto = aiosocks.Socks4Protocol( | |||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) | |||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, | |||
loop=loop, app_protocol_factory=ap_factory, waiter=whiter) | |||
proto._transport = mock.Mock() | |||
proto.read_response = mock.Mock( | |||
side_effect=coro(mock.Mock(return_value=r))) | |||
@@ -30,13 +41,15 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): | |||
return proto | |||
def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None): | |||
def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, | |||
ap_factory=None, whiter=None): | |||
addr = addr or aiosocks.Socks5Addr('localhost', 1080) | |||
auth = auth or aiosocks.Socks5Auth('user', 'pwd') | |||
dst = dst or ('python.org', 80) | |||
proto = aiosocks.Socks5Protocol( | |||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) | |||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, | |||
loop=loop, app_protocol_factory=ap_factory, waiter=whiter) | |||
proto._transport = mock.Mock() | |||
if not isinstance(r, (list, tuple)): | |||
@@ -63,17 +76,19 @@ class TestBaseSocksProtocol(unittest.TestCase): | |||
def test_init(self): | |||
with self.assertRaises(ValueError): | |||
BaseSocksProtocol(None, None, None, loop=self.loop) | |||
BaseSocksProtocol(None, None, None, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
BaseSocksProtocol(None, None, 123, loop=self.loop) | |||
BaseSocksProtocol(None, None, 123, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
BaseSocksProtocol(None, None, ('python.org',), loop=self.loop) | |||
BaseSocksProtocol(None, None, ('python.org',), loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
def test_write_request(self): | |||
proto = BaseSocksProtocol(None, None, ('python.org', 80), | |||
loop=self.loop) | |||
proto = make_base(self.loop) | |||
proto._transport = mock.Mock() | |||
proto.write_request([b'\x00', b'\x01\x02', 0x03]) | |||
@@ -82,6 +97,180 @@ class TestBaseSocksProtocol(unittest.TestCase): | |||
with self.assertRaises(ValueError): | |||
proto.write_request(['\x00']) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_connection_made_os_error(self, ef_mock): | |||
os_err_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = os_err_fut | |||
waiter = asyncio.Future(loop=self.loop) | |||
proto = make_base(self.loop, waiter=waiter) | |||
proto.connection_made(mock.Mock()) | |||
self.assertIs(proto._negotiate_fut, os_err_fut) | |||
with self.assertRaises(OSError): | |||
os_err_fut.set_exception(OSError('test')) | |||
self.loop.run_until_complete(os_err_fut) | |||
self.assertIn('test', str(waiter.exception())) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_connection_made_socks_err(self, ef_mock): | |||
socks_err_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = socks_err_fut | |||
waiter = asyncio.Future(loop=self.loop) | |||
proto = make_base(self.loop, waiter=waiter) | |||
proto.connection_made(mock.Mock()) | |||
self.assertIs(proto._negotiate_fut, socks_err_fut) | |||
with self.assertRaises(aiosocks.SocksError): | |||
socks_err_fut.set_exception(aiosocks.SocksError('test')) | |||
self.loop.run_until_complete(socks_err_fut) | |||
self.assertIn('Can not connect to', str(waiter.exception())) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_connection_made_without_app_proto(self, ef_mock): | |||
success_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = success_fut | |||
waiter = asyncio.Future(loop=self.loop) | |||
proto = make_base(self.loop, waiter=waiter) | |||
proto.connection_made(mock.Mock()) | |||
self.assertIs(proto._negotiate_fut, success_fut) | |||
success_fut.set_result(True) | |||
self.loop.run_until_complete(success_fut) | |||
self.assertTrue(waiter.done()) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_connection_made_with_app_proto(self, ef_mock): | |||
success_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = success_fut | |||
waiter = asyncio.Future(loop=self.loop) | |||
proto = make_base(self.loop, waiter=waiter, | |||
ap_factory=lambda: asyncio.Protocol()) | |||
proto.connection_made(mock.Mock()) | |||
self.assertIs(proto._negotiate_fut, success_fut) | |||
success_fut.set_result(True) | |||
self.loop.run_until_complete(success_fut) | |||
self.assertTrue(waiter.done()) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_connection_lost(self, ef_mock): | |||
negotiate_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = negotiate_fut | |||
app_proto = mock.Mock() | |||
loop_mock = mock.Mock() | |||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||
proto.connection_made(mock.Mock()) | |||
# negotiate not completed | |||
proto.connection_lost(True) | |||
self.assertFalse(loop_mock.call_soon.called) | |||
# negotiate successfully competed | |||
negotiate_fut.set_result(True) | |||
proto.connection_lost(True) | |||
self.assertTrue(loop_mock.call_soon.called) | |||
# negotiate failed | |||
negotiate_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = negotiate_fut | |||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||
proto.connection_made(mock.Mock()) | |||
negotiate_fut.set_exception(Exception()) | |||
proto.connection_lost(True) | |||
self.assertTrue(loop_mock.call_soon.called) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_pause_writing(self, ef_mock): | |||
negotiate_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = negotiate_fut | |||
app_proto = mock.Mock() | |||
loop_mock = mock.Mock() | |||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||
proto.connection_made(mock.Mock()) | |||
# negotiate not completed | |||
proto.pause_writing() | |||
self.assertFalse(app_proto.pause_writing.called) | |||
# negotiate successfully competed | |||
negotiate_fut.set_result(True) | |||
proto.pause_writing() | |||
self.assertTrue(app_proto.pause_writing.called) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_resume_writing(self, ef_mock): | |||
negotiate_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = negotiate_fut | |||
app_proto = mock.Mock() | |||
loop_mock = mock.Mock() | |||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||
proto.connection_made(mock.Mock()) | |||
# negotiate not completed | |||
with self.assertRaises(AssertionError): | |||
proto.resume_writing() | |||
# negotiate fail | |||
negotiate_fut.set_exception(Exception()) | |||
proto.resume_writing() | |||
self.assertTrue(app_proto.resume_writing.called) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_data_received(self, ef_mock): | |||
negotiate_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = negotiate_fut | |||
app_proto = mock.Mock() | |||
loop_mock = mock.Mock() | |||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||
proto.connection_made(mock.Mock()) | |||
# negotiate not completed | |||
proto.data_received(b'123') | |||
self.assertFalse(app_proto.data_received.called) | |||
# negotiate successfully competed | |||
negotiate_fut.set_result(True) | |||
proto.data_received(b'123') | |||
self.assertTrue(app_proto.data_received.called) | |||
@mock.patch('aiosocks.protocols.ensure_future') | |||
def test_eof_received(self, ef_mock): | |||
negotiate_fut = asyncio.Future(loop=self.loop) | |||
ef_mock.return_value = negotiate_fut | |||
app_proto = mock.Mock() | |||
loop_mock = mock.Mock() | |||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||
proto.connection_made(mock.Mock()) | |||
# negotiate not completed | |||
proto.eof_received() | |||
self.assertFalse(app_proto.eof_received.called) | |||
# negotiate successfully competed | |||
negotiate_fut.set_result(True) | |||
proto.eof_received() | |||
self.assertTrue(app_proto.eof_received.called) | |||
class TestSocks4Protocol(unittest.TestCase): | |||
def setUp(self): | |||
@@ -97,21 +286,27 @@ class TestSocks4Protocol(unittest.TestCase): | |||
dst = ('python.org', 80) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks4Protocol(None, None, dst, loop=self.loop) | |||
aiosocks.Socks4Protocol(None, None, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop) | |||
aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, | |||
loop=self.loop) | |||
loop=self.loop, waiter=None, | |||
app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, | |||
loop=self.loop) | |||
loop=self.loop, waiter=None, | |||
app_protocol_factory=None) | |||
aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop) | |||
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop) | |||
aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
def test_request_building(self): | |||
resp = b'\x00\x5a\x00P\x7f\x00\x00\x01' | |||
@@ -230,21 +425,27 @@ class TestSocks5Protocol(unittest.TestCase): | |||
dst = ('python.org', 80) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks5Protocol(None, None, dst, loop=self.loop) | |||
aiosocks.Socks5Protocol(None, None, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop) | |||
aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), | |||
auth, dst, loop=self.loop) | |||
auth, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
with self.assertRaises(ValueError): | |||
aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), | |||
dst, loop=self.loop) | |||
dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop) | |||
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop) | |||
aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop, | |||
waiter=None, app_protocol_factory=None) | |||
def test_authenticate(self): | |||
# invalid server version | |||