@@ -47,7 +47,15 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, | |||||
"proxy is Socks5Addr but proxy_auth is not Socks5Auth" | "proxy is Socks5Addr but proxy_auth is not Socks5Auth" | ||||
) | ) | ||||
if server_hostname is not None and not ssl: | |||||
raise ValueError('server_hostname is only meaningful with ssl') | |||||
if server_hostname is None and ssl: | |||||
# read details: asyncio.create_connection | |||||
server_hostname = dst[0] | |||||
loop = loop or asyncio.get_event_loop() | loop = loop or asyncio.get_event_loop() | ||||
waiter = asyncio.Future(loop=loop) | |||||
def socks_factory(): | def socks_factory(): | ||||
if isinstance(proxy, Socks4Addr): | if isinstance(proxy, Socks4Addr): | ||||
@@ -55,30 +63,24 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, | |||||
else: | else: | ||||
socks_proto = Socks5Protocol | socks_proto = Socks5Protocol | ||||
return socks_proto( | |||||
proxy=proxy, proxy_auth=proxy_auth, dst=dst, | |||||
remote_resolve=remote_resolve, loop=loop) | |||||
return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst, | |||||
app_protocol_factory=protocol_factory, | |||||
waiter=waiter, remote_resolve=remote_resolve, | |||||
loop=loop, ssl=ssl, server_hostname=server_hostname) | |||||
try: | try: | ||||
transport, protocol = yield from loop.create_connection( | transport, protocol = yield from loop.create_connection( | ||||
socks_factory, proxy.host, proxy.port, ssl=ssl, family=family, | |||||
proto=proto, flags=flags, sock=sock, local_addr=local_addr, | |||||
server_hostname=server_hostname) | |||||
socks_factory, proxy.host, proxy.port, family=family, | |||||
proto=proto, flags=flags, sock=sock, local_addr=local_addr) | |||||
except OSError as exc: | except OSError as exc: | ||||
raise SocksConnectionError( | raise SocksConnectionError( | ||||
'[Errno %s] Can not connect to proxy %s:%d [%s]' % | '[Errno %s] Can not connect to proxy %s:%d [%s]' % | ||||
(exc.errno, proxy.host, proxy.port, exc.strerror)) from exc | (exc.errno, proxy.host, proxy.port, exc.strerror)) from exc | ||||
# Wait until communication with proxy server is finished | |||||
try: | try: | ||||
yield from protocol.negotiate_done() | |||||
except SocksError as exc: | |||||
raise SocksError('Can not connect to %s:%s [%s]' % | |||||
(dst[0], dst[1], exc)) | |||||
if protocol_factory: | |||||
protocol = protocol_factory() | |||||
protocol.connection_made(transport) | |||||
transport._protocol = protocol | |||||
yield from waiter | |||||
except: | |||||
transport.close() | |||||
raise | |||||
return transport, protocol | |||||
return protocol.app_transport, protocol.app_protocol |
@@ -33,12 +33,16 @@ class SocksConnector(aiohttp.TCPConnector): | |||||
@asyncio.coroutine | @asyncio.coroutine | ||||
def _create_connection(self, req): | def _create_connection(self, req): | ||||
if req.ssl: | |||||
sslcontext = self.ssl_context | |||||
else: | |||||
sslcontext = None | |||||
if not self._remote_resolve: | if not self._remote_resolve: | ||||
dst_hosts = yield from self._resolve_host(req.host, req.port) | dst_hosts = yield from self._resolve_host(req.host, req.port) | ||||
dst = dst_hosts[0]['host'], dst_hosts[0]['port'] | dst = dst_hosts[0]['host'], dst_hosts[0]['port'] | ||||
else: | else: | ||||
dst = req.host, req.port | dst = req.host, req.port | ||||
exc = None | |||||
# if self._resolver is AsyncResolver and self._proxy.host | # if self._resolver is AsyncResolver and self._proxy.host | ||||
# is ip address, then aiodns raise DNSError. | # is ip address, then aiodns raise DNSError. | ||||
@@ -56,6 +60,7 @@ class SocksConnector(aiohttp.TCPConnector): | |||||
except ValueError: | except ValueError: | ||||
proxy_hosts = yield from self._resolve_host(self._proxy.host, | proxy_hosts = yield from self._resolve_host(self._proxy.host, | ||||
self._proxy.port) | self._proxy.port) | ||||
exc = None | |||||
for hinfo in proxy_hosts: | for hinfo in proxy_hosts: | ||||
try: | try: | ||||
@@ -65,8 +70,29 @@ class SocksConnector(aiohttp.TCPConnector): | |||||
transp, proto = yield from create_connection( | transp, proto = yield from create_connection( | ||||
self._factory, proxy, self._proxy_auth, dst, | self._factory, proxy, self._proxy_auth, dst, | ||||
loop=self._loop, remote_resolve=self._remote_resolve, | loop=self._loop, remote_resolve=self._remote_resolve, | ||||
ssl=None, family=hinfo['family'], proto=hinfo['proto'], | |||||
flags=hinfo['flags'], local_addr=self._local_addr) | |||||
ssl=sslcontext, family=hinfo['family'], | |||||
proto=hinfo['proto'], flags=hinfo['flags'], | |||||
local_addr=self._local_addr, | |||||
server_hostname=req.host if sslcontext else None) | |||||
has_cert = transp.get_extra_info('sslcontext') | |||||
if has_cert and self._fingerprint: | |||||
sock = transp.get_extra_info('socket') | |||||
if not hasattr(sock, 'getpeercert'): | |||||
# Workaround for asyncio 3.5.0 | |||||
# Starting from 3.5.1 version | |||||
# there is 'ssl_object' extra info in transport | |||||
sock = transp._ssl_protocol._sslpipe.ssl_object | |||||
# gives DER-encoded cert as a sequence of bytes (or None) | |||||
cert = sock.getpeercert(binary_form=True) | |||||
assert cert | |||||
got = self._hashfunc(cert).digest() | |||||
expected = self._fingerprint | |||||
if got != expected: | |||||
transp.close() | |||||
raise aiohttp.FingerprintMismatch( | |||||
expected, got, req.host, 80 | |||||
) | |||||
return transp, proto | return transp, proto | ||||
except (OSError, SocksError, SocksConnectionError) as e: | except (OSError, SocksError, SocksConnectionError) as e: | ||||
@@ -17,7 +17,9 @@ except ImportError: | |||||
class BaseSocksProtocol(asyncio.StreamReaderProtocol): | class BaseSocksProtocol(asyncio.StreamReaderProtocol): | ||||
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||||
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
remote_resolve=True, loop=None, ssl=False, | |||||
server_hostname=None): | |||||
if not isinstance(dst, (tuple, list)) or len(dst) != 2: | if not isinstance(dst, (tuple, list)) or len(dst) != 2: | ||||
raise ValueError( | raise ValueError( | ||||
'Invalid dst format, tuple("dst_host", dst_port))' | 'Invalid dst format, tuple("dst_host", dst_port))' | ||||
@@ -30,18 +32,79 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
self._loop = loop or asyncio.get_event_loop() | self._loop = loop or asyncio.get_event_loop() | ||||
self._transport = None | self._transport = None | ||||
self._negotiate_done = None | |||||
self._waiter = waiter | |||||
self._negotiate_fut = None | |||||
self._ssl = ssl | |||||
self._server_hostname = server_hostname | |||||
if app_protocol_factory: | |||||
self._app_protocol = app_protocol_factory() | |||||
else: | |||||
self._app_protocol = self | |||||
reader = asyncio.StreamReader(loop=self._loop) | reader = asyncio.StreamReader(loop=self._loop) | ||||
super().__init__(stream_reader=reader, loop=self._loop) | super().__init__(stream_reader=reader, loop=self._loop) | ||||
def connection_made(self, transport): | def connection_made(self, transport): | ||||
# connection_made is called | |||||
if self._transport: | |||||
return | |||||
super().connection_made(transport) | super().connection_made(transport) | ||||
self._transport = transport | self._transport = transport | ||||
def init_app_protocol(fut): | |||||
exc = fut.exception() | |||||
if exc: | |||||
if isinstance(exc, SocksError): | |||||
exc = SocksError('Can not connect to %s:%s. %s' % | |||||
(self._dst_host, self._dst_port, exc)) | |||||
self._waiter.set_exception(exc) | |||||
else: | |||||
if self._ssl: | |||||
sock = self._transport.get_extra_info('socket') | |||||
self._transport = self._loop._make_ssl_transport( | |||||
rawsock=sock, protocol=self._app_protocol, | |||||
sslcontext=self._ssl, server_side=False, | |||||
server_hostname=self._server_hostname, | |||||
waiter=self._waiter) | |||||
else: | |||||
self._app_protocol.connection_made(transport) | |||||
self._waiter.set_result(True) | |||||
req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | ||||
self._negotiate_done = ensure_future(req_coro, loop=self._loop) | |||||
self._negotiate_fut = ensure_future(req_coro, loop=self._loop) | |||||
self._negotiate_fut.add_done_callback(init_app_protocol) | |||||
def connection_lost(self, exc): | |||||
if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||||
self._loop.call_soon(self._app_protocol.connection_lost, exc) | |||||
super().connection_lost(exc) | |||||
def pause_writing(self): | |||||
if self._negotiate_fut.done(): | |||||
self._app_protocol.pause_writing() | |||||
else: | |||||
super().pause_writing() | |||||
def resume_writing(self): | |||||
if self._negotiate_fut.done(): | |||||
self._app_protocol.resume_writing() | |||||
else: | |||||
super().resume_writing() | |||||
def data_received(self, data): | |||||
if self._negotiate_fut.done(): | |||||
self._app_protocol.data_received(data) | |||||
else: | |||||
super().data_received(data) | |||||
def eof_received(self): | |||||
if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||||
self._app_protocol.eof_received() | |||||
super().eof_received() | |||||
@asyncio.coroutine | @asyncio.coroutine | ||||
def socks_request(self, cmd): | def socks_request(self, cmd): | ||||
@@ -74,12 +137,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
raise OSError('getaddrinfo() returned empty list') | raise OSError('getaddrinfo() returned empty list') | ||||
return infos[0][0], infos[0][4][0] | return infos[0][0], infos[0][4][0] | ||||
def negotiate_done(self): | |||||
return self._negotiate_done | |||||
@property | |||||
def app_protocol(self): | |||||
return self._app_protocol | |||||
@property | |||||
def app_transport(self): | |||||
return self._transport | |||||
class Socks4Protocol(BaseSocksProtocol): | class Socks4Protocol(BaseSocksProtocol): | ||||
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||||
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
remote_resolve=True, loop=None, ssl=False, | |||||
server_hostname=None): | |||||
proxy_auth = proxy_auth or Socks4Auth('') | proxy_auth = proxy_auth or Socks4Auth('') | ||||
if not isinstance(proxy, Socks4Addr): | if not isinstance(proxy, Socks4Addr): | ||||
@@ -88,7 +158,8 @@ class Socks4Protocol(BaseSocksProtocol): | |||||
if not isinstance(proxy_auth, Socks4Auth): | if not isinstance(proxy_auth, Socks4Auth): | ||||
raise ValueError('Invalid proxy_auth format') | raise ValueError('Invalid proxy_auth format') | ||||
super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) | |||||
super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
remote_resolve, loop, ssl, server_hostname) | |||||
@asyncio.coroutine | @asyncio.coroutine | ||||
def socks_request(self, cmd): | def socks_request(self, cmd): | ||||
@@ -130,7 +201,9 @@ class Socks4Protocol(BaseSocksProtocol): | |||||
class Socks5Protocol(BaseSocksProtocol): | class Socks5Protocol(BaseSocksProtocol): | ||||
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||||
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
remote_resolve=True, loop=None, ssl=False, | |||||
server_hostname=None): | |||||
proxy_auth = proxy_auth or Socks5Auth('', '') | proxy_auth = proxy_auth or Socks5Auth('', '') | ||||
if not isinstance(proxy, Socks5Addr): | if not isinstance(proxy, Socks5Addr): | ||||
@@ -139,7 +212,8 @@ class Socks5Protocol(BaseSocksProtocol): | |||||
if not isinstance(proxy_auth, Socks5Auth): | if not isinstance(proxy_auth, Socks5Auth): | ||||
raise ValueError('Invalid proxy_auth format') | raise ValueError('Invalid proxy_auth format') | ||||
super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) | |||||
super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
remote_resolve, loop, ssl, server_hostname) | |||||
@asyncio.coroutine | @asyncio.coroutine | ||||
def socks_request(self, cmd): | def socks_request(self, cmd): | ||||
@@ -0,0 +1,25 @@ | |||||
import asyncio | |||||
import socket | |||||
import functools | |||||
def find_unused_port(): | |||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||||
s.bind(('127.0.0.1', 0)) | |||||
port = s.getsockname()[1] | |||||
s.close() | |||||
return port | |||||
@asyncio.coroutine | |||||
def socks_handler(reader, writer, write_buff): | |||||
writer.write(write_buff) | |||||
@asyncio.coroutine | |||||
def fake_socks_srv(loop, write_buff): | |||||
port = find_unused_port() | |||||
handler = functools.partial(socks_handler, write_buff=write_buff) | |||||
srv = yield from asyncio.start_server( | |||||
handler, '127.0.0.1', port, family=socket.AF_INET, loop=loop) | |||||
return srv, port |
@@ -24,7 +24,12 @@ class TestSocksConnector(unittest.TestCase): | |||||
return mock.Mock(side_effect=coroutine(coro)) | return mock.Mock(side_effect=coroutine(coro)) | ||||
def test_connect_proxy_ip(self): | |||||
@mock.patch('aiosocks.connector.create_connection') | |||||
def test_connect_proxy_ip(self, cr_conn_mock): | |||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
cr_conn_mock.side_effect = \ | |||||
self._fake_coroutine((tr, proto)).side_effect | |||||
loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
@@ -33,21 +38,18 @@ class TestSocksConnector(unittest.TestCase): | |||||
loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | ||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
proto.negotiate_done = self._fake_coroutine(True) | |||||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
conn = self.loop.run_until_complete(connector.connect(req)) | conn = self.loop.run_until_complete(connector.connect(req)) | ||||
self.assertTrue(loop_mock.getaddrinfo.is_called) | self.assertTrue(loop_mock.getaddrinfo.is_called) | ||||
self.assertIs(conn._transport, tr) | self.assertIs(conn._transport, tr) | ||||
self.assertTrue( | |||||
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||||
) | |||||
conn.close() | conn.close() | ||||
def test_connect_proxy_domain(self): | |||||
@mock.patch('aiosocks.connector.create_connection') | |||||
def test_connect_proxy_domain(self, cr_conn_mock): | |||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
cr_conn_mock.side_effect = \ | |||||
self._fake_coroutine((tr, proto)).side_effect | |||||
loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
@@ -56,60 +58,53 @@ class TestSocksConnector(unittest.TestCase): | |||||
connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | ||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
proto.negotiate_done = self._fake_coroutine(True) | |||||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
conn = self.loop.run_until_complete(connector.connect(req)) | conn = self.loop.run_until_complete(connector.connect(req)) | ||||
self.assertTrue(connector._resolve_host.is_called) | self.assertTrue(connector._resolve_host.is_called) | ||||
self.assertEqual(connector._resolve_host.call_count, 1) | self.assertEqual(connector._resolve_host.call_count, 1) | ||||
self.assertIs(conn._transport, tr) | self.assertIs(conn._transport, tr) | ||||
self.assertTrue( | |||||
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||||
) | |||||
conn.close() | conn.close() | ||||
def test_connect_locale_resolve(self): | |||||
loop_mock = mock.Mock() | |||||
@mock.patch('aiosocks.connector.create_connection') | |||||
def test_connect_locale_resolve(self, cr_conn_mock): | |||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
cr_conn_mock.side_effect = \ | |||||
self._fake_coroutine((tr, proto)).side_effect | |||||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), | connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), | ||||
None, loop=loop_mock, remote_resolve=False) | |||||
None, loop=self.loop, remote_resolve=False) | |||||
connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | ||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
proto.negotiate_done = self._fake_coroutine(True) | |||||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
conn = self.loop.run_until_complete(connector.connect(req)) | conn = self.loop.run_until_complete(connector.connect(req)) | ||||
self.assertTrue(connector._resolve_host.is_called) | self.assertTrue(connector._resolve_host.is_called) | ||||
self.assertEqual(connector._resolve_host.call_count, 2) | self.assertEqual(connector._resolve_host.call_count, 2) | ||||
self.assertIs(conn._transport, tr) | |||||
self.assertTrue( | |||||
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||||
) | |||||
conn.close() | conn.close() | ||||
def test_proxy_connect_fail(self): | |||||
@mock.patch('aiosocks.connector.create_connection') | |||||
def test_proxy_connect_fail(self, cr_conn_mock): | |||||
loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
cr_conn_mock.side_effect = \ | |||||
self._fake_coroutine(aiosocks.SocksConnectionError()).side_effect | |||||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | ||||
None, loop=loop_mock) | None, loop=loop_mock) | ||||
loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | ||||
loop_mock.create_connection = self._fake_coroutine(OSError()) | |||||
with self.assertRaises(aiohttp.ProxyConnectionError): | with self.assertRaises(aiohttp.ProxyConnectionError): | ||||
self.loop.run_until_complete(connector.connect(req)) | self.loop.run_until_complete(connector.connect(req)) | ||||
def test_proxy_negotiate_fail(self): | |||||
@mock.patch('aiosocks.connector.create_connection') | |||||
def test_proxy_negotiate_fail(self, cr_conn_mock): | |||||
loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
cr_conn_mock.side_effect = \ | |||||
self._fake_coroutine(aiosocks.SocksError()).side_effect | |||||
req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | ||||
@@ -117,9 +112,5 @@ class TestSocksConnector(unittest.TestCase): | |||||
loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | ||||
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) | |||||
loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
with self.assertRaises(aiosocks.SocksError): | with self.assertRaises(aiosocks.SocksError): | ||||
self.loop.run_until_complete(connector.connect(req)) | self.loop.run_until_complete(connector.connect(req)) |
@@ -2,6 +2,7 @@ import unittest | |||||
import aiosocks | import aiosocks | ||||
import asyncio | import asyncio | ||||
from unittest import mock | from unittest import mock | ||||
from .socks_serv import fake_socks_srv | |||||
try: | try: | ||||
from asyncio import ensure_future | from asyncio import ensure_future | ||||
@@ -74,6 +75,15 @@ class TestCreateConnection(unittest.TestCase): | |||||
self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', | self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', | ||||
str(ct.exception)) | str(ct.exception)) | ||||
# test ssl, server_hostname | |||||
with self.assertRaises(ValueError) as ct: | |||||
conn = aiosocks.create_connection( | |||||
None, addr, auth, dst, server_hostname='python.org' | |||||
) | |||||
self.loop.run_until_complete(conn) | |||||
self.assertIn('server_hostname is only meaningful with ssl', | |||||
str(ct.exception)) | |||||
def test_connection_fail(self): | def test_connection_fail(self): | ||||
addr = aiosocks.Socks5Addr('localhost') | addr = aiosocks.Socks5Addr('localhost') | ||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | auth = aiosocks.Socks5Auth('usr', 'pwd') | ||||
@@ -88,45 +98,323 @@ class TestCreateConnection(unittest.TestCase): | |||||
) | ) | ||||
self.loop.run_until_complete(conn) | self.loop.run_until_complete(conn) | ||||
def test_negotiate_fail(self): | |||||
addr = aiosocks.Socks5Addr('localhost') | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
class TestCreateSocks4Connection(unittest.TestCase): | |||||
def setUp(self): | |||||
self.loop = asyncio.new_event_loop() | |||||
asyncio.set_event_loop(None) | |||||
def tearDown(self): | |||||
self.loop.close() | |||||
def test_connect_success(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x00\x5a\x04W\x01\x01\x01\x01test') | |||||
) | |||||
addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks4Auth('usr') | |||||
dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
transp, proto = mock.Mock(), mock.Mock() | |||||
proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
loop_mock = mock.Mock() | |||||
loop_mock.create_connection = self._fake_coroutine((transp, proto)) | |||||
_, addr = protocol._negotiate_fut.result() | |||||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||||
self.assertEqual(data, b'test') | |||||
server.close() | |||||
transport.close() | |||||
def test_invalid_ver(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x01\x5a\x04W\x01\x01\x01\x01') | |||||
) | |||||
addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks4Auth('usr') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | with self.assertRaises(aiosocks.SocksError) as ct: | ||||
conn = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=loop_mock | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('invalid data', str(ct.exception)) | |||||
server.close() | |||||
def test_access_not_granted(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x00\x5b\x04W\x01\x01\x01\x01') | |||||
) | |||||
addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks4Auth('usr') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('0x5b', str(ct.exception)) | |||||
server.close() | |||||
class TestCreateSocks5Connect(unittest.TestCase): | |||||
def setUp(self): | |||||
self.loop = asyncio.new_event_loop() | |||||
asyncio.set_event_loop(None) | |||||
def tearDown(self): | |||||
self.loop.close() | |||||
def test_connect_success_anonymous(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv( | |||||
self.loop, | |||||
b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' | |||||
) | ) | ||||
self.loop.run_until_complete(conn) | |||||
self.assertIn('Can not connect to python.org:80', | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
_, addr = protocol._negotiate_fut.result() | |||||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||||
self.assertEqual(data, b'test') | |||||
server.close() | |||||
transport.close() | |||||
def test_connect_success_usr_pwd(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv( | |||||
self.loop, | |||||
b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' | |||||
) | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
_, addr = protocol._negotiate_fut.result() | |||||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||||
self.assertEqual(data, b'test') | |||||
server.close() | |||||
transport.close() | |||||
def test_auth_ver_err(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x04\x02') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('invalid version', str(ct.exception)) | |||||
server.close() | |||||
def test_auth_method_rejected(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\xFF') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('authentication methods were rejected', | |||||
str(ct.exception)) | str(ct.exception)) | ||||
def test_create_protocol(self): | |||||
addr = aiosocks.Socks5Addr('localhost') | |||||
server.close() | |||||
def test_auth_status_invalid(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\xF0') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | auth = aiosocks.Socks5Auth('usr', 'pwd') | ||||
dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
transp, proto = mock.Mock(), mock.Mock() | |||||
proto.negotiate_done = self._fake_coroutine(True) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('invalid data', str(ct.exception)) | |||||
loop_mock = mock.Mock() | |||||
loop_mock.create_connection = self._fake_coroutine((transp, proto)) | |||||
server.close() | |||||
def test_auth_status_invalid2(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\x02\x02\x00') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('invalid data', str(ct.exception)) | |||||
user_proto = mock.Mock() | |||||
server.close() | |||||
conn = aiosocks.create_connection( | |||||
lambda: user_proto, addr, auth, dst, loop=loop_mock | |||||
def test_auth_failed(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\x02\x01\x01') | |||||
) | ) | ||||
fut = ensure_future(conn, loop=self.loop) | |||||
self.loop.run_until_complete(fut) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('authentication failed', str(ct.exception)) | |||||
server.close() | |||||
def test_cmd_ver_err(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x04\x00\x00') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('invalid version', str(ct.exception)) | |||||
server.close() | |||||
def test_cmd_not_granted(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x01\x00') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('General SOCKS server failure', str(ct.exception)) | |||||
server.close() | |||||
def test_invalid_address_type(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x00\x00\xFF') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
with self.assertRaises(aiosocks.SocksError) as ct: | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
transport.close() | |||||
self.assertIn('invalid data', str(ct.exception)) | |||||
server.close() | |||||
def test_atype_ipv4(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv( | |||||
self.loop, | |||||
b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W' | |||||
) | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
_, addr = protocol._negotiate_fut.result() | |||||
self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
transport.close() | |||||
server.close() | |||||
def test_atype_ipv6(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv( | |||||
self.loop, | |||||
b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00' | |||||
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W') | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
_, addr = protocol._negotiate_fut.result() | |||||
self.assertEqual(addr, ('::111', 1111)) | |||||
transport.close() | |||||
server.close() | |||||
def test_atype_domain(self): | |||||
server, port = self.loop.run_until_complete( | |||||
fake_socks_srv( | |||||
self.loop, | |||||
b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W' | |||||
) | |||||
) | |||||
addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
dst = ('python.org', 80) | |||||
coro = aiosocks.create_connection( | |||||
None, addr, auth, dst, loop=self.loop) | |||||
transport, protocol = self.loop.run_until_complete(coro) | |||||
_, addr = protocol._negotiate_fut.result() | |||||
self.assertEqual(addr, (b'python.org', 1111)) | |||||
transport, protocol = fut.result() | |||||
self.assertIs(transport, transp) | |||||
self.assertIs(protocol, user_proto) | |||||
self.assertIs(transport._protocol, user_proto) | |||||
transport.close() | |||||
server.close() |
@@ -13,13 +13,24 @@ except ImportError: | |||||
ensure_future = asyncio.async | ensure_future = asyncio.async | ||||
def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): | |||||
def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None): | |||||
dst = dst or ('python.org', 80) | |||||
proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl, | |||||
loop=loop, waiter=waiter, | |||||
app_protocol_factory=ap_factory) | |||||
return proto | |||||
def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'', | |||||
ap_factory=None, whiter=None): | |||||
addr = addr or aiosocks.Socks4Addr('localhost', 1080) | addr = addr or aiosocks.Socks4Addr('localhost', 1080) | ||||
auth = auth or aiosocks.Socks4Auth('user') | auth = auth or aiosocks.Socks4Auth('user') | ||||
dst = dst or ('python.org', 80) | dst = dst or ('python.org', 80) | ||||
proto = aiosocks.Socks4Protocol( | proto = aiosocks.Socks4Protocol( | ||||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) | |||||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, | |||||
loop=loop, app_protocol_factory=ap_factory, waiter=whiter) | |||||
proto._transport = mock.Mock() | proto._transport = mock.Mock() | ||||
proto.read_response = mock.Mock( | proto.read_response = mock.Mock( | ||||
side_effect=coro(mock.Mock(return_value=r))) | side_effect=coro(mock.Mock(return_value=r))) | ||||
@@ -30,13 +41,15 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): | |||||
return proto | return proto | ||||
def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None): | |||||
def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, | |||||
ap_factory=None, whiter=None): | |||||
addr = addr or aiosocks.Socks5Addr('localhost', 1080) | addr = addr or aiosocks.Socks5Addr('localhost', 1080) | ||||
auth = auth or aiosocks.Socks5Auth('user', 'pwd') | auth = auth or aiosocks.Socks5Auth('user', 'pwd') | ||||
dst = dst or ('python.org', 80) | dst = dst or ('python.org', 80) | ||||
proto = aiosocks.Socks5Protocol( | proto = aiosocks.Socks5Protocol( | ||||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) | |||||
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, | |||||
loop=loop, app_protocol_factory=ap_factory, waiter=whiter) | |||||
proto._transport = mock.Mock() | proto._transport = mock.Mock() | ||||
if not isinstance(r, (list, tuple)): | if not isinstance(r, (list, tuple)): | ||||
@@ -63,17 +76,19 @@ class TestBaseSocksProtocol(unittest.TestCase): | |||||
def test_init(self): | def test_init(self): | ||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
BaseSocksProtocol(None, None, None, loop=self.loop) | |||||
BaseSocksProtocol(None, None, None, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
BaseSocksProtocol(None, None, 123, loop=self.loop) | |||||
BaseSocksProtocol(None, None, 123, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
BaseSocksProtocol(None, None, ('python.org',), loop=self.loop) | |||||
BaseSocksProtocol(None, None, ('python.org',), loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
def test_write_request(self): | def test_write_request(self): | ||||
proto = BaseSocksProtocol(None, None, ('python.org', 80), | |||||
loop=self.loop) | |||||
proto = make_base(self.loop) | |||||
proto._transport = mock.Mock() | proto._transport = mock.Mock() | ||||
proto.write_request([b'\x00', b'\x01\x02', 0x03]) | proto.write_request([b'\x00', b'\x01\x02', 0x03]) | ||||
@@ -82,6 +97,180 @@ class TestBaseSocksProtocol(unittest.TestCase): | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
proto.write_request(['\x00']) | proto.write_request(['\x00']) | ||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_connection_made_os_error(self, ef_mock): | |||||
os_err_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = os_err_fut | |||||
waiter = asyncio.Future(loop=self.loop) | |||||
proto = make_base(self.loop, waiter=waiter) | |||||
proto.connection_made(mock.Mock()) | |||||
self.assertIs(proto._negotiate_fut, os_err_fut) | |||||
with self.assertRaises(OSError): | |||||
os_err_fut.set_exception(OSError('test')) | |||||
self.loop.run_until_complete(os_err_fut) | |||||
self.assertIn('test', str(waiter.exception())) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_connection_made_socks_err(self, ef_mock): | |||||
socks_err_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = socks_err_fut | |||||
waiter = asyncio.Future(loop=self.loop) | |||||
proto = make_base(self.loop, waiter=waiter) | |||||
proto.connection_made(mock.Mock()) | |||||
self.assertIs(proto._negotiate_fut, socks_err_fut) | |||||
with self.assertRaises(aiosocks.SocksError): | |||||
socks_err_fut.set_exception(aiosocks.SocksError('test')) | |||||
self.loop.run_until_complete(socks_err_fut) | |||||
self.assertIn('Can not connect to', str(waiter.exception())) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_connection_made_without_app_proto(self, ef_mock): | |||||
success_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = success_fut | |||||
waiter = asyncio.Future(loop=self.loop) | |||||
proto = make_base(self.loop, waiter=waiter) | |||||
proto.connection_made(mock.Mock()) | |||||
self.assertIs(proto._negotiate_fut, success_fut) | |||||
success_fut.set_result(True) | |||||
self.loop.run_until_complete(success_fut) | |||||
self.assertTrue(waiter.done()) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_connection_made_with_app_proto(self, ef_mock): | |||||
success_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = success_fut | |||||
waiter = asyncio.Future(loop=self.loop) | |||||
proto = make_base(self.loop, waiter=waiter, | |||||
ap_factory=lambda: asyncio.Protocol()) | |||||
proto.connection_made(mock.Mock()) | |||||
self.assertIs(proto._negotiate_fut, success_fut) | |||||
success_fut.set_result(True) | |||||
self.loop.run_until_complete(success_fut) | |||||
self.assertTrue(waiter.done()) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_connection_lost(self, ef_mock): | |||||
negotiate_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = negotiate_fut | |||||
app_proto = mock.Mock() | |||||
loop_mock = mock.Mock() | |||||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
proto.connection_made(mock.Mock()) | |||||
# negotiate not completed | |||||
proto.connection_lost(True) | |||||
self.assertFalse(loop_mock.call_soon.called) | |||||
# negotiate successfully competed | |||||
negotiate_fut.set_result(True) | |||||
proto.connection_lost(True) | |||||
self.assertTrue(loop_mock.call_soon.called) | |||||
# negotiate failed | |||||
negotiate_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = negotiate_fut | |||||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
proto.connection_made(mock.Mock()) | |||||
negotiate_fut.set_exception(Exception()) | |||||
proto.connection_lost(True) | |||||
self.assertTrue(loop_mock.call_soon.called) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_pause_writing(self, ef_mock): | |||||
negotiate_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = negotiate_fut | |||||
app_proto = mock.Mock() | |||||
loop_mock = mock.Mock() | |||||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
proto.connection_made(mock.Mock()) | |||||
# negotiate not completed | |||||
proto.pause_writing() | |||||
self.assertFalse(app_proto.pause_writing.called) | |||||
# negotiate successfully competed | |||||
negotiate_fut.set_result(True) | |||||
proto.pause_writing() | |||||
self.assertTrue(app_proto.pause_writing.called) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_resume_writing(self, ef_mock): | |||||
negotiate_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = negotiate_fut | |||||
app_proto = mock.Mock() | |||||
loop_mock = mock.Mock() | |||||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
proto.connection_made(mock.Mock()) | |||||
# negotiate not completed | |||||
with self.assertRaises(AssertionError): | |||||
proto.resume_writing() | |||||
# negotiate fail | |||||
negotiate_fut.set_exception(Exception()) | |||||
proto.resume_writing() | |||||
self.assertTrue(app_proto.resume_writing.called) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_data_received(self, ef_mock): | |||||
negotiate_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = negotiate_fut | |||||
app_proto = mock.Mock() | |||||
loop_mock = mock.Mock() | |||||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
proto.connection_made(mock.Mock()) | |||||
# negotiate not completed | |||||
proto.data_received(b'123') | |||||
self.assertFalse(app_proto.data_received.called) | |||||
# negotiate successfully competed | |||||
negotiate_fut.set_result(True) | |||||
proto.data_received(b'123') | |||||
self.assertTrue(app_proto.data_received.called) | |||||
@mock.patch('aiosocks.protocols.ensure_future') | |||||
def test_eof_received(self, ef_mock): | |||||
negotiate_fut = asyncio.Future(loop=self.loop) | |||||
ef_mock.return_value = negotiate_fut | |||||
app_proto = mock.Mock() | |||||
loop_mock = mock.Mock() | |||||
proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
proto.connection_made(mock.Mock()) | |||||
# negotiate not completed | |||||
proto.eof_received() | |||||
self.assertFalse(app_proto.eof_received.called) | |||||
# negotiate successfully competed | |||||
negotiate_fut.set_result(True) | |||||
proto.eof_received() | |||||
self.assertTrue(app_proto.eof_received.called) | |||||
class TestSocks4Protocol(unittest.TestCase): | class TestSocks4Protocol(unittest.TestCase): | ||||
def setUp(self): | def setUp(self): | ||||
@@ -97,21 +286,27 @@ class TestSocks4Protocol(unittest.TestCase): | |||||
dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks4Protocol(None, None, dst, loop=self.loop) | |||||
aiosocks.Socks4Protocol(None, None, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop) | |||||
aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, | aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, | ||||
loop=self.loop) | |||||
loop=self.loop, waiter=None, | |||||
app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, | aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, | ||||
loop=self.loop) | |||||
loop=self.loop, waiter=None, | |||||
app_protocol_factory=None) | |||||
aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop) | |||||
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop) | |||||
aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
def test_request_building(self): | def test_request_building(self): | ||||
resp = b'\x00\x5a\x00P\x7f\x00\x00\x01' | resp = b'\x00\x5a\x00P\x7f\x00\x00\x01' | ||||
@@ -230,21 +425,27 @@ class TestSocks5Protocol(unittest.TestCase): | |||||
dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks5Protocol(None, None, dst, loop=self.loop) | |||||
aiosocks.Socks5Protocol(None, None, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop) | |||||
aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), | aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), | ||||
auth, dst, loop=self.loop) | |||||
auth, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), | aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), | ||||
dst, loop=self.loop) | |||||
dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop) | |||||
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop) | |||||
aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop, | |||||
waiter=None, app_protocol_factory=None) | |||||
def test_authenticate(self): | def test_authenticate(self): | ||||
# invalid server version | # invalid server version | ||||