Browse Source

Added SSL support and extended test cases for aiosocks.create_connect

main
nibrag 8 years ago
parent
commit
4512eaac25
8 changed files with 717 additions and 110 deletions
  1. +19
    -17
      aiosocks/__init__.py
  2. +29
    -3
      aiosocks/connector.py
  3. +83
    -9
      aiosocks/protocols.py
  4. +0
    -0
      tests/__init__.py
  5. +25
    -0
      tests/socks_serv.py
  6. +25
    -34
      tests/test_connector.py
  7. +314
    -26
      tests/test_create_connect.py
  8. +222
    -21
      tests/test_protocol.py

+ 19
- 17
aiosocks/__init__.py View File

@@ -47,7 +47,15 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *,
"proxy is Socks5Addr but proxy_auth is not Socks5Auth"
)

if server_hostname is not None and not ssl:
raise ValueError('server_hostname is only meaningful with ssl')

if server_hostname is None and ssl:
# read details: asyncio.create_connection
server_hostname = dst[0]

loop = loop or asyncio.get_event_loop()
waiter = asyncio.Future(loop=loop)

def socks_factory():
if isinstance(proxy, Socks4Addr):
@@ -55,30 +63,24 @@ def create_connection(protocol_factory, proxy, proxy_auth, dst, *,
else:
socks_proto = Socks5Protocol

return socks_proto(
proxy=proxy, proxy_auth=proxy_auth, dst=dst,
remote_resolve=remote_resolve, loop=loop)
return socks_proto(proxy=proxy, proxy_auth=proxy_auth, dst=dst,
app_protocol_factory=protocol_factory,
waiter=waiter, remote_resolve=remote_resolve,
loop=loop, ssl=ssl, server_hostname=server_hostname)

try:
transport, protocol = yield from loop.create_connection(
socks_factory, proxy.host, proxy.port, ssl=ssl, family=family,
proto=proto, flags=flags, sock=sock, local_addr=local_addr,
server_hostname=server_hostname)
socks_factory, proxy.host, proxy.port, family=family,
proto=proto, flags=flags, sock=sock, local_addr=local_addr)
except OSError as exc:
raise SocksConnectionError(
'[Errno %s] Can not connect to proxy %s:%d [%s]' %
(exc.errno, proxy.host, proxy.port, exc.strerror)) from exc

# Wait until communication with proxy server is finished
try:
yield from protocol.negotiate_done()
except SocksError as exc:
raise SocksError('Can not connect to %s:%s [%s]' %
(dst[0], dst[1], exc))

if protocol_factory:
protocol = protocol_factory()
protocol.connection_made(transport)
transport._protocol = protocol
yield from waiter
except:
transport.close()
raise

return transport, protocol
return protocol.app_transport, protocol.app_protocol

+ 29
- 3
aiosocks/connector.py View File

@@ -33,12 +33,16 @@ class SocksConnector(aiohttp.TCPConnector):

@asyncio.coroutine
def _create_connection(self, req):
if req.ssl:
sslcontext = self.ssl_context
else:
sslcontext = None

if not self._remote_resolve:
dst_hosts = yield from self._resolve_host(req.host, req.port)
dst = dst_hosts[0]['host'], dst_hosts[0]['port']
else:
dst = req.host, req.port
exc = None

# if self._resolver is AsyncResolver and self._proxy.host
# is ip address, then aiodns raise DNSError.
@@ -56,6 +60,7 @@ class SocksConnector(aiohttp.TCPConnector):
except ValueError:
proxy_hosts = yield from self._resolve_host(self._proxy.host,
self._proxy.port)
exc = None

for hinfo in proxy_hosts:
try:
@@ -65,8 +70,29 @@ class SocksConnector(aiohttp.TCPConnector):
transp, proto = yield from create_connection(
self._factory, proxy, self._proxy_auth, dst,
loop=self._loop, remote_resolve=self._remote_resolve,
ssl=None, family=hinfo['family'], proto=hinfo['proto'],
flags=hinfo['flags'], local_addr=self._local_addr)
ssl=sslcontext, family=hinfo['family'],
proto=hinfo['proto'], flags=hinfo['flags'],
local_addr=self._local_addr,
server_hostname=req.host if sslcontext else None)

has_cert = transp.get_extra_info('sslcontext')
if has_cert and self._fingerprint:
sock = transp.get_extra_info('socket')
if not hasattr(sock, 'getpeercert'):
# Workaround for asyncio 3.5.0
# Starting from 3.5.1 version
# there is 'ssl_object' extra info in transport
sock = transp._ssl_protocol._sslpipe.ssl_object
# gives DER-encoded cert as a sequence of bytes (or None)
cert = sock.getpeercert(binary_form=True)
assert cert
got = self._hashfunc(cert).digest()
expected = self._fingerprint
if got != expected:
transp.close()
raise aiohttp.FingerprintMismatch(
expected, got, req.host, 80
)

return transp, proto
except (OSError, SocksError, SocksConnectionError) as e:


+ 83
- 9
aiosocks/protocols.py View File

@@ -17,7 +17,9 @@ except ImportError:


class BaseSocksProtocol(asyncio.StreamReaderProtocol):
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve=True, loop=None, ssl=False,
server_hostname=None):
if not isinstance(dst, (tuple, list)) or len(dst) != 2:
raise ValueError(
'Invalid dst format, tuple("dst_host", dst_port))'
@@ -30,18 +32,79 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):

self._loop = loop or asyncio.get_event_loop()
self._transport = None
self._negotiate_done = None
self._waiter = waiter
self._negotiate_fut = None
self._ssl = ssl
self._server_hostname = server_hostname

if app_protocol_factory:
self._app_protocol = app_protocol_factory()
else:
self._app_protocol = self

reader = asyncio.StreamReader(loop=self._loop)

super().__init__(stream_reader=reader, loop=self._loop)

def connection_made(self, transport):
# connection_made is called
if self._transport:
return

super().connection_made(transport)
self._transport = transport

def init_app_protocol(fut):
exc = fut.exception()
if exc:
if isinstance(exc, SocksError):
exc = SocksError('Can not connect to %s:%s. %s' %
(self._dst_host, self._dst_port, exc))
self._waiter.set_exception(exc)
else:
if self._ssl:
sock = self._transport.get_extra_info('socket')

self._transport = self._loop._make_ssl_transport(
rawsock=sock, protocol=self._app_protocol,
sslcontext=self._ssl, server_side=False,
server_hostname=self._server_hostname,
waiter=self._waiter)
else:
self._app_protocol.connection_made(transport)
self._waiter.set_result(True)

req_coro = self.socks_request(c.SOCKS_CMD_CONNECT)
self._negotiate_done = ensure_future(req_coro, loop=self._loop)
self._negotiate_fut = ensure_future(req_coro, loop=self._loop)
self._negotiate_fut.add_done_callback(init_app_protocol)

def connection_lost(self, exc):
if self._negotiate_fut.done() and not self._negotiate_fut.exception():
self._loop.call_soon(self._app_protocol.connection_lost, exc)
super().connection_lost(exc)

def pause_writing(self):
if self._negotiate_fut.done():
self._app_protocol.pause_writing()
else:
super().pause_writing()

def resume_writing(self):
if self._negotiate_fut.done():
self._app_protocol.resume_writing()
else:
super().resume_writing()

def data_received(self, data):
if self._negotiate_fut.done():
self._app_protocol.data_received(data)
else:
super().data_received(data)

def eof_received(self):
if self._negotiate_fut.done() and not self._negotiate_fut.exception():
self._app_protocol.eof_received()
super().eof_received()

@asyncio.coroutine
def socks_request(self, cmd):
@@ -74,12 +137,19 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):
raise OSError('getaddrinfo() returned empty list')
return infos[0][0], infos[0][4][0]

def negotiate_done(self):
return self._negotiate_done
@property
def app_protocol(self):
return self._app_protocol

@property
def app_transport(self):
return self._transport


class Socks4Protocol(BaseSocksProtocol):
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve=True, loop=None, ssl=False,
server_hostname=None):
proxy_auth = proxy_auth or Socks4Auth('')

if not isinstance(proxy, Socks4Addr):
@@ -88,7 +158,8 @@ class Socks4Protocol(BaseSocksProtocol):
if not isinstance(proxy_auth, Socks4Auth):
raise ValueError('Invalid proxy_auth format')

super().__init__(proxy, proxy_auth, dst, remote_resolve, loop)
super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve, loop, ssl, server_hostname)

@asyncio.coroutine
def socks_request(self, cmd):
@@ -130,7 +201,9 @@ class Socks4Protocol(BaseSocksProtocol):


class Socks5Protocol(BaseSocksProtocol):
def __init__(self, proxy, proxy_auth, dst, remote_resolve=True, loop=None):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve=True, loop=None, ssl=False,
server_hostname=None):
proxy_auth = proxy_auth or Socks5Auth('', '')

if not isinstance(proxy, Socks5Addr):
@@ -139,7 +212,8 @@ class Socks5Protocol(BaseSocksProtocol):
if not isinstance(proxy_auth, Socks5Auth):
raise ValueError('Invalid proxy_auth format')

super().__init__(proxy, proxy_auth, dst, remote_resolve, loop)
super().__init__(proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve, loop, ssl, server_hostname)

@asyncio.coroutine
def socks_request(self, cmd):


+ 0
- 0
tests/__init__.py View File


+ 25
- 0
tests/socks_serv.py View File

@@ -0,0 +1,25 @@
import asyncio
import socket
import functools


def find_unused_port():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 0))
port = s.getsockname()[1]
s.close()
return port


@asyncio.coroutine
def socks_handler(reader, writer, write_buff):
writer.write(write_buff)


@asyncio.coroutine
def fake_socks_srv(loop, write_buff):
port = find_unused_port()
handler = functools.partial(socks_handler, write_buff=write_buff)
srv = yield from asyncio.start_server(
handler, '127.0.0.1', port, family=socket.AF_INET, loop=loop)
return srv, port

+ 25
- 34
tests/test_connector.py View File

@@ -24,7 +24,12 @@ class TestSocksConnector(unittest.TestCase):

return mock.Mock(side_effect=coroutine(coro))

def test_connect_proxy_ip(self):
@mock.patch('aiosocks.connector.create_connection')
def test_connect_proxy_ip(self, cr_conn_mock):
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
cr_conn_mock.side_effect = \
self._fake_coroutine((tr, proto)).side_effect

loop_mock = mock.Mock()

req = ClientRequest('GET', 'http://python.org', loop=self.loop)
@@ -33,21 +38,18 @@ class TestSocksConnector(unittest.TestCase):

loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()])

tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
proto.negotiate_done = self._fake_coroutine(True)
loop_mock.create_connection = self._fake_coroutine((tr, proto))

conn = self.loop.run_until_complete(connector.connect(req))

self.assertTrue(loop_mock.getaddrinfo.is_called)
self.assertIs(conn._transport, tr)
self.assertTrue(
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol)
)

conn.close()

def test_connect_proxy_domain(self):
@mock.patch('aiosocks.connector.create_connection')
def test_connect_proxy_domain(self, cr_conn_mock):
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
cr_conn_mock.side_effect = \
self._fake_coroutine((tr, proto)).side_effect
loop_mock = mock.Mock()

req = ClientRequest('GET', 'http://python.org', loop=self.loop)
@@ -56,60 +58,53 @@ class TestSocksConnector(unittest.TestCase):

connector._resolve_host = self._fake_coroutine([mock.MagicMock()])

tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
proto.negotiate_done = self._fake_coroutine(True)
loop_mock.create_connection = self._fake_coroutine((tr, proto))

conn = self.loop.run_until_complete(connector.connect(req))

self.assertTrue(connector._resolve_host.is_called)
self.assertEqual(connector._resolve_host.call_count, 1)
self.assertIs(conn._transport, tr)
self.assertTrue(
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol)
)

conn.close()

def test_connect_locale_resolve(self):
loop_mock = mock.Mock()
@mock.patch('aiosocks.connector.create_connection')
def test_connect_locale_resolve(self, cr_conn_mock):
tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
cr_conn_mock.side_effect = \
self._fake_coroutine((tr, proto)).side_effect

req = ClientRequest('GET', 'http://python.org', loop=self.loop)
connector = SocksConnector(aiosocks.Socks5Addr('proxy.example'),
None, loop=loop_mock, remote_resolve=False)
None, loop=self.loop, remote_resolve=False)

connector._resolve_host = self._fake_coroutine([mock.MagicMock()])

tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
proto.negotiate_done = self._fake_coroutine(True)
loop_mock.create_connection = self._fake_coroutine((tr, proto))

conn = self.loop.run_until_complete(connector.connect(req))

self.assertTrue(connector._resolve_host.is_called)
self.assertEqual(connector._resolve_host.call_count, 2)
self.assertIs(conn._transport, tr)
self.assertTrue(
isinstance(conn._protocol, aiohttp.parsers.StreamProtocol)
)

conn.close()

def test_proxy_connect_fail(self):
@mock.patch('aiosocks.connector.create_connection')
def test_proxy_connect_fail(self, cr_conn_mock):
loop_mock = mock.Mock()
cr_conn_mock.side_effect = \
self._fake_coroutine(aiosocks.SocksConnectionError()).side_effect

req = ClientRequest('GET', 'http://python.org', loop=self.loop)
connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'),
None, loop=loop_mock)

loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()])
loop_mock.create_connection = self._fake_coroutine(OSError())

with self.assertRaises(aiohttp.ProxyConnectionError):
self.loop.run_until_complete(connector.connect(req))

def test_proxy_negotiate_fail(self):
@mock.patch('aiosocks.connector.create_connection')
def test_proxy_negotiate_fail(self, cr_conn_mock):
loop_mock = mock.Mock()
cr_conn_mock.side_effect = \
self._fake_coroutine(aiosocks.SocksError()).side_effect

req = ClientRequest('GET', 'http://python.org', loop=self.loop)
connector = SocksConnector(aiosocks.Socks5Addr('127.0.0.1'),
@@ -117,9 +112,5 @@ class TestSocksConnector(unittest.TestCase):

loop_mock.getaddrinfo = self._fake_coroutine([mock.MagicMock()])

tr, proto = mock.Mock(name='transport'), mock.Mock(name='protocol')
proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError())
loop_mock.create_connection = self._fake_coroutine((tr, proto))

with self.assertRaises(aiosocks.SocksError):
self.loop.run_until_complete(connector.connect(req))

+ 314
- 26
tests/test_create_connect.py View File

@@ -2,6 +2,7 @@ import unittest
import aiosocks
import asyncio
from unittest import mock
from .socks_serv import fake_socks_srv

try:
from asyncio import ensure_future
@@ -74,6 +75,15 @@ class TestCreateConnection(unittest.TestCase):
self.assertIn('proxy is Socks4Addr but proxy_auth is not Socks4Auth',
str(ct.exception))

# test ssl, server_hostname
with self.assertRaises(ValueError) as ct:
conn = aiosocks.create_connection(
None, addr, auth, dst, server_hostname='python.org'
)
self.loop.run_until_complete(conn)
self.assertIn('server_hostname is only meaningful with ssl',
str(ct.exception))

def test_connection_fail(self):
addr = aiosocks.Socks5Addr('localhost')
auth = aiosocks.Socks5Auth('usr', 'pwd')
@@ -88,45 +98,323 @@ class TestCreateConnection(unittest.TestCase):
)
self.loop.run_until_complete(conn)

def test_negotiate_fail(self):
addr = aiosocks.Socks5Addr('localhost')
auth = aiosocks.Socks5Auth('usr', 'pwd')

class TestCreateSocks4Connection(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()

def test_connect_success(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x00\x5a\x04W\x01\x01\x01\x01test')
)
addr = aiosocks.Socks4Addr('127.0.0.1', port)
auth = aiosocks.Socks4Auth('usr')
dst = ('python.org', 80)

transp, proto = mock.Mock(), mock.Mock()
proto.negotiate_done = self._fake_coroutine(aiosocks.SocksError())
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)

loop_mock = mock.Mock()
loop_mock.create_connection = self._fake_coroutine((transp, proto))
_, addr = protocol._negotiate_fut.result()
self.assertEqual(addr, ('1.1.1.1', 1111))

data = self.loop.run_until_complete(protocol._stream_reader.read(4))
self.assertEqual(data, b'test')

server.close()
transport.close()

def test_invalid_ver(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x01\x5a\x04W\x01\x01\x01\x01')
)
addr = aiosocks.Socks4Addr('127.0.0.1', port)
auth = aiosocks.Socks4Auth('usr')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
conn = aiosocks.create_connection(
None, addr, auth, dst, loop=loop_mock
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('invalid data', str(ct.exception))

server.close()

def test_access_not_granted(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x00\x5b\x04W\x01\x01\x01\x01')
)
addr = aiosocks.Socks4Addr('127.0.0.1', port)
auth = aiosocks.Socks4Auth('usr')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('0x5b', str(ct.exception))

server.close()


class TestCreateSocks5Connect(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)

def tearDown(self):
self.loop.close()

def test_connect_success_anonymous(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(
self.loop,
b'\x05\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest'
)
self.loop.run_until_complete(conn)
self.assertIn('Can not connect to python.org:80',
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)

_, addr = protocol._negotiate_fut.result()
self.assertEqual(addr, ('1.1.1.1', 1111))

data = self.loop.run_until_complete(protocol._stream_reader.read(4))
self.assertEqual(data, b'test')

server.close()
transport.close()

def test_connect_success_usr_pwd(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(
self.loop,
b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04Wtest'
)
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)

_, addr = protocol._negotiate_fut.result()
self.assertEqual(addr, ('1.1.1.1', 1111))

data = self.loop.run_until_complete(protocol._stream_reader.read(4))
self.assertEqual(data, b'test')

server.close()
transport.close()

def test_auth_ver_err(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x04\x02')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('invalid version', str(ct.exception))

server.close()

def test_auth_method_rejected(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\xFF')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('authentication methods were rejected',
str(ct.exception))

def test_create_protocol(self):
addr = aiosocks.Socks5Addr('localhost')
server.close()

def test_auth_status_invalid(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\xF0')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

transp, proto = mock.Mock(), mock.Mock()
proto.negotiate_done = self._fake_coroutine(True)
with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('invalid data', str(ct.exception))

loop_mock = mock.Mock()
loop_mock.create_connection = self._fake_coroutine((transp, proto))
server.close()

def test_auth_status_invalid2(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\x02\x02\x00')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('invalid data', str(ct.exception))

user_proto = mock.Mock()
server.close()

conn = aiosocks.create_connection(
lambda: user_proto, addr, auth, dst, loop=loop_mock
def test_auth_failed(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\x02\x01\x01')
)
fut = ensure_future(conn, loop=self.loop)
self.loop.run_until_complete(fut)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('authentication failed', str(ct.exception))

server.close()

def test_cmd_ver_err(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x04\x00\x00')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('invalid version', str(ct.exception))

server.close()

def test_cmd_not_granted(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x01\x00')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('General SOCKS server failure', str(ct.exception))

server.close()

def test_invalid_address_type(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(self.loop, b'\x05\x02\x01\x00\x05\x00\x00\xFF')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

with self.assertRaises(aiosocks.SocksError) as ct:
coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)
transport.close()
self.assertIn('invalid data', str(ct.exception))

server.close()

def test_atype_ipv4(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(
self.loop,
b'\x05\x02\x01\x00\x05\x00\x00\x01\x01\x01\x01\x01\x04W'
)
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)

_, addr = protocol._negotiate_fut.result()
self.assertEqual(addr, ('1.1.1.1', 1111))

transport.close()
server.close()

def test_atype_ipv6(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(
self.loop,
b'\x05\x02\x01\x00\x05\x00\x00\x04\x00\x00\x00\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x11\x04W')
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)

_, addr = protocol._negotiate_fut.result()
self.assertEqual(addr, ('::111', 1111))

transport.close()
server.close()

def test_atype_domain(self):
server, port = self.loop.run_until_complete(
fake_socks_srv(
self.loop,
b'\x05\x02\x01\x00\x05\x00\x00\x03\x0apython.org\x04W'
)
)
addr = aiosocks.Socks5Addr('127.0.0.1', port)
auth = aiosocks.Socks5Auth('usr', 'pwd')
dst = ('python.org', 80)

coro = aiosocks.create_connection(
None, addr, auth, dst, loop=self.loop)
transport, protocol = self.loop.run_until_complete(coro)

_, addr = protocol._negotiate_fut.result()
self.assertEqual(addr, (b'python.org', 1111))

transport, protocol = fut.result()
self.assertIs(transport, transp)
self.assertIs(protocol, user_proto)
self.assertIs(transport._protocol, user_proto)
transport.close()
server.close()

+ 222
- 21
tests/test_protocol.py View File

@@ -13,13 +13,24 @@ except ImportError:
ensure_future = asyncio.async


def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''):
def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None):
dst = dst or ('python.org', 80)

proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl,
loop=loop, waiter=waiter,
app_protocol_factory=ap_factory)
return proto


def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'',
ap_factory=None, whiter=None):
addr = addr or aiosocks.Socks4Addr('localhost', 1080)
auth = auth or aiosocks.Socks4Auth('user')
dst = dst or ('python.org', 80)

proto = aiosocks.Socks4Protocol(
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop)
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
proto._transport = mock.Mock()
proto.read_response = mock.Mock(
side_effect=coro(mock.Mock(return_value=r)))
@@ -30,13 +41,15 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b''):
return proto


def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None):
def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None,
ap_factory=None, whiter=None):
addr = addr or aiosocks.Socks5Addr('localhost', 1080)
auth = auth or aiosocks.Socks5Auth('user', 'pwd')
dst = dst or ('python.org', 80)

proto = aiosocks.Socks5Protocol(
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr, loop=loop)
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
proto._transport = mock.Mock()

if not isinstance(r, (list, tuple)):
@@ -63,17 +76,19 @@ class TestBaseSocksProtocol(unittest.TestCase):

def test_init(self):
with self.assertRaises(ValueError):
BaseSocksProtocol(None, None, None, loop=self.loop)
BaseSocksProtocol(None, None, None, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
BaseSocksProtocol(None, None, 123, loop=self.loop)
BaseSocksProtocol(None, None, 123, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
BaseSocksProtocol(None, None, ('python.org',), loop=self.loop)
BaseSocksProtocol(None, None, ('python.org',), loop=self.loop,
waiter=None, app_protocol_factory=None)

def test_write_request(self):
proto = BaseSocksProtocol(None, None, ('python.org', 80),
loop=self.loop)
proto = make_base(self.loop)
proto._transport = mock.Mock()

proto.write_request([b'\x00', b'\x01\x02', 0x03])
@@ -82,6 +97,180 @@ class TestBaseSocksProtocol(unittest.TestCase):
with self.assertRaises(ValueError):
proto.write_request(['\x00'])

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_os_error(self, ef_mock):
os_err_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = os_err_fut

waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter)
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, os_err_fut)

with self.assertRaises(OSError):
os_err_fut.set_exception(OSError('test'))
self.loop.run_until_complete(os_err_fut)
self.assertIn('test', str(waiter.exception()))

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_socks_err(self, ef_mock):
socks_err_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = socks_err_fut

waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter)
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, socks_err_fut)

with self.assertRaises(aiosocks.SocksError):
socks_err_fut.set_exception(aiosocks.SocksError('test'))
self.loop.run_until_complete(socks_err_fut)
self.assertIn('Can not connect to', str(waiter.exception()))

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_without_app_proto(self, ef_mock):
success_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = success_fut

waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter)
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, success_fut)

success_fut.set_result(True)
self.loop.run_until_complete(success_fut)
self.assertTrue(waiter.done())

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_with_app_proto(self, ef_mock):
success_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = success_fut

waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter,
ap_factory=lambda: asyncio.Protocol())
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, success_fut)

success_fut.set_result(True)
self.loop.run_until_complete(success_fut)
self.assertTrue(waiter.done())

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_lost(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()

loop_mock = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto.connection_lost(True)
self.assertFalse(loop_mock.call_soon.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
proto.connection_lost(True)
self.assertTrue(loop_mock.call_soon.called)

# negotiate failed
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

negotiate_fut.set_exception(Exception())
proto.connection_lost(True)
self.assertTrue(loop_mock.call_soon.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_pause_writing(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()

loop_mock = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto.pause_writing()
self.assertFalse(app_proto.pause_writing.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
proto.pause_writing()
self.assertTrue(app_proto.pause_writing.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_resume_writing(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()

loop_mock = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
with self.assertRaises(AssertionError):
proto.resume_writing()

# negotiate fail
negotiate_fut.set_exception(Exception())
proto.resume_writing()
self.assertTrue(app_proto.resume_writing.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_data_received(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()

loop_mock = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto.data_received(b'123')
self.assertFalse(app_proto.data_received.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
proto.data_received(b'123')
self.assertTrue(app_proto.data_received.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_eof_received(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()

loop_mock = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto.eof_received()
self.assertFalse(app_proto.eof_received.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
proto.eof_received()
self.assertTrue(app_proto.eof_received.called)


class TestSocks4Protocol(unittest.TestCase):
def setUp(self):
@@ -97,21 +286,27 @@ class TestSocks4Protocol(unittest.TestCase):
dst = ('python.org', 80)

with self.assertRaises(ValueError):
aiosocks.Socks4Protocol(None, None, dst, loop=self.loop)
aiosocks.Socks4Protocol(None, None, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop)
aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst,
loop=self.loop)
loop=self.loop, waiter=None,
app_protocol_factory=None)

with self.assertRaises(ValueError):
aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst,
loop=self.loop)
loop=self.loop, waiter=None,
app_protocol_factory=None)

aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop)
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop)
aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

def test_request_building(self):
resp = b'\x00\x5a\x00P\x7f\x00\x00\x01'
@@ -230,21 +425,27 @@ class TestSocks5Protocol(unittest.TestCase):
dst = ('python.org', 80)

with self.assertRaises(ValueError):
aiosocks.Socks5Protocol(None, None, dst, loop=self.loop)
aiosocks.Socks5Protocol(None, None, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop)
aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'),
auth, dst, loop=self.loop)
auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

with self.assertRaises(ValueError):
aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'),
dst, loop=self.loop)
dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop)
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop)
aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

def test_authenticate(self):
# invalid server version


Loading…
Cancel
Save