Browse Source

Use helpers.fake_coroutine. Added test cases for a new way to negitiate with socks server

main
nibrag 8 years ago
parent
commit
9c0894c045
1 changed files with 174 additions and 169 deletions
  1. +174
    -169
      tests/test_protocols.py

tests/test_protocol.py → tests/test_protocols.py View File

@@ -6,6 +6,7 @@ from unittest import mock
from asyncio import coroutine as coro
import aiosocks.constants as c
from aiosocks.protocols import BaseSocksProtocol
from .helpers import fake_coroutine

try:
from asyncio import ensure_future
@@ -31,7 +32,7 @@ def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'',
proto = aiosocks.Socks4Protocol(
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
proto._transport = mock.Mock()
proto._stream_writer = mock.Mock()
proto.read_response = mock.Mock(
side_effect=coro(mock.Mock(return_value=r)))
proto._get_dst_addr = mock.Mock(
@@ -50,7 +51,7 @@ def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None,
proto = aiosocks.Socks5Protocol(
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
proto._transport = mock.Mock()
proto._stream_writer = mock.Mock()

if not isinstance(r, (list, tuple)):
proto.read_response = mock.Mock(
@@ -89,187 +90,197 @@ class TestBaseSocksProtocol(unittest.TestCase):

def test_write_request(self):
proto = make_base(self.loop)
proto._transport = mock.Mock()
proto._stream_writer = mock.Mock()

proto.write_request([b'\x00', b'\x01\x02', 0x03])
proto._transport.write.assert_called_with(b'\x00\x01\x02\x03')
proto._stream_writer.write.assert_called_with(b'\x00\x01\x02\x03')

with self.assertRaises(ValueError):
proto.write_request(['\x00'])

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_os_error(self, ef_mock):
os_err_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = os_err_fut

def test_negotiate_os_error(self):
waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter)
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, os_err_fut)
proto.socks_request = fake_coroutine(OSError('test'))

with self.assertRaises(OSError):
os_err_fut.set_exception(OSError('test'))
self.loop.run_until_complete(os_err_fut)
self.loop.run_until_complete(proto.negotiate(None, None))
self.assertIn('test', str(waiter.exception()))

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_socks_err(self, ef_mock):
socks_err_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = socks_err_fut

def test_negotiate_socks_err(self):
waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter)
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, socks_err_fut)
proto.socks_request = fake_coroutine(aiosocks.SocksError('test'))

with self.assertRaises(aiosocks.SocksError):
socks_err_fut.set_exception(aiosocks.SocksError('test'))
self.loop.run_until_complete(socks_err_fut)
self.loop.run_until_complete(proto.negotiate(None, None))
self.assertIn('Can not connect to', str(waiter.exception()))

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_without_app_proto(self, ef_mock):
success_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = success_fut

def test_negotiate_without_app_proto(self):
waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter)
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, success_fut)
proto.socks_request = fake_coroutine((None, None))
proto._transport = True

success_fut.set_result(True)
self.loop.run_until_complete(success_fut)
self.loop.run_until_complete(proto.negotiate(None, None))
self.assertTrue(waiter.done())

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_made_with_app_proto(self, ef_mock):
success_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = success_fut

def test_negotiate_with_app_proto(self):
waiter = asyncio.Future(loop=self.loop)
proto = make_base(self.loop, waiter=waiter,
ap_factory=lambda: asyncio.Protocol())
proto.connection_made(mock.Mock())

self.assertIs(proto._negotiate_fut, success_fut)
proto.socks_request = fake_coroutine((None, None))

success_fut.set_result(True)
self.loop.run_until_complete(success_fut)
self.loop.run_until_complete(proto.negotiate(None, None))
self.assertTrue(waiter.done())

@mock.patch('aiosocks.protocols.ensure_future')
def test_connection_lost(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()

def test_connection_lost(self):
loop_mock = mock.Mock()
app_proto = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto._negotiate_done = False
proto.connection_lost(True)
self.assertFalse(loop_mock.call_soon.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
loop_mock.reset_mock()
proto._negotiate_done = True
proto.connection_lost(True)
self.assertTrue(loop_mock.call_soon.called)

# negotiate failed
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

negotiate_fut.set_exception(Exception())
# don't call connect_lost, if app_protocol == self
# otherwise recursion
loop_mock.reset_mock()
proto = make_base(loop_mock, ap_factory=None)
proto._negotiate_done = True
proto.connection_lost(True)
self.assertTrue(loop_mock.call_soon.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_pause_writing(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()
self.assertFalse(loop_mock.call_soon.called)

def test_pause_writing(self):
loop_mock = mock.Mock()
app_proto = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto._negotiate_done = False
proto.pause_writing()
self.assertFalse(app_proto.pause_writing.called)
self.assertFalse(proto._app_protocol.pause_writing.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
app_proto.reset_mock()
proto._negotiate_done = True
proto.pause_writing()
self.assertTrue(app_proto.pause_writing.called)
self.assertTrue(proto._app_protocol.pause_writing.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_resume_writing(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()
# don't call pause_writing, if app_protocol == self
# otherwise recursion
app_proto.reset_mock()
proto = make_base(loop_mock)
proto._negotiate_done = True
proto.pause_writing()

def test_resume_writing(self):
loop_mock = mock.Mock()
app_proto = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto._negotiate_done = False
# negotiate not completed
with self.assertRaises(AssertionError):
proto.resume_writing()
self.assertFalse(proto._app_protocol.resume_writing.called)

# negotiate fail
negotiate_fut.set_exception(Exception())
# negotiate successfully competed
loop_mock.reset_mock()
proto._negotiate_done = True
proto.resume_writing()
self.assertTrue(app_proto.resume_writing.called)
self.assertTrue(proto._app_protocol.resume_writing.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_data_received(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()
# don't call resume_writing, if app_protocol == self
# otherwise recursion
loop_mock.reset_mock()
proto = make_base(loop_mock)
proto._negotiate_done = True
with self.assertRaises(AssertionError):
proto.resume_writing()

def test_data_received(self):
loop_mock = mock.Mock()
app_proto = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto._negotiate_done = False
proto.data_received(b'123')
self.assertFalse(app_proto.data_received.called)
self.assertFalse(proto._app_protocol.data_received.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
app_proto.reset_mock()
proto._negotiate_done = True
proto.data_received(b'123')
self.assertTrue(app_proto.data_received.called)
self.assertTrue(proto._app_protocol.data_received.called)

@mock.patch('aiosocks.protocols.ensure_future')
def test_eof_received(self, ef_mock):
negotiate_fut = asyncio.Future(loop=self.loop)
ef_mock.return_value = negotiate_fut
app_proto = mock.Mock()
# don't call data_received, if app_protocol == self
# otherwise recursion
loop_mock.reset_mock()
proto = make_base(loop_mock)
proto._negotiate_done = True
proto.data_received(b'123')

def test_eof_received(self):
loop_mock = mock.Mock()
app_proto = mock.Mock()

proto = make_base(loop_mock, ap_factory=lambda: app_proto)
proto.connection_made(mock.Mock())

# negotiate not completed
proto._negotiate_done = False
proto.eof_received()
self.assertFalse(app_proto.eof_received.called)
self.assertFalse(proto._app_protocol.eof_received.called)

# negotiate successfully competed
negotiate_fut.set_result(True)
app_proto.reset_mock()
proto._negotiate_done = True
proto.eof_received()
self.assertTrue(proto._app_protocol.eof_received.called)

# don't call pause_writing, if app_protocol == self
# otherwise recursion
app_proto.reset_mock()
proto = make_base(loop_mock)
proto._negotiate_done = True
proto.eof_received()
self.assertTrue(app_proto.eof_received.called)

@mock.patch('aiosocks.protocols.asyncio.Task')
def test_func_negotiate_cb_call(self, task_mock):
loop_mock = mock.Mock()
waiter = mock.Mock()

proto = make_base(loop_mock, waiter=waiter)
proto.socks_request = fake_coroutine((None, None))
proto._negotiate_done_cb = mock.Mock()

self.loop.run_until_complete(proto.negotiate(None, None))
self.assertTrue(proto._negotiate_done_cb.called)
self.assertFalse(task_mock.called)

@mock.patch('aiosocks.protocols.asyncio.Task')
def test_coro_negotiate_cb_call(self, task_mock):
loop_mock = mock.Mock()
waiter = mock.Mock()

proto = make_base(loop_mock, waiter=waiter)
proto.socks_request = fake_coroutine((None, None))
proto._negotiate_done_cb = fake_coroutine(None)

self.loop.run_until_complete(proto.negotiate(None, None))
self.assertTrue(proto._negotiate_done_cb.called)
self.assertTrue(task_mock.called)


class TestSocks4Protocol(unittest.TestCase):
@@ -308,92 +319,89 @@ class TestSocks4Protocol(unittest.TestCase):
aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

def test_request_building(self):
resp = b'\x00\x5a\x00P\x7f\x00\x00\x01'

# dst = domain, remote resolve = true
proto = make_socks4(self.loop, dst=('python.org', 80), r=resp)
def test_dst_domain_with_remote_resolve(self):
proto = make_socks4(self.loop, dst=('python.org', 80),
r=b'\x00\x5a\x00P\x7f\x00\x00\x01')

req = proto.socks_request(c.SOCKS_CMD_CONNECT)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04\x01\x00P\x00\x00\x00\x01user\x00python.org\x00'
)

# dst = domain, remote resolve = false
def test_dst_domain_with_local_resolve(self):
proto = make_socks4(self.loop, dst=('python.org', 80),
rr=False, r=resp)
rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01')

req = proto.socks_request(c.SOCKS_CMD_CONNECT)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04\x01\x00P\x7f\x00\x00\x01user\x00'
)

# dst = ip, remote resolve = true
proto = make_socks4(self.loop, dst=('127.0.0.1', 8800), r=resp)
def test_dst_ip_with_remote_resolve(self):
proto = make_socks4(self.loop, dst=('127.0.0.1', 8800),
r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
req = proto.socks_request(c.SOCKS_CMD_CONNECT)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04\x01"`\x7f\x00\x00\x01user\x00'
)

# dst = ip, remote resolve = false
def test_dst_ip_with_locale_resolve(self):
proto = make_socks4(self.loop, dst=('127.0.0.1', 8800),
rr=False, r=resp)
rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01')

req = proto.socks_request(c.SOCKS_CMD_CONNECT)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04\x01"`\x7f\x00\x00\x01user\x00'
)

# dst = domain, without user
def test_dst_domain_without_user(self):
proto = make_socks4(self.loop, auth=aiosocks.Socks4Auth(''),
dst=('python.org', 80), r=resp)
dst=('python.org', 80),
r=b'\x00\x5a\x00P\x7f\x00\x00\x01')

req = proto.socks_request(c.SOCKS_CMD_CONNECT)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04\x01\x00P\x00\x00\x00\x01\x00python.org\x00'
)

# dst = ip, without user
def test_dst_ip_without_user(self):
proto = make_socks4(self.loop, auth=aiosocks.Socks4Auth(''),
dst=('127.0.0.1', 8800), r=resp)
dst=('127.0.0.1', 8800),
r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
req = proto.socks_request(c.SOCKS_CMD_CONNECT)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04\x01"`\x7f\x00\x00\x01\x00'
)

def test_response_handling(self):
valid_resp = b'\x00\x5a\x00P\x7f\x00\x00\x01'
invalid_data_resp = b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'
socks_err_resp = b'\x00\x5b\x00P\x7f\x00\x00\x01'
socks_err_unk_resp = b'\x00\x5e\x00P\x7f\x00\x00\x01'

# valid result
proto = make_socks4(self.loop, r=valid_resp)
def test_valid_resp_handling(self):
proto = make_socks4(self.loop, r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
req = ensure_future(
proto.socks_request(c.SOCKS_CMD_CONNECT), loop=self.loop)
self.loop.run_until_complete(req)

self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80)))

# invalid server reply
proto = make_socks4(self.loop, r=invalid_data_resp)
def test_invalid_reply_resp_handling(self):
proto = make_socks4(self.loop, r=b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF')
req = proto.socks_request(c.SOCKS_CMD_CONNECT)

with self.assertRaises(aiosocks.InvalidServerReply):
self.loop.run_until_complete(req)

# socks server sent error
proto = make_socks4(self.loop, r=socks_err_resp)
def test_socks_err_resp_handling(self):
proto = make_socks4(self.loop, r=b'\x00\x5b\x00P\x7f\x00\x00\x01')
req = proto.socks_request(c.SOCKS_CMD_CONNECT)

with self.assertRaises(aiosocks.SocksError) as cm:
@@ -401,8 +409,8 @@ class TestSocks4Protocol(unittest.TestCase):

self.assertTrue('0x5b' in str(cm.exception))

# socks server send unknown error
proto = make_socks4(self.loop, r=socks_err_unk_resp)
def test_unknown_err_resp_handling(self):
proto = make_socks4(self.loop, r=b'\x00\x5e\x00P\x7f\x00\x00\x01')
req = proto.socks_request(c.SOCKS_CMD_CONNECT)

with self.assertRaises(aiosocks.SocksError) as cm:
@@ -447,86 +455,83 @@ class TestSocks5Protocol(unittest.TestCase):
aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop,
waiter=None, app_protocol_factory=None)

def test_authenticate(self):
# invalid server version
def test_auth_inv_srv_ver(self):
proto = make_socks5(self.loop, r=b'\x00\x00')
req = proto.authenticate()

with self.assertRaises(aiosocks.InvalidServerVersion):
self.loop.run_until_complete(req)

# anonymous auth granted
proto = make_socks5(self.loop, r=b'\x05\x00')
req = proto.authenticate()
self.loop.run_until_complete(req)

# no acceptable auth methods
def test_auth_no_acceptable_auth_methods(self):
proto = make_socks5(self.loop, r=b'\x05\xFF')
req = proto.authenticate()
with self.assertRaises(aiosocks.NoAcceptableAuthMethods):
self.loop.run_until_complete(req)

# unsupported auth method
def test_auth_unsupported_auth_method(self):
proto = make_socks5(self.loop, r=b'\x05\xF0')
req = proto.authenticate()
with self.assertRaises(aiosocks.InvalidServerReply):
self.loop.run_until_complete(req)

# auth: username, pwd
# access granted
def test_auth_usr_pwd_granted(self):
proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x01\x00',))
req = proto.authenticate()
self.loop.run_until_complete(req)
proto._transport.write.assert_has_calls([
self.loop.run_until_complete(proto.authenticate())
proto._stream_writer.write.assert_has_calls([
mock.call(b'\x05\x02\x00\x02'),
mock.call(b'\x01\x04user\x03pwd')
])

# invalid reply
def test_auth_invalid_reply(self):
proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x00\x00',))
req = proto.authenticate()
with self.assertRaises(aiosocks.InvalidServerReply):
self.loop.run_until_complete(req)

# access denied
def test_auth_access_denied(self):
proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x01\x01',))
req = proto.authenticate()
with self.assertRaises(aiosocks.LoginAuthenticationFailed):
self.loop.run_until_complete(req)

def test_write_address(self):
# ipv4
def test_auth_anonymous_granted(self):
proto = make_socks5(self.loop, r=b'\x05\x00')
req = proto.authenticate()
self.loop.run_until_complete(req)

def test_wr_addr_ipv4(self):
proto = make_socks5(self.loop)
req = proto.write_address('127.0.0.1', 80)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(b'\x01\x7f\x00\x00\x01\x00P')
proto._stream_writer.write.assert_called_with(
b'\x01\x7f\x00\x00\x01\x00P')

# ipv6
def test_wr_addr_ipv6(self):
proto = make_socks5(self.loop)
req = proto.write_address(
'2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(
proto._stream_writer.write.assert_called_with(
b'\x04 \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]\x00P')

# domain, remote_resolve = true
def test_wr_addr_domain_with_remote_resolve(self):
proto = make_socks5(self.loop)
req = proto.write_address('python.org', 80)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(b'\x03\npython.org\x00P')
proto._stream_writer.write.assert_called_with(b'\x03\npython.org\x00P')

# domain, remote resolve = false
def test_wr_addr_domain_with_locale_resolve(self):
proto = make_socks5(self.loop, rr=False)
req = proto.write_address('python.org', 80)
self.loop.run_until_complete(req)

proto._transport.write.assert_called_with(b'\x01\x7f\x00\x00\x01\x00P')
proto._stream_writer.write.assert_called_with(
b'\x01\x7f\x00\x00\x01\x00P')

def test_read_address(self):
# ipv4
def test_rd_addr_ipv4(self):
proto = make_socks5(
self.loop, r=[b'\x01', b'\x7f\x00\x00\x01', b'\x00P'])
req = ensure_future(proto.read_address(), loop=self.loop)
@@ -534,7 +539,7 @@ class TestSocks5Protocol(unittest.TestCase):

self.assertEqual(req.result(), ('127.0.0.1', 80))

# ipv6
def test_rd_addr_ipv6(self):
resp = [
b'\x04',
b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]',
@@ -547,7 +552,7 @@ class TestSocks5Protocol(unittest.TestCase):
self.assertEqual(
req.result(), ('2001:db8:11a3:9d7:1f34:8a2e:7a0:765d', 80))

# domain
def test_rd_addr_domain(self):
proto = make_socks5(
self.loop, r=[b'\x03', b'\n', b'python.org', b'\x00P'])
req = ensure_future(proto.read_address(), loop=self.loop)
@@ -555,14 +560,13 @@ class TestSocks5Protocol(unittest.TestCase):

self.assertEqual(req.result(), (b'python.org', 80))

def test_socks_request(self):
# invalid version
def test_socks_req_inv_ver(self):
proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x04\x00\x00'])
req = proto.socks_request(c.SOCKS_CMD_CONNECT)
with self.assertRaises(aiosocks.InvalidServerVersion):
self.loop.run_until_complete(req)

# socks error
def test_socks_req_socks_srv_err(self):
proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x05\x02\x00'])
req = proto.socks_request(c.SOCKS_CMD_CONNECT)
with self.assertRaises(aiosocks.SocksError) as ct:
@@ -571,7 +575,7 @@ class TestSocks5Protocol(unittest.TestCase):
self.assertTrue(
'Connection not allowed by ruleset' in str(ct.exception))

# socks unknown error
def test_socks_req_unknown_err(self):
proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x05\xFF\x00'])
req = proto.socks_request(c.SOCKS_CMD_CONNECT)
with self.assertRaises(aiosocks.SocksError) as ct:
@@ -579,6 +583,7 @@ class TestSocks5Protocol(unittest.TestCase):

self.assertTrue('Unknown error' in str(ct.exception))

def test_socks_req_cmd_granted(self):
# cmd granted
resp = [b'\x05\x00',
b'\x05\x00\x00',
@@ -590,7 +595,7 @@ class TestSocks5Protocol(unittest.TestCase):
self.loop.run_until_complete(req)

self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80)))
proto._transport.write.assert_has_calls([
proto._stream_writer.write.assert_has_calls([
mock.call(b'\x05\x02\x00\x02'),
mock.call(b'\x05\x01\x00'),
mock.call(b'\x03\npython.org\x00P')

Loading…
Cancel
Save