diff --git a/aiosocks/__init__.py b/aiosocks/__init__.py index ec49af0..9aa27a5 100644 --- a/aiosocks/__init__.py +++ b/aiosocks/__init__.py @@ -47,7 +47,15 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, "proxy is Socks5Addr but proxy_auth is not Socks5Auth" ) + if server_hostname is not None and not ssl: + raise ValueError('server_hostname is only meaningful with ssl') + + if server_hostname is None and ssl: + # read details: asyncio.create_connection + server_hostname = dst[0] + loop = loop or asyncio.get_event_loop() + waiter = asyncio.Future(loop=loop) def socks_factory(): if isinstance(proxy, Socks4Addr): @@ -55,30 +63,24 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, else: socks_proto = Socks5Protocol - return socks_proto( - proxy=proxy, proxy_auth=proxy_auth, dst=dst, - remote_resolve=remote_resolve, loop=loop) + return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst, + app_protocol_factory=protocol_factory, + waiter=waiter, remote_resolve=remote_resolve, + loop=loop, ssl=ssl, server_hostname=server_hostname) try: transport, protocol = yield from loop.create_connection( - socks_factory, proxy.host, proxy.port, ssl=ssl, family=family, - proto=proto, flags=flags, sock=sock, local_addr=local_addr, - server_hostname=server_hostname) + socks_factory, proxy.host, proxy.port, family=family, + proto=proto, flags=flags, sock=sock, local_addr=local_addr) except OSError as exc: raise SocksConnectionError( '[Errno %s] Can not connect to proxy %s:%d [%s]' % (exc.errno, proxy.host, proxy.port, exc.strerror)) from exc - # Wait until communication with proxy server is finished try: - yield from protocol.negotiate_done() - except SocksError as exc: - raise SocksError('Can not connect to %s:%s [%s]' % - (dst[0], dst[1], exc)) - - if protocol_factory: - protocol = protocol_factory() - protocol.connection_made(transport) - transport._protocol = protocol + yield from waiter + except: + transport.close() + raise - return transport, protocol + return protocol.app_transport, protocol.app_protocol diff --git a/aiosocks/connector.py b/aiosocks/connector.py index 0a72d4b..502fb10 100644 --- a/aiosocks/connector.py +++ b/aiosocks/connector.py @@ -33,12 +33,16 @@ class SocksConnector(aiohttp.TCPConnector): @asyncio.coroutine def _create_connection(self, req): + if req.ssl: + sslcontext = self.ssl_context + else: + sslcontext = None + if not self._remote_resolve: dst_hosts = yield from self._resolve_host(req.host, req.port) dst = dst_hosts[0]['host'], dst_hosts[0]['port'] else: dst = req.host, req.port - exc = None # if self._resolver is AsyncResolver and self._proxy.host # is ip address, then aiodns raise DNSError. @@ -56,6 +60,7 @@ class SocksConnector(aiohttp.TCPConnector): except ValueError: proxy_hosts = yield from self._resolve_host(self._proxy.host, self._proxy.port) + exc = None for hinfo in proxy_hosts: try: @@ -65,8 +70,29 @@ class SocksConnector(aiohttp.TCPConnector): transp, proto = yield from create_connection( self._factory, proxy, self._proxy_auth, dst, loop=self._loop, remote_resolve=self._remote_resolve, - ssl=None, family=hinfo['family'], proto=hinfo['proto'], - flags=hinfo['flags'], local_addr=self._local_addr) + ssl=sslcontext, family=hinfo['family'], + proto=hinfo['proto'], flags=hinfo['flags'], + local_addr=self._local_addr, + server_hostname=req.host if sslcontext else None) + + has_cert = transp.get_extra_info('sslcontext') + if has_cert and self._fingerprint: + sock = transp.get_extra_info('socket') + if not hasattr(sock, 'getpeercert'): + # Workaround for asyncio 3.5.0 + # Starting from 3.5.1 version + # there is 'ssl_object' extra info in transport + sock = transp._ssl_protocol._sslpipe.ssl_object + # gives DER-encoded cert as a sequence of bytes (or None) + cert = sock.getpeercert(binary_form=True) + assert cert + got = self._hashfunc(cert).digest() + expected = self._fingerprint + if got != expected: + transp.close() + raise aiohttp.FingerprintMismatch( + expected, got, req.host, 80 + ) return transp, proto except (OSError, SocksError, SocksConnectionError) as e: diff --git a/aiosocks/protocols.py b/aiosocks/protocols.py index 72a0434..bd7cbcf 100644 --- a/aiosocks/protocols.py +++ b/aiosocks/protocols.py @@ -17,7 +17,9 @@ except ImportError: class BaseSocksProtocol(asyncio.StreamReaderProtocol): - def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): + def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, + remote_resolve=True, loop=None, ssl=False, + server_hostname=None): if not isinstance(dst, (tuple, list)) or len(dst) != 2: raise ValueError( 'Invalid dst format, tuple("dst_host", dst_port))' @@ -30,18 +32,79 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): self._loop = loop or asyncio.get_event_loop() self._transport = None - self._negotiate_done = None + self._waiter = waiter + self._negotiate_fut = None + self._ssl = ssl + self._server_hostname = server_hostname + + if app_protocol_factory: + self._app_protocol = app_protocol_factory() + else: + self._app_protocol = self reader = asyncio.StreamReader(loop=self._loop) super().__init__(stream_reader=reader, loop=self._loop) def connection_made(self, transport): + # connection_made is called + if self._transport: + return + super().connection_made(transport) self._transport = transport + def init_app_protocol(fut): + exc = fut.exception() + if exc: + if isinstance(exc, SocksError): + exc = SocksError('Can not connect to %s:%s. %s' % + (self._dst_host, self._dst_port, exc)) + self._waiter.set_exception(exc) + else: + if self._ssl: + sock = self._transport.get_extra_info('socket') + + self._transport = self._loop._make_ssl_transport( + rawsock=sock, protocol=self._app_protocol, + sslcontext=self._ssl, server_side=False, + server_hostname=self._server_hostname, + waiter=self._waiter) + else: + self._app_protocol.connection_made(transport) + self._waiter.set_result(True) + req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) - self._negotiate_done = ensure_future(req_coro, loop=self._loop) + self._negotiate_fut = ensure_future(req_coro, loop=self._loop) + self._negotiate_fut.add_done_callback(init_app_protocol) + + def connection_lost(self, exc): + if self._negotiate_fut.done() and not self._negotiate_fut.exception(): + self._loop.call_soon(self._app_protocol.connection_lost, exc) + super().connection_lost(exc) + + def pause_writing(self): + if self._negotiate_fut.done(): + self._app_protocol.pause_writing() + else: + super().pause_writing() + + def resume_writing(self): + if self._negotiate_fut.done(): + self._app_protocol.resume_writing() + else: + super().resume_writing() + + def data_received(self, data): + if self._negotiate_fut.done(): + self._app_protocol.data_received(data) + else: + super().data_received(data) + + def eof_received(self): + if self._negotiate_fut.done() and not self._negotiate_fut.exception(): + self._app_protocol.eof_received() + super().eof_received() @asyncio.coroutine def socks_request(self, cmd): @@ -74,12 +137,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): raise OSError('getaddrinfo() returned empty list') return infos[0][0], infos[0][4][0] - def negotiate_done(self): - return self._negotiate_done + @property + def app_protocol(self): + return self._app_protocol + + @property + def app_transport(self): + return self._transport class Socks4Protocol(BaseSocksProtocol): - def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): + def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, + remote_resolve=True, loop=None, ssl=False, + server_hostname=None): proxy_auth = proxy_auth or Socks4Auth('') if not isinstance(proxy, Socks4Addr): @@ -88,7 +158,8 @@ class Socks4Protocol(BaseSocksProtocol): if not isinstance(proxy_auth, Socks4Auth): raise ValueError('Invalid proxy_auth format') - super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) + super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, + remote_resolve, loop, ssl, server_hostname) @asyncio.coroutine def socks_request(self, cmd): @@ -130,7 +201,9 @@ class Socks4Protocol(BaseSocksProtocol): class Socks5Protocol(BaseSocksProtocol): - def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): + def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, + remote_resolve=True, loop=None, ssl=False, + server_hostname=None): proxy_auth = proxy_auth or Socks5Auth('', '') if not isinstance(proxy, Socks5Addr): @@ -139,7 +212,8 @@ class Socks5Protocol(BaseSocksProtocol): if not isinstance(proxy_auth, Socks5Auth): raise ValueError('Invalid proxy_auth format') - super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) + super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, + remote_resolve, loop, ssl, server_hostname) @asyncio.coroutine def socks_request(self, cmd): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/socks_serv.py b/tests/socks_serv.py new file mode 100644 index 0000000..87628db --- /dev/null +++ b/tests/socks_serv.py @@ -0,0 +1,25 @@ +import asyncio +import socket +import functools + + +def find_unused_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(('127.0.0.1', 0)) + port = s.getsockname()[1] + s.close() + return port + + +@asyncio.coroutine +def socks_handler(reader, writer, write_buff): + writer.write(write_buff) + + +@asyncio.coroutine +def fake_socks_srv(loop, write_buff): + port = find_unused_port() + handler = functools.partial(socks_handler, write_buff=write_buff) + srv = yield from asyncio.start_server( + handler, '127.0.0.1', port, family=socket.AF_INET, loop=loop) + return srv, port diff --git a/tests/test_connector.py b/tests/test_connector.py index c16bec1..e614b87 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -24,7 +24,12 @@ class TestSocksConnector(unittest.TestCase): return mock.Mock(side_effect=coroutine(coro)) - def test_connect_proxy_ip(self): + @mock.patch('aiosocks.connector.create_connection') + def test_connect_proxy_ip(self, cr_conn_mock): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + cr_conn_mock.side_effect = \ + self._fake_coroutine((tr, proto)).side_effect + loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) @@ -33,21 +38,18 @@ class TestSocksConnector(unittest.TestCase): loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - proto.negotiate_done = self._fake_coroutine(True) - loop_mock.create_connection = self._fake_coroutine((tr, proto)) - conn = self.loop.run_until_complete(connector.connect(req)) self.assertTrue(loop_mock.getaddrinfo.is_called) self.assertIs(conn._transport, tr) - self.assertTrue( - isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) - ) conn.close() - def test_connect_proxy_domain(self): + @mock.patch('aiosocks.connector.create_connection') + def test_connect_proxy_domain(self, cr_conn_mock): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + cr_conn_mock.side_effect = \ + self._fake_coroutine((tr, proto)).side_effect loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) @@ -56,60 +58,53 @@ class TestSocksConnector(unittest.TestCase): connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - proto.negotiate_done = self._fake_coroutine(True) - loop_mock.create_connection = self._fake_coroutine((tr, proto)) - conn = self.loop.run_until_complete(connector.connect(req)) self.assertTrue(connector._resolve_host.is_called) self.assertEqual(connector._resolve_host.call_count, 1) self.assertIs(conn._transport, tr) - self.assertTrue( - isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) - ) conn.close() - def test_connect_locale_resolve(self): - loop_mock = mock.Mock() + @mock.patch('aiosocks.connector.create_connection') + def test_connect_locale_resolve(self, cr_conn_mock): + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + cr_conn_mock.side_effect = \ + self._fake_coroutine((tr, proto)).side_effect req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), - None, loop=loop_mock, remote_resolve=False) + None, loop=self.loop, remote_resolve=False) connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - proto.negotiate_done = self._fake_coroutine(True) - loop_mock.create_connection = self._fake_coroutine((tr, proto)) - conn = self.loop.run_until_complete(connector.connect(req)) self.assertTrue(connector._resolve_host.is_called) self.assertEqual(connector._resolve_host.call_count, 2) - self.assertIs(conn._transport, tr) - self.assertTrue( - isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) - ) conn.close() - def test_proxy_connect_fail(self): + @mock.patch('aiosocks.connector.create_connection') + def test_proxy_connect_fail(self, cr_conn_mock): loop_mock = mock.Mock() + cr_conn_mock.side_effect = \ + self._fake_coroutine(aiosocks.SocksConnectionError()).side_effect req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), None, loop=loop_mock) loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) - loop_mock.create_connection = self._fake_coroutine(OSError()) with self.assertRaises(aiohttp.ProxyConnectionError): self.loop.run_until_complete(connector.connect(req)) - def test_proxy_negotiate_fail(self): + @mock.patch('aiosocks.connector.create_connection') + def test_proxy_negotiate_fail(self, cr_conn_mock): loop_mock = mock.Mock() + cr_conn_mock.side_effect = \ + self._fake_coroutine(aiosocks.SocksError()).side_effect req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), @@ -117,9 +112,5 @@ class TestSocksConnector(unittest.TestCase): loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) - tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') - proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) - loop_mock.create_connection = self._fake_coroutine((tr, proto)) - with self.assertRaises(aiosocks.SocksError): self.loop.run_until_complete(connector.connect(req)) diff --git a/tests/test_create_connect.py b/tests/test_create_connect.py index fb390e3..e8aca55 100644 --- a/tests/test_create_connect.py +++ b/tests/test_create_connect.py @@ -2,6 +2,7 @@ import unittest import aiosocks import asyncio from unittest import mock +from .socks_serv import fake_socks_srv try: from asyncio import ensure_future @@ -74,6 +75,15 @@ class TestCreateConnection(unittest.TestCase): self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', str(ct.exception)) + # test ssl, server_hostname + with self.assertRaises(ValueError) as ct: + conn = aiosocks.create_connection( + None, addr, auth, dst, server_hostname='python.org' + ) + self.loop.run_until_complete(conn) + self.assertIn('server_hostname is only meaningful with ssl', + str(ct.exception)) + def test_connection_fail(self): addr = aiosocks.Socks5Addr('localhost') auth = aiosocks.Socks5Auth('usr', 'pwd') @@ -88,45 +98,323 @@ class TestCreateConnection(unittest.TestCase): ) self.loop.run_until_complete(conn) - def test_negotiate_fail(self): - addr = aiosocks.Socks5Addr('localhost') - auth = aiosocks.Socks5Auth('usr', 'pwd') + +class TestCreateSocks4Connection(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_connect_success(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x00\x5a\x04W\x01\x01\x01\x01test') + ) + addr = aiosocks.Socks4Addr('127.0.0.1', port) + auth = aiosocks.Socks4Auth('usr') dst = ('python.org', 80) - transp, proto = mock.Mock(), mock.Mock() - proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) - loop_mock = mock.Mock() - loop_mock.create_connection = self._fake_coroutine((transp, proto)) + _, addr = protocol._negotiate_fut.result() + self.assertEqual(addr, ('1.1.1.1', 1111)) + + data = self.loop.run_until_complete(protocol._stream_reader.read(4)) + self.assertEqual(data, b'test') + + server.close() + transport.close() + + def test_invalid_ver(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x01\x5a\x04W\x01\x01\x01\x01') + ) + addr = aiosocks.Socks4Addr('127.0.0.1', port) + auth = aiosocks.Socks4Auth('usr') + dst = ('python.org', 80) with self.assertRaises(aiosocks.SocksError) as ct: - conn = aiosocks.create_connection( - None, addr, auth, dst, loop=loop_mock + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('invalid data', str(ct.exception)) + + server.close() + + def test_access_not_granted(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x00\x5b\x04W\x01\x01\x01\x01') + ) + addr = aiosocks.Socks4Addr('127.0.0.1', port) + auth = aiosocks.Socks4Auth('usr') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('0x5b', str(ct.exception)) + + server.close() + + +class TestCreateSocks5Connect(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_connect_success_anonymous(self): + server, port = self.loop.run_until_complete( + fake_socks_srv( + self.loop, + b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' ) - self.loop.run_until_complete(conn) - self.assertIn('Can not connect to python.org:80', + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + + _, addr = protocol._negotiate_fut.result() + self.assertEqual(addr, ('1.1.1.1', 1111)) + + data = self.loop.run_until_complete(protocol._stream_reader.read(4)) + self.assertEqual(data, b'test') + + server.close() + transport.close() + + def test_connect_success_usr_pwd(self): + server, port = self.loop.run_until_complete( + fake_socks_srv( + self.loop, + b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' + ) + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + + _, addr = protocol._negotiate_fut.result() + self.assertEqual(addr, ('1.1.1.1', 1111)) + + data = self.loop.run_until_complete(protocol._stream_reader.read(4)) + self.assertEqual(data, b'test') + + server.close() + transport.close() + + def test_auth_ver_err(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x04\x02') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('invalid version', str(ct.exception)) + + server.close() + + def test_auth_method_rejected(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\xFF') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('authentication methods were rejected', str(ct.exception)) - def test_create_protocol(self): - addr = aiosocks.Socks5Addr('localhost') + server.close() + + def test_auth_status_invalid(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\xF0') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) auth = aiosocks.Socks5Auth('usr', 'pwd') dst = ('python.org', 80) - transp, proto = mock.Mock(), mock.Mock() - proto.negotiate_done = self._fake_coroutine(True) + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('invalid data', str(ct.exception)) - loop_mock = mock.Mock() - loop_mock.create_connection = self._fake_coroutine((transp, proto)) + server.close() + + def test_auth_status_invalid2(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\x02\x02\x00') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('invalid data', str(ct.exception)) - user_proto = mock.Mock() + server.close() - conn = aiosocks.create_connection( - lambda: user_proto, addr, auth, dst, loop=loop_mock + def test_auth_failed(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\x02\x01\x01') ) - fut = ensure_future(conn, loop=self.loop) - self.loop.run_until_complete(fut) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('authentication failed', str(ct.exception)) + + server.close() + + def test_cmd_ver_err(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x04\x00\x00') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('invalid version', str(ct.exception)) + + server.close() + + def test_cmd_not_granted(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x01\x00') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('General SOCKS server failure', str(ct.exception)) + + server.close() + + def test_invalid_address_type(self): + server, port = self.loop.run_until_complete( + fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x00\x00\xFF') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + with self.assertRaises(aiosocks.SocksError) as ct: + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + transport.close() + self.assertIn('invalid data', str(ct.exception)) + + server.close() + + def test_atype_ipv4(self): + server, port = self.loop.run_until_complete( + fake_socks_srv( + self.loop, + b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W' + ) + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + + _, addr = protocol._negotiate_fut.result() + self.assertEqual(addr, ('1.1.1.1', 1111)) + + transport.close() + server.close() + + def test_atype_ipv6(self): + server, port = self.loop.run_until_complete( + fake_socks_srv( + self.loop, + b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W') + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + + _, addr = protocol._negotiate_fut.result() + self.assertEqual(addr, ('::111', 1111)) + + transport.close() + server.close() + + def test_atype_domain(self): + server, port = self.loop.run_until_complete( + fake_socks_srv( + self.loop, + b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W' + ) + ) + addr = aiosocks.Socks5Addr('127.0.0.1', port) + auth = aiosocks.Socks5Auth('usr', 'pwd') + dst = ('python.org', 80) + + coro = aiosocks.create_connection( + None, addr, auth, dst, loop=self.loop) + transport, protocol = self.loop.run_until_complete(coro) + + _, addr = protocol._negotiate_fut.result() + self.assertEqual(addr, (b'python.org', 1111)) - transport, protocol = fut.result() - self.assertIs(transport, transp) - self.assertIs(protocol, user_proto) - self.assertIs(transport._protocol, user_proto) + transport.close() + server.close() diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 31e3f4c..66ce0f5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -13,13 +13,24 @@ except ImportError: ensure_future = asyncio.async -def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): +def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None): + dst = dst or ('python.org', 80) + + proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl, + loop=loop, waiter=waiter, + app_protocol_factory=ap_factory) + return proto + + +def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'', + ap_factory=None, whiter=None): addr = addr or aiosocks.Socks4Addr('localhost', 1080) auth = auth or aiosocks.Socks4Auth('user') dst = dst or ('python.org', 80) proto = aiosocks.Socks4Protocol( - proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) + proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, + loop=loop, app_protocol_factory=ap_factory, waiter=whiter) proto._transport = mock.Mock() proto.read_response = mock.Mock( side_effect=coro(mock.Mock(return_value=r))) @@ -30,13 +41,15 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): return proto -def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None): +def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, + ap_factory=None, whiter=None): addr = addr or aiosocks.Socks5Addr('localhost', 1080) auth = auth or aiosocks.Socks5Auth('user', 'pwd') dst = dst or ('python.org', 80) proto = aiosocks.Socks5Protocol( - proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) + proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, + loop=loop, app_protocol_factory=ap_factory, waiter=whiter) proto._transport = mock.Mock() if not isinstance(r, (list, tuple)): @@ -63,17 +76,19 @@ class TestBaseSocksProtocol(unittest.TestCase): def test_init(self): with self.assertRaises(ValueError): - BaseSocksProtocol(None, None, None, loop=self.loop) + BaseSocksProtocol(None, None, None, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): - BaseSocksProtocol(None, None, 123, loop=self.loop) + BaseSocksProtocol(None, None, 123, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): - BaseSocksProtocol(None, None, ('python.org',), loop=self.loop) + BaseSocksProtocol(None, None, ('python.org',), loop=self.loop, + waiter=None, app_protocol_factory=None) def test_write_request(self): - proto = BaseSocksProtocol(None, None, ('python.org', 80), - loop=self.loop) + proto = make_base(self.loop) proto._transport = mock.Mock() proto.write_request([b'\x00', b'\x01\x02', 0x03]) @@ -82,6 +97,180 @@ class TestBaseSocksProtocol(unittest.TestCase): with self.assertRaises(ValueError): proto.write_request(['\x00']) + @mock.patch('aiosocks.protocols.ensure_future') + def test_connection_made_os_error(self, ef_mock): + os_err_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = os_err_fut + + waiter = asyncio.Future(loop=self.loop) + proto = make_base(self.loop, waiter=waiter) + proto.connection_made(mock.Mock()) + + self.assertIs(proto._negotiate_fut, os_err_fut) + + with self.assertRaises(OSError): + os_err_fut.set_exception(OSError('test')) + self.loop.run_until_complete(os_err_fut) + self.assertIn('test', str(waiter.exception())) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_connection_made_socks_err(self, ef_mock): + socks_err_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = socks_err_fut + + waiter = asyncio.Future(loop=self.loop) + proto = make_base(self.loop, waiter=waiter) + proto.connection_made(mock.Mock()) + + self.assertIs(proto._negotiate_fut, socks_err_fut) + + with self.assertRaises(aiosocks.SocksError): + socks_err_fut.set_exception(aiosocks.SocksError('test')) + self.loop.run_until_complete(socks_err_fut) + self.assertIn('Can not connect to', str(waiter.exception())) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_connection_made_without_app_proto(self, ef_mock): + success_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = success_fut + + waiter = asyncio.Future(loop=self.loop) + proto = make_base(self.loop, waiter=waiter) + proto.connection_made(mock.Mock()) + + self.assertIs(proto._negotiate_fut, success_fut) + + success_fut.set_result(True) + self.loop.run_until_complete(success_fut) + self.assertTrue(waiter.done()) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_connection_made_with_app_proto(self, ef_mock): + success_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = success_fut + + waiter = asyncio.Future(loop=self.loop) + proto = make_base(self.loop, waiter=waiter, + ap_factory=lambda: asyncio.Protocol()) + proto.connection_made(mock.Mock()) + + self.assertIs(proto._negotiate_fut, success_fut) + + success_fut.set_result(True) + self.loop.run_until_complete(success_fut) + self.assertTrue(waiter.done()) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_connection_lost(self, ef_mock): + negotiate_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = negotiate_fut + app_proto = mock.Mock() + + loop_mock = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + proto.connection_made(mock.Mock()) + + # negotiate not completed + proto.connection_lost(True) + self.assertFalse(loop_mock.call_soon.called) + + # negotiate successfully competed + negotiate_fut.set_result(True) + proto.connection_lost(True) + self.assertTrue(loop_mock.call_soon.called) + + # negotiate failed + negotiate_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = negotiate_fut + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + proto.connection_made(mock.Mock()) + + negotiate_fut.set_exception(Exception()) + proto.connection_lost(True) + self.assertTrue(loop_mock.call_soon.called) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_pause_writing(self, ef_mock): + negotiate_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = negotiate_fut + app_proto = mock.Mock() + + loop_mock = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + proto.connection_made(mock.Mock()) + + # negotiate not completed + proto.pause_writing() + self.assertFalse(app_proto.pause_writing.called) + + # negotiate successfully competed + negotiate_fut.set_result(True) + proto.pause_writing() + self.assertTrue(app_proto.pause_writing.called) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_resume_writing(self, ef_mock): + negotiate_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = negotiate_fut + app_proto = mock.Mock() + + loop_mock = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + proto.connection_made(mock.Mock()) + + # negotiate not completed + with self.assertRaises(AssertionError): + proto.resume_writing() + + # negotiate fail + negotiate_fut.set_exception(Exception()) + proto.resume_writing() + self.assertTrue(app_proto.resume_writing.called) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_data_received(self, ef_mock): + negotiate_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = negotiate_fut + app_proto = mock.Mock() + + loop_mock = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + proto.connection_made(mock.Mock()) + + # negotiate not completed + proto.data_received(b'123') + self.assertFalse(app_proto.data_received.called) + + # negotiate successfully competed + negotiate_fut.set_result(True) + proto.data_received(b'123') + self.assertTrue(app_proto.data_received.called) + + @mock.patch('aiosocks.protocols.ensure_future') + def test_eof_received(self, ef_mock): + negotiate_fut = asyncio.Future(loop=self.loop) + ef_mock.return_value = negotiate_fut + app_proto = mock.Mock() + + loop_mock = mock.Mock() + + proto = make_base(loop_mock, ap_factory=lambda: app_proto) + proto.connection_made(mock.Mock()) + + # negotiate not completed + proto.eof_received() + self.assertFalse(app_proto.eof_received.called) + + # negotiate successfully competed + negotiate_fut.set_result(True) + proto.eof_received() + self.assertTrue(app_proto.eof_received.called) + class TestSocks4Protocol(unittest.TestCase): def setUp(self): @@ -97,21 +286,27 @@ class TestSocks4Protocol(unittest.TestCase): dst = ('python.org', 80) with self.assertRaises(ValueError): - aiosocks.Socks4Protocol(None, None, dst, loop=self.loop) + aiosocks.Socks4Protocol(None, None, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): - aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop) + aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, - loop=self.loop) + loop=self.loop, waiter=None, + app_protocol_factory=None) with self.assertRaises(ValueError): aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, - loop=self.loop) + loop=self.loop, waiter=None, + app_protocol_factory=None) - aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop) - aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop) + aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) + aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) def test_request_building(self): resp = b'\x00\x5a\x00P\x7f\x00\x00\x01' @@ -230,21 +425,27 @@ class TestSocks5Protocol(unittest.TestCase): dst = ('python.org', 80) with self.assertRaises(ValueError): - aiosocks.Socks5Protocol(None, None, dst, loop=self.loop) + aiosocks.Socks5Protocol(None, None, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): - aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop) + aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), - auth, dst, loop=self.loop) + auth, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) with self.assertRaises(ValueError): aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), - dst, loop=self.loop) + dst, loop=self.loop, + waiter=None, app_protocol_factory=None) - aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop) - aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop) + aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) + aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop, + waiter=None, app_protocol_factory=None) def test_authenticate(self): # invalid server version