Browse Source

Added limit param for SocksReader

main
nibrag 8 years ago
parent
commit
e7892c500c
3 changed files with 40 additions and 12 deletions
  1. +4
    -3
      aiosocks/__init__.py
  2. +25
    -9
      aiosocks/protocols.py
  3. +11
    -0
      tests/test_protocols.py

+ 4
- 3
aiosocks/__init__.py View File

@@ -6,7 +6,7 @@ from .errors import (
from .helpers import (
SocksAddr, Socks4Addr, Socks5Addr, Socks4Auth, Socks5Auth
)
from .protocols import Socks4Protocol, Socks5Protocol
from .protocols import Socks4Protocol, Socks5Protocol, DEFAULT_LIMIT

__version__ = '0.1.3'

@@ -21,7 +21,7 @@ __all__ = ('Socks4Protocol', 'Socks5Protocol', 'Socks4Auth',
def create_connection(protocol_factory, proxy, proxy_auth, dst, *,
remote_resolve=True, loop=None, ssl=None, family=0,
proto=0, flags=0, sock=None, local_addr=None,
server_hostname=None):
server_hostname=None, reader_limit=DEFAULT_LIMIT):
assert isinstance(proxy, SocksAddr), (
'proxy must be Socks4Addr() or Socks5Addr() tuple'
)
@@ -66,7 +66,8 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *,
return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst,
app_protocol_factory=protocol_factory,
waiter=waiter, remote_resolve=remote_resolve,
loop=loop, ssl=ssl, server_hostname=server_hostname)
loop=loop, ssl=ssl, server_hostname=server_hostname,
reader_limit=reader_limit)

try:
transport, protocol = yield from loop.create_connection(


+ 25
- 9
aiosocks/protocols.py View File

@@ -16,10 +16,14 @@ except ImportError:
ensure_future = asyncio.async


DEFAULT_LIMIT = getattr(asyncio.streams, '_DEFAULT_LIMIT', 2**16)


class BaseSocksProtocol(asyncio.StreamReaderProtocol):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, *,
remote_resolve=True, loop=None, ssl=False,
server_hostname=None, negotiate_done_cb=None):
server_hostname=None, negotiate_done_cb=None,
reader_limit=DEFAULT_LIMIT):
if not isinstance(dst, (tuple, list)) or len(dst) != 2:
raise ValueError(
'Invalid dst format, tuple("dst_host", dst_port))'
@@ -45,7 +49,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):
else:
self._app_protocol = self

reader = asyncio.StreamReader(loop=self._loop)
reader = asyncio.StreamReader(loop=self._loop, limit=reader_limit)

super().__init__(stream_reader=reader,
client_connected_cb=self.negotiate, loop=self._loop)
@@ -189,11 +193,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):
"""
return self._proxy_peername

@property
def reader(self):
return self._stream_reader

@property
def writer(self):
return self._stream_writer


class Socks4Protocol(BaseSocksProtocol):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve=True, loop=None, ssl=False,
server_hostname=None):
server_hostname=None, reader_limit=DEFAULT_LIMIT):
proxy_auth = proxy_auth or Socks4Auth('')

if not isinstance(proxy, Socks4Addr):
@@ -202,8 +214,10 @@ class Socks4Protocol(BaseSocksProtocol):
if not isinstance(proxy_auth, Socks4Auth):
raise ValueError('Invalid proxy_auth format')

super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve, loop, ssl, server_hostname)
super().__init__(proxy, proxy_auth, dst, app_protocol_factory,
waiter, remote_resolve=remote_resolve, loop=loop,
ssl=ssl, server_hostname=server_hostname,
reader_limit=reader_limit)

@asyncio.coroutine
def socks_request(self, cmd):
@@ -247,7 +261,7 @@ class Socks4Protocol(BaseSocksProtocol):
class Socks5Protocol(BaseSocksProtocol):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve=True, loop=None, ssl=False,
server_hostname=None):
server_hostname=None, reader_limit=DEFAULT_LIMIT):
proxy_auth = proxy_auth or Socks5Auth('', '')

if not isinstance(proxy, Socks5Addr):
@@ -256,8 +270,10 @@ class Socks5Protocol(BaseSocksProtocol):
if not isinstance(proxy_auth, Socks5Auth):
raise ValueError('Invalid proxy_auth format')

super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve, loop, ssl, server_hostname)
super().__init__(proxy, proxy_auth, dst, app_protocol_factory,
waiter, remote_resolve=remote_resolve, loop=loop,
ssl=ssl, server_hostname=server_hostname,
reader_limit=reader_limit)

@asyncio.coroutine
def socks_request(self, cmd):


+ 11
- 0
tests/test_protocols.py View File

@@ -298,6 +298,17 @@ class TestBaseSocksProtocol(unittest.TestCase):
self.assertTrue(proto._negotiate_done_cb.called)
self.assertTrue(task_mock.called)

def test_reader_limit(self):
proto = BaseSocksProtocol(None, None, ('python.org', 80),
None, None, reader_limit=10,
loop=self.loop)
self.assertEqual(proto.reader._limit, 10)

proto = BaseSocksProtocol(None, None, ('python.org', 80),
None, None, reader_limit=15,
loop=self.loop)
self.assertEqual(proto.reader._limit, 15)


class TestSocks4Protocol(unittest.TestCase):
def setUp(self):


Loading…
Cancel
Save