diff --git a/aiosocks/__init__.py b/aiosocks/__init__.py index 9aa27a5..35bf7fa 100644 --- a/aiosocks/__init__.py +++ b/aiosocks/__init__.py @@ -6,7 +6,7 @@ from .errors import ( from .helpers import ( SocksAddr, Socks4Addr, Socks5Addr, Socks4Auth, Socks5Auth ) -from .protocols import Socks4Protocol, Socks5Protocol +from .protocols import Socks4Protocol, Socks5Protocol, DEFAULT_LIMIT __version__ = '0.1.3' @@ -21,7 +21,7 @@ __all__ = ('Socks4Protocol', 'Socks5Protocol', 'Socks4Auth', def create_connection(protocol_factory, proxy, proxy_auth, dst, *, remote_resolve=True, loop=None, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, - server_hostname=None): + server_hostname=None, reader_limit=DEFAULT_LIMIT): assert isinstance(proxy, SocksAddr), ( 'proxy must be Socks4Addr() or Socks5Addr() tuple' ) @@ -66,7 +66,8 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *, return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst, app_protocol_factory=protocol_factory, waiter=waiter, remote_resolve=remote_resolve, - loop=loop, ssl=ssl, server_hostname=server_hostname) + loop=loop, ssl=ssl, server_hostname=server_hostname, + reader_limit=reader_limit) try: transport, protocol = yield from loop.create_connection( diff --git a/aiosocks/protocols.py b/aiosocks/protocols.py index 6bbaff5..91e2463 100644 --- a/aiosocks/protocols.py +++ b/aiosocks/protocols.py @@ -16,10 +16,14 @@ except ImportError: ensure_future = asyncio.async +DEFAULT_LIMIT = getattr(asyncio.streams, '_DEFAULT_LIMIT', 2**16) + + class BaseSocksProtocol(asyncio.StreamReaderProtocol): - def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, + def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, *, remote_resolve=True, loop=None, ssl=False, - server_hostname=None, negotiate_done_cb=None): + server_hostname=None, negotiate_done_cb=None, + reader_limit=DEFAULT_LIMIT): if not isinstance(dst, (tuple, list)) or len(dst) != 2: raise ValueError( 'Invalid dst format, tuple("dst_host", dst_port))' @@ -45,7 +49,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): else: self._app_protocol = self - reader = asyncio.StreamReader(loop=self._loop) + reader = asyncio.StreamReader(loop=self._loop, limit=reader_limit) super().__init__(stream_reader=reader, client_connected_cb=self.negotiate, loop=self._loop) @@ -189,11 +193,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): """ return self._proxy_peername + @property + def reader(self): + return self._stream_reader + + @property + def writer(self): + return self._stream_writer + class Socks4Protocol(BaseSocksProtocol): def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, remote_resolve=True, loop=None, ssl=False, - server_hostname=None): + server_hostname=None, reader_limit=DEFAULT_LIMIT): proxy_auth = proxy_auth or Socks4Auth('') if not isinstance(proxy, Socks4Addr): @@ -202,8 +214,10 @@ class Socks4Protocol(BaseSocksProtocol): if not isinstance(proxy_auth, Socks4Auth): raise ValueError('Invalid proxy_auth format') - super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, - remote_resolve, loop, ssl, server_hostname) + super().__init__(proxy, proxy_auth, dst, app_protocol_factory, + waiter, remote_resolve=remote_resolve, loop=loop, + ssl=ssl, server_hostname=server_hostname, + reader_limit=reader_limit) @asyncio.coroutine def socks_request(self, cmd): @@ -247,7 +261,7 @@ class Socks4Protocol(BaseSocksProtocol): class Socks5Protocol(BaseSocksProtocol): def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, remote_resolve=True, loop=None, ssl=False, - server_hostname=None): + server_hostname=None, reader_limit=DEFAULT_LIMIT): proxy_auth = proxy_auth or Socks5Auth('', '') if not isinstance(proxy, Socks5Addr): @@ -256,8 +270,10 @@ class Socks5Protocol(BaseSocksProtocol): if not isinstance(proxy_auth, Socks5Auth): raise ValueError('Invalid proxy_auth format') - super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter, - remote_resolve, loop, ssl, server_hostname) + super().__init__(proxy, proxy_auth, dst, app_protocol_factory, + waiter, remote_resolve=remote_resolve, loop=loop, + ssl=ssl, server_hostname=server_hostname, + reader_limit=reader_limit) @asyncio.coroutine def socks_request(self, cmd): diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 036385d..4b2d8ab 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -298,6 +298,17 @@ class TestBaseSocksProtocol(unittest.TestCase): self.assertTrue(proto._negotiate_done_cb.called) self.assertTrue(task_mock.called) + def test_reader_limit(self): + proto = BaseSocksProtocol(None, None, ('python.org', 80), + None, None, reader_limit=10, + loop=self.loop) + self.assertEqual(proto.reader._limit, 10) + + proto = BaseSocksProtocol(None, None, ('python.org', 80), + None, None, reader_limit=15, + loop=self.loop) + self.assertEqual(proto.reader._limit, 15) + class TestSocks4Protocol(unittest.TestCase): def setUp(self):