diff --git a/tests/test_connector.py b/tests/test_connector.py new file mode 100644 index 0000000..81fe6ad --- /dev/null +++ b/tests/test_connector.py @@ -0,0 +1,119 @@ +import unittest +import asyncio +import aiosocks +import aiohttp +from unittest import mock +from asyncio import coroutine +from aiohttp.client_reqrep import ClientRequest +from aiosocks.connector import SocksConnector + + +class TestSocksConnector(unittest.TestCase): + def setUp(self): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def _fake_coroutine(self, return_value): + def coro(*args, **kwargs): + if isinstance(return_value, Exception): + raise return_value + return return_value + + return mock.Mock(side_effect=coroutine(coro)) + + def test_connect_proxy_ip(self): + loop_mock = mock.Mock() + + req = ClientRequest('GET', 'http://python.org', loop=self.loop) + connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), + None, loop=loop_mock) + + loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) + + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + proto.negotiate_done = self._fake_coroutine(True) + loop_mock.create_connection = self._fake_coroutine((tr, proto)) + + conn = self.loop.run_until_complete(connector.connect(req)) + + self.assertTrue(loop_mock.getaddrinfo.is_called) + self.assertIs(conn._transport, tr) + self.assertTrue(isinstance(conn._protocol, aiohttp.parsers.StreamProtocol)) + + conn.close() + + def test_connect_proxy_domain(self): + loop_mock = mock.Mock() + + req = ClientRequest('GET', 'http://python.org', loop=self.loop) + connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), + None, loop=loop_mock) + + connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) + + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + proto.negotiate_done = self._fake_coroutine(True) + loop_mock.create_connection = self._fake_coroutine((tr, proto)) + + conn = self.loop.run_until_complete(connector.connect(req)) + + self.assertTrue(connector._resolve_host.is_called) + self.assertEqual(connector._resolve_host.call_count, 1) + self.assertIs(conn._transport, tr) + self.assertTrue(isinstance(conn._protocol, aiohttp.parsers.StreamProtocol)) + + conn.close() + + def test_connect_locale_resolve(self): + loop_mock = mock.Mock() + + req = ClientRequest('GET', 'http://python.org', loop=self.loop) + connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'), + None, loop=loop_mock, remote_resolve=False) + + connector._resolve_host = self._fake_coroutine([mock.MagicMock()]) + + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + proto.negotiate_done = self._fake_coroutine(True) + loop_mock.create_connection = self._fake_coroutine((tr, proto)) + + conn = self.loop.run_until_complete(connector.connect(req)) + + self.assertTrue(connector._resolve_host.is_called) + self.assertEqual(connector._resolve_host.call_count, 2) + self.assertIs(conn._transport, tr) + self.assertTrue(isinstance(conn._protocol, aiohttp.parsers.StreamProtocol)) + + conn.close() + + def test_proxy_connect_fail(self): + loop_mock = mock.Mock() + + req = ClientRequest('GET', 'http://python.org', loop=self.loop) + connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), + None, loop=loop_mock) + + loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) + loop_mock.create_connection = self._fake_coroutine(OSError()) + + with self.assertRaises(aiohttp.ProxyConnectionError): + self.loop.run_until_complete(connector.connect(req)) + + def test_proxy_negotiate_fail(self): + loop_mock = mock.Mock() + + req = ClientRequest('GET', 'http://python.org', loop=self.loop) + connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'), + None, loop=loop_mock) + + loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()]) + + tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol') + proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError()) + loop_mock.create_connection = self._fake_coroutine((tr, proto)) + + with self.assertRaises(aiosocks.SocksError): + self.loop.run_until_complete(connector.connect(req))