| @@ -47,7 +47,15 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, | |||||
| "proxy is Socks5Addr but proxy_auth is not Socks5Auth" | "proxy is Socks5Addr but proxy_auth is not Socks5Auth" | ||||
| ) | ) | ||||
| if server_hostname is not None and not ssl: | |||||
| raise ValueError('server_hostname is only meaningful with ssl') | |||||
| if server_hostname is None and ssl: | |||||
| # read details: asyncio.create_connection | |||||
| server_hostname = dst[0] | |||||
| loop = loop or asyncio.get_event_loop() | loop = loop or asyncio.get_event_loop() | ||||
| waiter = asyncio.Future(loop=loop) | |||||
| def socks_factory(): | def socks_factory(): | ||||
| if isinstance(proxy, Socks4Addr): | if isinstance(proxy, Socks4Addr): | ||||
| @@ -55,30 +63,24 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, | |||||
| else: | else: | ||||
| socks_proto = Socks5Protocol | socks_proto = Socks5Protocol | ||||
| return socks_proto( | |||||
| proxy=proxy, proxy_auth=proxy_auth, dst=dst, | |||||
| remote_resolve=remote_resolve, loop=loop) | |||||
| return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst, | |||||
| app_protocol_factory=protocol_factory, | |||||
| waiter=waiter, remote_resolve=remote_resolve, | |||||
| loop=loop, ssl=ssl, server_hostname=server_hostname) | |||||
| try: | try: | ||||
| transport, protocol = yield from loop.create_connection( | transport, protocol = yield from loop.create_connection( | ||||
| socks_factory, proxy.host, proxy.port, ssl=ssl, family=family, | |||||
| proto=proto, flags=flags, sock=sock, local_addr=local_addr, | |||||
| server_hostname=server_hostname) | |||||
| socks_factory, proxy.host, proxy.port, family=family, | |||||
| proto=proto, flags=flags, sock=sock, local_addr=local_addr) | |||||
| except OSError as exc: | except OSError as exc: | ||||
| raise SocksConnectionError( | raise SocksConnectionError( | ||||
| '[Errno %s] Can not connect to proxy %s:%d [%s]' % | '[Errno %s] Can not connect to proxy %s:%d [%s]' % | ||||
| (exc.errno, proxy.host, proxy.port, exc.strerror)) from exc | (exc.errno, proxy.host, proxy.port, exc.strerror)) from exc | ||||
| # Wait until communication with proxy server is finished | |||||
| try: | try: | ||||
| yield from protocol.negotiate_done() | |||||
| except SocksError as exc: | |||||
| raise SocksError('Can not connect to %s:%s [%s]' % | |||||
| (dst[0], dst[1], exc)) | |||||
| if protocol_factory: | |||||
| protocol = protocol_factory() | |||||
| protocol.connection_made(transport) | |||||
| transport._protocol = protocol | |||||
| yield from waiter | |||||
| except: | |||||
| transport.close() | |||||
| raise | |||||
| return transport, protocol | |||||
| return protocol.app_transport, protocol.app_protocol | |||||
| @@ -33,12 +33,16 @@ class SocksConnector(aiohttp.TCPConnector): | |||||
| @asyncio.coroutine | @asyncio.coroutine | ||||
| def _create_connection(self, req): | def _create_connection(self, req): | ||||
| if req.ssl: | |||||
| sslcontext = self.ssl_context | |||||
| else: | |||||
| sslcontext = None | |||||
| if not self._remote_resolve: | if not self._remote_resolve: | ||||
| dst_hosts = yield from self._resolve_host(req.host, req.port) | dst_hosts = yield from self._resolve_host(req.host, req.port) | ||||
| dst = dst_hosts[0]['host'], dst_hosts[0]['port'] | dst = dst_hosts[0]['host'], dst_hosts[0]['port'] | ||||
| else: | else: | ||||
| dst = req.host, req.port | dst = req.host, req.port | ||||
| exc = None | |||||
| # if self._resolver is AsyncResolver and self._proxy.host | # if self._resolver is AsyncResolver and self._proxy.host | ||||
| # is ip address, then aiodns raise DNSError. | # is ip address, then aiodns raise DNSError. | ||||
| @@ -56,6 +60,7 @@ class SocksConnector(aiohttp.TCPConnector): | |||||
| except ValueError: | except ValueError: | ||||
| proxy_hosts = yield from self._resolve_host(self._proxy.host, | proxy_hosts = yield from self._resolve_host(self._proxy.host, | ||||
| self._proxy.port) | self._proxy.port) | ||||
| exc = None | |||||
| for hinfo in proxy_hosts: | for hinfo in proxy_hosts: | ||||
| try: | try: | ||||
| @@ -65,8 +70,29 @@ class SocksConnector(aiohttp.TCPConnector): | |||||
| transp, proto = yield from create_connection( | transp, proto = yield from create_connection( | ||||
| self._factory, proxy, self._proxy_auth, dst, | self._factory, proxy, self._proxy_auth, dst, | ||||
| loop=self._loop, remote_resolve=self._remote_resolve, | loop=self._loop, remote_resolve=self._remote_resolve, | ||||
| ssl=None, family=hinfo['family'], proto=hinfo['proto'], | |||||
| flags=hinfo['flags'], local_addr=self._local_addr) | |||||
| ssl=sslcontext, family=hinfo['family'], | |||||
| proto=hinfo['proto'], flags=hinfo['flags'], | |||||
| local_addr=self._local_addr, | |||||
| server_hostname=req.host if sslcontext else None) | |||||
| has_cert = transp.get_extra_info('sslcontext') | |||||
| if has_cert and self._fingerprint: | |||||
| sock = transp.get_extra_info('socket') | |||||
| if not hasattr(sock, 'getpeercert'): | |||||
| # Workaround for asyncio 3.5.0 | |||||
| # Starting from 3.5.1 version | |||||
| # there is 'ssl_object' extra info in transport | |||||
| sock = transp._ssl_protocol._sslpipe.ssl_object | |||||
| # gives DER-encoded cert as a sequence of bytes (or None) | |||||
| cert = sock.getpeercert(binary_form=True) | |||||
| assert cert | |||||
| got = self._hashfunc(cert).digest() | |||||
| expected = self._fingerprint | |||||
| if got != expected: | |||||
| transp.close() | |||||
| raise aiohttp.FingerprintMismatch( | |||||
| expected, got, req.host, 80 | |||||
| ) | |||||
| return transp, proto | return transp, proto | ||||
| except (OSError, SocksError, SocksConnectionError) as e: | except (OSError, SocksError, SocksConnectionError) as e: | ||||
| @@ -17,7 +17,9 @@ except ImportError: | |||||
| class BaseSocksProtocol(asyncio.StreamReaderProtocol): | class BaseSocksProtocol(asyncio.StreamReaderProtocol): | ||||
| def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||||
| def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
| remote_resolve=True, loop=None, ssl=False, | |||||
| server_hostname=None): | |||||
| if not isinstance(dst, (tuple, list)) or len(dst) != 2: | if not isinstance(dst, (tuple, list)) or len(dst) != 2: | ||||
| raise ValueError( | raise ValueError( | ||||
| 'Invalid dst format, tuple("dst_host", dst_port))' | 'Invalid dst format, tuple("dst_host", dst_port))' | ||||
| @@ -30,18 +32,79 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| self._loop = loop or asyncio.get_event_loop() | self._loop = loop or asyncio.get_event_loop() | ||||
| self._transport = None | self._transport = None | ||||
| self._negotiate_done = None | |||||
| self._waiter = waiter | |||||
| self._negotiate_fut = None | |||||
| self._ssl = ssl | |||||
| self._server_hostname = server_hostname | |||||
| if app_protocol_factory: | |||||
| self._app_protocol = app_protocol_factory() | |||||
| else: | |||||
| self._app_protocol = self | |||||
| reader = asyncio.StreamReader(loop=self._loop) | reader = asyncio.StreamReader(loop=self._loop) | ||||
| super().__init__(stream_reader=reader, loop=self._loop) | super().__init__(stream_reader=reader, loop=self._loop) | ||||
| def connection_made(self, transport): | def connection_made(self, transport): | ||||
| # connection_made is called | |||||
| if self._transport: | |||||
| return | |||||
| super().connection_made(transport) | super().connection_made(transport) | ||||
| self._transport = transport | self._transport = transport | ||||
| def init_app_protocol(fut): | |||||
| exc = fut.exception() | |||||
| if exc: | |||||
| if isinstance(exc, SocksError): | |||||
| exc = SocksError('Can not connect to %s:%s. %s' % | |||||
| (self._dst_host, self._dst_port, exc)) | |||||
| self._waiter.set_exception(exc) | |||||
| else: | |||||
| if self._ssl: | |||||
| sock = self._transport.get_extra_info('socket') | |||||
| self._transport = self._loop._make_ssl_transport( | |||||
| rawsock=sock, protocol=self._app_protocol, | |||||
| sslcontext=self._ssl, server_side=False, | |||||
| server_hostname=self._server_hostname, | |||||
| waiter=self._waiter) | |||||
| else: | |||||
| self._app_protocol.connection_made(transport) | |||||
| self._waiter.set_result(True) | |||||
| req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | ||||
| self._negotiate_done = ensure_future(req_coro, loop=self._loop) | |||||
| self._negotiate_fut = ensure_future(req_coro, loop=self._loop) | |||||
| self._negotiate_fut.add_done_callback(init_app_protocol) | |||||
| def connection_lost(self, exc): | |||||
| if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||||
| self._loop.call_soon(self._app_protocol.connection_lost, exc) | |||||
| super().connection_lost(exc) | |||||
| def pause_writing(self): | |||||
| if self._negotiate_fut.done(): | |||||
| self._app_protocol.pause_writing() | |||||
| else: | |||||
| super().pause_writing() | |||||
| def resume_writing(self): | |||||
| if self._negotiate_fut.done(): | |||||
| self._app_protocol.resume_writing() | |||||
| else: | |||||
| super().resume_writing() | |||||
| def data_received(self, data): | |||||
| if self._negotiate_fut.done(): | |||||
| self._app_protocol.data_received(data) | |||||
| else: | |||||
| super().data_received(data) | |||||
| def eof_received(self): | |||||
| if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||||
| self._app_protocol.eof_received() | |||||
| super().eof_received() | |||||
| @asyncio.coroutine | @asyncio.coroutine | ||||
| def socks_request(self, cmd): | def socks_request(self, cmd): | ||||
| @@ -74,12 +137,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| raise OSError('getaddrinfo() returned empty list') | raise OSError('getaddrinfo() returned empty list') | ||||
| return infos[0][0], infos[0][4][0] | return infos[0][0], infos[0][4][0] | ||||
| def negotiate_done(self): | |||||
| return self._negotiate_done | |||||
| @property | |||||
| def app_protocol(self): | |||||
| return self._app_protocol | |||||
| @property | |||||
| def app_transport(self): | |||||
| return self._transport | |||||
| class Socks4Protocol(BaseSocksProtocol): | class Socks4Protocol(BaseSocksProtocol): | ||||
| def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||||
| def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
| remote_resolve=True, loop=None, ssl=False, | |||||
| server_hostname=None): | |||||
| proxy_auth = proxy_auth or Socks4Auth('') | proxy_auth = proxy_auth or Socks4Auth('') | ||||
| if not isinstance(proxy, Socks4Addr): | if not isinstance(proxy, Socks4Addr): | ||||
| @@ -88,7 +158,8 @@ class Socks4Protocol(BaseSocksProtocol): | |||||
| if not isinstance(proxy_auth, Socks4Auth): | if not isinstance(proxy_auth, Socks4Auth): | ||||
| raise ValueError('Invalid proxy_auth format') | raise ValueError('Invalid proxy_auth format') | ||||
| super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) | |||||
| super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
| remote_resolve, loop, ssl, server_hostname) | |||||
| @asyncio.coroutine | @asyncio.coroutine | ||||
| def socks_request(self, cmd): | def socks_request(self, cmd): | ||||
| @@ -130,7 +201,9 @@ class Socks4Protocol(BaseSocksProtocol): | |||||
| class Socks5Protocol(BaseSocksProtocol): | class Socks5Protocol(BaseSocksProtocol): | ||||
| def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None): | |||||
| def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
| remote_resolve=True, loop=None, ssl=False, | |||||
| server_hostname=None): | |||||
| proxy_auth = proxy_auth or Socks5Auth('', '') | proxy_auth = proxy_auth or Socks5Auth('', '') | ||||
| if not isinstance(proxy, Socks5Addr): | if not isinstance(proxy, Socks5Addr): | ||||
| @@ -139,7 +212,8 @@ class Socks5Protocol(BaseSocksProtocol): | |||||
| if not isinstance(proxy_auth, Socks5Auth): | if not isinstance(proxy_auth, Socks5Auth): | ||||
| raise ValueError('Invalid proxy_auth format') | raise ValueError('Invalid proxy_auth format') | ||||
| super().__init__(proxy, proxy_auth, dst, remote_resolve, loop) | |||||
| super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||||
| remote_resolve, loop, ssl, server_hostname) | |||||
| @asyncio.coroutine | @asyncio.coroutine | ||||
| def socks_request(self, cmd): | def socks_request(self, cmd): | ||||
| @@ -0,0 +1,25 @@ | |||||
| import asyncio | |||||
| import socket | |||||
| import functools | |||||
| def find_unused_port(): | |||||
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||||
| s.bind(('127.0.0.1', 0)) | |||||
| port = s.getsockname()[1] | |||||
| s.close() | |||||
| return port | |||||
| @asyncio.coroutine | |||||
| def socks_handler(reader, writer, write_buff): | |||||
| writer.write(write_buff) | |||||
| @asyncio.coroutine | |||||
| def fake_socks_srv(loop, write_buff): | |||||
| port = find_unused_port() | |||||
| handler = functools.partial(socks_handler, write_buff=write_buff) | |||||
| srv = yield from asyncio.start_server( | |||||
| handler, '127.0.0.1', port, family=socket.AF_INET, loop=loop) | |||||
| return srv, port | |||||
| @@ -24,7 +24,12 @@ class TestSocksConnector(unittest.TestCase): | |||||
| return mock.Mock(side_effect=coroutine(coro)) | return mock.Mock(side_effect=coroutine(coro)) | ||||
| def test_connect_proxy_ip(self): | |||||
| @mock.patch('aiosocks.connector.create_connection') | |||||
| def test_connect_proxy_ip(self, cr_conn_mock): | |||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| cr_conn_mock.side_effect = \ | |||||
| self._fake_coroutine((tr, proto)).side_effect | |||||
| loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
| req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
| @@ -33,21 +38,18 @@ class TestSocksConnector(unittest.TestCase): | |||||
| loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | ||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| proto.negotiate_done = self._fake_coroutine(True) | |||||
| loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
| conn = self.loop.run_until_complete(connector.connect(req)) | conn = self.loop.run_until_complete(connector.connect(req)) | ||||
| self.assertTrue(loop_mock.getaddrinfo.is_called) | self.assertTrue(loop_mock.getaddrinfo.is_called) | ||||
| self.assertIs(conn._transport, tr) | self.assertIs(conn._transport, tr) | ||||
| self.assertTrue( | |||||
| isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||||
| ) | |||||
| conn.close() | conn.close() | ||||
| def test_connect_proxy_domain(self): | |||||
| @mock.patch('aiosocks.connector.create_connection') | |||||
| def test_connect_proxy_domain(self, cr_conn_mock): | |||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| cr_conn_mock.side_effect = \ | |||||
| self._fake_coroutine((tr, proto)).side_effect | |||||
| loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
| req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
| @@ -56,60 +58,53 @@ class TestSocksConnector(unittest.TestCase): | |||||
| connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | ||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| proto.negotiate_done = self._fake_coroutine(True) | |||||
| loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
| conn = self.loop.run_until_complete(connector.connect(req)) | conn = self.loop.run_until_complete(connector.connect(req)) | ||||
| self.assertTrue(connector._resolve_host.is_called) | self.assertTrue(connector._resolve_host.is_called) | ||||
| self.assertEqual(connector._resolve_host.call_count, 1) | self.assertEqual(connector._resolve_host.call_count, 1) | ||||
| self.assertIs(conn._transport, tr) | self.assertIs(conn._transport, tr) | ||||
| self.assertTrue( | |||||
| isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||||
| ) | |||||
| conn.close() | conn.close() | ||||
| def test_connect_locale_resolve(self): | |||||
| loop_mock = mock.Mock() | |||||
| @mock.patch('aiosocks.connector.create_connection') | |||||
| def test_connect_locale_resolve(self, cr_conn_mock): | |||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| cr_conn_mock.side_effect = \ | |||||
| self._fake_coroutine((tr, proto)).side_effect | |||||
| req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
| connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), | connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), | ||||
| None, loop=loop_mock, remote_resolve=False) | |||||
| None, loop=self.loop, remote_resolve=False) | |||||
| connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) | ||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| proto.negotiate_done = self._fake_coroutine(True) | |||||
| loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
| conn = self.loop.run_until_complete(connector.connect(req)) | conn = self.loop.run_until_complete(connector.connect(req)) | ||||
| self.assertTrue(connector._resolve_host.is_called) | self.assertTrue(connector._resolve_host.is_called) | ||||
| self.assertEqual(connector._resolve_host.call_count, 2) | self.assertEqual(connector._resolve_host.call_count, 2) | ||||
| self.assertIs(conn._transport, tr) | |||||
| self.assertTrue( | |||||
| isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) | |||||
| ) | |||||
| conn.close() | conn.close() | ||||
| def test_proxy_connect_fail(self): | |||||
| @mock.patch('aiosocks.connector.create_connection') | |||||
| def test_proxy_connect_fail(self, cr_conn_mock): | |||||
| loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
| cr_conn_mock.side_effect = \ | |||||
| self._fake_coroutine(aiosocks.SocksConnectionError()).side_effect | |||||
| req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
| connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | ||||
| None, loop=loop_mock) | None, loop=loop_mock) | ||||
| loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | ||||
| loop_mock.create_connection = self._fake_coroutine(OSError()) | |||||
| with self.assertRaises(aiohttp.ProxyConnectionError): | with self.assertRaises(aiohttp.ProxyConnectionError): | ||||
| self.loop.run_until_complete(connector.connect(req)) | self.loop.run_until_complete(connector.connect(req)) | ||||
| def test_proxy_negotiate_fail(self): | |||||
| @mock.patch('aiosocks.connector.create_connection') | |||||
| def test_proxy_negotiate_fail(self, cr_conn_mock): | |||||
| loop_mock = mock.Mock() | loop_mock = mock.Mock() | ||||
| cr_conn_mock.side_effect = \ | |||||
| self._fake_coroutine(aiosocks.SocksError()).side_effect | |||||
| req = ClientRequest('GET', 'http://python.org', loop=self.loop) | req = ClientRequest('GET', 'http://python.org', loop=self.loop) | ||||
| connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), | ||||
| @@ -117,9 +112,5 @@ class TestSocksConnector(unittest.TestCase): | |||||
| loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) | ||||
| tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') | |||||
| proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) | |||||
| loop_mock.create_connection = self._fake_coroutine((tr, proto)) | |||||
| with self.assertRaises(aiosocks.SocksError): | with self.assertRaises(aiosocks.SocksError): | ||||
| self.loop.run_until_complete(connector.connect(req)) | self.loop.run_until_complete(connector.connect(req)) | ||||
| @@ -2,6 +2,7 @@ import unittest | |||||
| import aiosocks | import aiosocks | ||||
| import asyncio | import asyncio | ||||
| from unittest import mock | from unittest import mock | ||||
| from .socks_serv import fake_socks_srv | |||||
| try: | try: | ||||
| from asyncio import ensure_future | from asyncio import ensure_future | ||||
| @@ -74,6 +75,15 @@ class TestCreateConnection(unittest.TestCase): | |||||
| self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', | self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth', | ||||
| str(ct.exception)) | str(ct.exception)) | ||||
| # test ssl, server_hostname | |||||
| with self.assertRaises(ValueError) as ct: | |||||
| conn = aiosocks.create_connection( | |||||
| None, addr, auth, dst, server_hostname='python.org' | |||||
| ) | |||||
| self.loop.run_until_complete(conn) | |||||
| self.assertIn('server_hostname is only meaningful with ssl', | |||||
| str(ct.exception)) | |||||
| def test_connection_fail(self): | def test_connection_fail(self): | ||||
| addr = aiosocks.Socks5Addr('localhost') | addr = aiosocks.Socks5Addr('localhost') | ||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | auth = aiosocks.Socks5Auth('usr', 'pwd') | ||||
| @@ -88,45 +98,323 @@ class TestCreateConnection(unittest.TestCase): | |||||
| ) | ) | ||||
| self.loop.run_until_complete(conn) | self.loop.run_until_complete(conn) | ||||
| def test_negotiate_fail(self): | |||||
| addr = aiosocks.Socks5Addr('localhost') | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| class TestCreateSocks4Connection(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.loop = asyncio.new_event_loop() | |||||
| asyncio.set_event_loop(None) | |||||
| def tearDown(self): | |||||
| self.loop.close() | |||||
| def test_connect_success(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x00\x5a\x04W\x01\x01\x01\x01test') | |||||
| ) | |||||
| addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks4Auth('usr') | |||||
| dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
| transp, proto = mock.Mock(), mock.Mock() | |||||
| proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| loop_mock = mock.Mock() | |||||
| loop_mock.create_connection = self._fake_coroutine((transp, proto)) | |||||
| _, addr = protocol._negotiate_fut.result() | |||||
| self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
| data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||||
| self.assertEqual(data, b'test') | |||||
| server.close() | |||||
| transport.close() | |||||
| def test_invalid_ver(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x01\x5a\x04W\x01\x01\x01\x01') | |||||
| ) | |||||
| addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks4Auth('usr') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | with self.assertRaises(aiosocks.SocksError) as ct: | ||||
| conn = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=loop_mock | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('invalid data', str(ct.exception)) | |||||
| server.close() | |||||
| def test_access_not_granted(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x00\x5b\x04W\x01\x01\x01\x01') | |||||
| ) | |||||
| addr = aiosocks.Socks4Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks4Auth('usr') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('0x5b', str(ct.exception)) | |||||
| server.close() | |||||
| class TestCreateSocks5Connect(unittest.TestCase): | |||||
| def setUp(self): | |||||
| self.loop = asyncio.new_event_loop() | |||||
| asyncio.set_event_loop(None) | |||||
| def tearDown(self): | |||||
| self.loop.close() | |||||
| def test_connect_success_anonymous(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv( | |||||
| self.loop, | |||||
| b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' | |||||
| ) | ) | ||||
| self.loop.run_until_complete(conn) | |||||
| self.assertIn('Can not connect to python.org:80', | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| _, addr = protocol._negotiate_fut.result() | |||||
| self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
| data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||||
| self.assertEqual(data, b'test') | |||||
| server.close() | |||||
| transport.close() | |||||
| def test_connect_success_usr_pwd(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv( | |||||
| self.loop, | |||||
| b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest' | |||||
| ) | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| _, addr = protocol._negotiate_fut.result() | |||||
| self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
| data = self.loop.run_until_complete(protocol._stream_reader.read(4)) | |||||
| self.assertEqual(data, b'test') | |||||
| server.close() | |||||
| transport.close() | |||||
| def test_auth_ver_err(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x04\x02') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('invalid version', str(ct.exception)) | |||||
| server.close() | |||||
| def test_auth_method_rejected(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\xFF') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('authentication methods were rejected', | |||||
| str(ct.exception)) | str(ct.exception)) | ||||
| def test_create_protocol(self): | |||||
| addr = aiosocks.Socks5Addr('localhost') | |||||
| server.close() | |||||
| def test_auth_status_invalid(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\xF0') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | auth = aiosocks.Socks5Auth('usr', 'pwd') | ||||
| dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
| transp, proto = mock.Mock(), mock.Mock() | |||||
| proto.negotiate_done = self._fake_coroutine(True) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('invalid data', str(ct.exception)) | |||||
| loop_mock = mock.Mock() | |||||
| loop_mock.create_connection = self._fake_coroutine((transp, proto)) | |||||
| server.close() | |||||
| def test_auth_status_invalid2(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\x02\x02\x00') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('invalid data', str(ct.exception)) | |||||
| user_proto = mock.Mock() | |||||
| server.close() | |||||
| conn = aiosocks.create_connection( | |||||
| lambda: user_proto, addr, auth, dst, loop=loop_mock | |||||
| def test_auth_failed(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\x02\x01\x01') | |||||
| ) | ) | ||||
| fut = ensure_future(conn, loop=self.loop) | |||||
| self.loop.run_until_complete(fut) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('authentication failed', str(ct.exception)) | |||||
| server.close() | |||||
| def test_cmd_ver_err(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x04\x00\x00') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('invalid version', str(ct.exception)) | |||||
| server.close() | |||||
| def test_cmd_not_granted(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x01\x00') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('General SOCKS server failure', str(ct.exception)) | |||||
| server.close() | |||||
| def test_invalid_address_type(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x00\x00\xFF') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| with self.assertRaises(aiosocks.SocksError) as ct: | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| transport.close() | |||||
| self.assertIn('invalid data', str(ct.exception)) | |||||
| server.close() | |||||
| def test_atype_ipv4(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv( | |||||
| self.loop, | |||||
| b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W' | |||||
| ) | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| _, addr = protocol._negotiate_fut.result() | |||||
| self.assertEqual(addr, ('1.1.1.1', 1111)) | |||||
| transport.close() | |||||
| server.close() | |||||
| def test_atype_ipv6(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv( | |||||
| self.loop, | |||||
| b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00' | |||||
| b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W') | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| _, addr = protocol._negotiate_fut.result() | |||||
| self.assertEqual(addr, ('::111', 1111)) | |||||
| transport.close() | |||||
| server.close() | |||||
| def test_atype_domain(self): | |||||
| server, port = self.loop.run_until_complete( | |||||
| fake_socks_srv( | |||||
| self.loop, | |||||
| b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W' | |||||
| ) | |||||
| ) | |||||
| addr = aiosocks.Socks5Addr('127.0.0.1', port) | |||||
| auth = aiosocks.Socks5Auth('usr', 'pwd') | |||||
| dst = ('python.org', 80) | |||||
| coro = aiosocks.create_connection( | |||||
| None, addr, auth, dst, loop=self.loop) | |||||
| transport, protocol = self.loop.run_until_complete(coro) | |||||
| _, addr = protocol._negotiate_fut.result() | |||||
| self.assertEqual(addr, (b'python.org', 1111)) | |||||
| transport, protocol = fut.result() | |||||
| self.assertIs(transport, transp) | |||||
| self.assertIs(protocol, user_proto) | |||||
| self.assertIs(transport._protocol, user_proto) | |||||
| transport.close() | |||||
| server.close() | |||||
| @@ -13,13 +13,24 @@ except ImportError: | |||||
| ensure_future = asyncio.async | ensure_future = asyncio.async | ||||
| def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): | |||||
| def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None): | |||||
| dst = dst or ('python.org', 80) | |||||
| proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl, | |||||
| loop=loop, waiter=waiter, | |||||
| app_protocol_factory=ap_factory) | |||||
| return proto | |||||
| def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'', | |||||
| ap_factory=None, whiter=None): | |||||
| addr = addr or aiosocks.Socks4Addr('localhost', 1080) | addr = addr or aiosocks.Socks4Addr('localhost', 1080) | ||||
| auth = auth or aiosocks.Socks4Auth('user') | auth = auth or aiosocks.Socks4Auth('user') | ||||
| dst = dst or ('python.org', 80) | dst = dst or ('python.org', 80) | ||||
| proto = aiosocks.Socks4Protocol( | proto = aiosocks.Socks4Protocol( | ||||
| proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) | |||||
| proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, | |||||
| loop=loop, app_protocol_factory=ap_factory, waiter=whiter) | |||||
| proto._transport = mock.Mock() | proto._transport = mock.Mock() | ||||
| proto.read_response = mock.Mock( | proto.read_response = mock.Mock( | ||||
| side_effect=coro(mock.Mock(return_value=r))) | side_effect=coro(mock.Mock(return_value=r))) | ||||
| @@ -30,13 +41,15 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''): | |||||
| return proto | return proto | ||||
| def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None): | |||||
| def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None, | |||||
| ap_factory=None, whiter=None): | |||||
| addr = addr or aiosocks.Socks5Addr('localhost', 1080) | addr = addr or aiosocks.Socks5Addr('localhost', 1080) | ||||
| auth = auth or aiosocks.Socks5Auth('user', 'pwd') | auth = auth or aiosocks.Socks5Auth('user', 'pwd') | ||||
| dst = dst or ('python.org', 80) | dst = dst or ('python.org', 80) | ||||
| proto = aiosocks.Socks5Protocol( | proto = aiosocks.Socks5Protocol( | ||||
| proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop) | |||||
| proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, | |||||
| loop=loop, app_protocol_factory=ap_factory, waiter=whiter) | |||||
| proto._transport = mock.Mock() | proto._transport = mock.Mock() | ||||
| if not isinstance(r, (list, tuple)): | if not isinstance(r, (list, tuple)): | ||||
| @@ -63,17 +76,19 @@ class TestBaseSocksProtocol(unittest.TestCase): | |||||
| def test_init(self): | def test_init(self): | ||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| BaseSocksProtocol(None, None, None, loop=self.loop) | |||||
| BaseSocksProtocol(None, None, None, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| BaseSocksProtocol(None, None, 123, loop=self.loop) | |||||
| BaseSocksProtocol(None, None, 123, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| BaseSocksProtocol(None, None, ('python.org',), loop=self.loop) | |||||
| BaseSocksProtocol(None, None, ('python.org',), loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| def test_write_request(self): | def test_write_request(self): | ||||
| proto = BaseSocksProtocol(None, None, ('python.org', 80), | |||||
| loop=self.loop) | |||||
| proto = make_base(self.loop) | |||||
| proto._transport = mock.Mock() | proto._transport = mock.Mock() | ||||
| proto.write_request([b'\x00', b'\x01\x02', 0x03]) | proto.write_request([b'\x00', b'\x01\x02', 0x03]) | ||||
| @@ -82,6 +97,180 @@ class TestBaseSocksProtocol(unittest.TestCase): | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| proto.write_request(['\x00']) | proto.write_request(['\x00']) | ||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_connection_made_os_error(self, ef_mock): | |||||
| os_err_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = os_err_fut | |||||
| waiter = asyncio.Future(loop=self.loop) | |||||
| proto = make_base(self.loop, waiter=waiter) | |||||
| proto.connection_made(mock.Mock()) | |||||
| self.assertIs(proto._negotiate_fut, os_err_fut) | |||||
| with self.assertRaises(OSError): | |||||
| os_err_fut.set_exception(OSError('test')) | |||||
| self.loop.run_until_complete(os_err_fut) | |||||
| self.assertIn('test', str(waiter.exception())) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_connection_made_socks_err(self, ef_mock): | |||||
| socks_err_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = socks_err_fut | |||||
| waiter = asyncio.Future(loop=self.loop) | |||||
| proto = make_base(self.loop, waiter=waiter) | |||||
| proto.connection_made(mock.Mock()) | |||||
| self.assertIs(proto._negotiate_fut, socks_err_fut) | |||||
| with self.assertRaises(aiosocks.SocksError): | |||||
| socks_err_fut.set_exception(aiosocks.SocksError('test')) | |||||
| self.loop.run_until_complete(socks_err_fut) | |||||
| self.assertIn('Can not connect to', str(waiter.exception())) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_connection_made_without_app_proto(self, ef_mock): | |||||
| success_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = success_fut | |||||
| waiter = asyncio.Future(loop=self.loop) | |||||
| proto = make_base(self.loop, waiter=waiter) | |||||
| proto.connection_made(mock.Mock()) | |||||
| self.assertIs(proto._negotiate_fut, success_fut) | |||||
| success_fut.set_result(True) | |||||
| self.loop.run_until_complete(success_fut) | |||||
| self.assertTrue(waiter.done()) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_connection_made_with_app_proto(self, ef_mock): | |||||
| success_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = success_fut | |||||
| waiter = asyncio.Future(loop=self.loop) | |||||
| proto = make_base(self.loop, waiter=waiter, | |||||
| ap_factory=lambda: asyncio.Protocol()) | |||||
| proto.connection_made(mock.Mock()) | |||||
| self.assertIs(proto._negotiate_fut, success_fut) | |||||
| success_fut.set_result(True) | |||||
| self.loop.run_until_complete(success_fut) | |||||
| self.assertTrue(waiter.done()) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_connection_lost(self, ef_mock): | |||||
| negotiate_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = negotiate_fut | |||||
| app_proto = mock.Mock() | |||||
| loop_mock = mock.Mock() | |||||
| proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
| proto.connection_made(mock.Mock()) | |||||
| # negotiate not completed | |||||
| proto.connection_lost(True) | |||||
| self.assertFalse(loop_mock.call_soon.called) | |||||
| # negotiate successfully competed | |||||
| negotiate_fut.set_result(True) | |||||
| proto.connection_lost(True) | |||||
| self.assertTrue(loop_mock.call_soon.called) | |||||
| # negotiate failed | |||||
| negotiate_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = negotiate_fut | |||||
| proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
| proto.connection_made(mock.Mock()) | |||||
| negotiate_fut.set_exception(Exception()) | |||||
| proto.connection_lost(True) | |||||
| self.assertTrue(loop_mock.call_soon.called) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_pause_writing(self, ef_mock): | |||||
| negotiate_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = negotiate_fut | |||||
| app_proto = mock.Mock() | |||||
| loop_mock = mock.Mock() | |||||
| proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
| proto.connection_made(mock.Mock()) | |||||
| # negotiate not completed | |||||
| proto.pause_writing() | |||||
| self.assertFalse(app_proto.pause_writing.called) | |||||
| # negotiate successfully competed | |||||
| negotiate_fut.set_result(True) | |||||
| proto.pause_writing() | |||||
| self.assertTrue(app_proto.pause_writing.called) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_resume_writing(self, ef_mock): | |||||
| negotiate_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = negotiate_fut | |||||
| app_proto = mock.Mock() | |||||
| loop_mock = mock.Mock() | |||||
| proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
| proto.connection_made(mock.Mock()) | |||||
| # negotiate not completed | |||||
| with self.assertRaises(AssertionError): | |||||
| proto.resume_writing() | |||||
| # negotiate fail | |||||
| negotiate_fut.set_exception(Exception()) | |||||
| proto.resume_writing() | |||||
| self.assertTrue(app_proto.resume_writing.called) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_data_received(self, ef_mock): | |||||
| negotiate_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = negotiate_fut | |||||
| app_proto = mock.Mock() | |||||
| loop_mock = mock.Mock() | |||||
| proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
| proto.connection_made(mock.Mock()) | |||||
| # negotiate not completed | |||||
| proto.data_received(b'123') | |||||
| self.assertFalse(app_proto.data_received.called) | |||||
| # negotiate successfully competed | |||||
| negotiate_fut.set_result(True) | |||||
| proto.data_received(b'123') | |||||
| self.assertTrue(app_proto.data_received.called) | |||||
| @mock.patch('aiosocks.protocols.ensure_future') | |||||
| def test_eof_received(self, ef_mock): | |||||
| negotiate_fut = asyncio.Future(loop=self.loop) | |||||
| ef_mock.return_value = negotiate_fut | |||||
| app_proto = mock.Mock() | |||||
| loop_mock = mock.Mock() | |||||
| proto = make_base(loop_mock, ap_factory=lambda: app_proto) | |||||
| proto.connection_made(mock.Mock()) | |||||
| # negotiate not completed | |||||
| proto.eof_received() | |||||
| self.assertFalse(app_proto.eof_received.called) | |||||
| # negotiate successfully competed | |||||
| negotiate_fut.set_result(True) | |||||
| proto.eof_received() | |||||
| self.assertTrue(app_proto.eof_received.called) | |||||
| class TestSocks4Protocol(unittest.TestCase): | class TestSocks4Protocol(unittest.TestCase): | ||||
| def setUp(self): | def setUp(self): | ||||
| @@ -97,21 +286,27 @@ class TestSocks4Protocol(unittest.TestCase): | |||||
| dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks4Protocol(None, None, dst, loop=self.loop) | |||||
| aiosocks.Socks4Protocol(None, None, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop) | |||||
| aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, | aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst, | ||||
| loop=self.loop) | |||||
| loop=self.loop, waiter=None, | |||||
| app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, | aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst, | ||||
| loop=self.loop) | |||||
| loop=self.loop, waiter=None, | |||||
| app_protocol_factory=None) | |||||
| aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop) | |||||
| aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop) | |||||
| aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| def test_request_building(self): | def test_request_building(self): | ||||
| resp = b'\x00\x5a\x00P\x7f\x00\x00\x01' | resp = b'\x00\x5a\x00P\x7f\x00\x00\x01' | ||||
| @@ -230,21 +425,27 @@ class TestSocks5Protocol(unittest.TestCase): | |||||
| dst = ('python.org', 80) | dst = ('python.org', 80) | ||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks5Protocol(None, None, dst, loop=self.loop) | |||||
| aiosocks.Socks5Protocol(None, None, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop) | |||||
| aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), | aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'), | ||||
| auth, dst, loop=self.loop) | |||||
| auth, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| with self.assertRaises(ValueError): | with self.assertRaises(ValueError): | ||||
| aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), | aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'), | ||||
| dst, loop=self.loop) | |||||
| dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop) | |||||
| aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop) | |||||
| aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop, | |||||
| waiter=None, app_protocol_factory=None) | |||||
| def test_authenticate(self): | def test_authenticate(self): | ||||
| # invalid server version | # invalid server version | ||||