import unittest import asyncio import aiosocks import aiohttp from unittest import mock from asyncio import coroutine from aiohttp.client_reqrep import ClientRequest from aiosocks.connector import SocksConnector class TestSocksConnector(unittest.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(None) def tearDown(self): self.loop.close() def _fake_coroutine(self, return_value): def coro(*args, **kwargs): if isinstance(return_value, Exception): raise return_value return return_value return mock.Mock(side_effect=coroutine(coro)) def test_connect_proxy_ip(self): loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), None, loop=loop_mock) loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') proto.negotiate_done = self._fake_coroutine(True) loop_mock.create_connection = self._fake_coroutine((tr, proto)) conn = self.loop.run_until_complete(connector.connect(req)) self.assertTrue(loop_mock.getaddrinfo.is_called) self.assertIs(conn._transport, tr) self.assertTrue( isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) ) conn.close() def test_connect_proxy_domain(self): loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), None, loop=loop_mock) connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') proto.negotiate_done = self._fake_coroutine(True) loop_mock.create_connection = self._fake_coroutine((tr, proto)) conn = self.loop.run_until_complete(connector.connect(req)) self.assertTrue(connector._resolve_host.is_called) self.assertEqual(connector._resolve_host.call_count, 1) self.assertIs(conn._transport, tr) self.assertTrue( isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) ) conn.close() def test_connect_locale_resolve(self): loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), None, loop=loop_mock, remote_resolve=False) connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') proto.negotiate_done = self._fake_coroutine(True) loop_mock.create_connection = self._fake_coroutine((tr, proto)) conn = self.loop.run_until_complete(connector.connect(req)) self.assertTrue(connector._resolve_host.is_called) self.assertEqual(connector._resolve_host.call_count, 2) self.assertIs(conn._transport, tr) self.assertTrue( isinstance(conn._protocol, aiohttp.parsers.StreamProtocol) ) conn.close() def test_proxy_connect_fail(self): loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), None, loop=loop_mock) loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) loop_mock.create_connection = self._fake_coroutine(OSError()) with self.assertRaises(aiohttp.ProxyConnectionError): self.loop.run_until_complete(connector.connect(req)) def test_proxy_negotiate_fail(self): loop_mock = mock.Mock() req = ClientRequest('GET', 'http://python.org', loop=self.loop) connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), None, loop=loop_mock) loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) loop_mock.create_connection = self._fake_coroutine((tr, proto)) with self.assertRaises(aiosocks.SocksError): self.loop.run_until_complete(connector.connect(req))