An attempt at adding UDP support to aiosocks. Untested due to lack of server support.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

619 lines
20 KiB

  1. import asyncio
  2. import aiosocks
  3. import pytest
  4. import socket
  5. import ssl as ssllib
  6. from unittest import mock
  7. from asyncio import coroutine as coro
  8. from aiohttp.test_utils import make_mocked_coro
  9. import aiosocks.constants as c
  10. from aiosocks.protocols import BaseSocksProtocol
  11. def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None):
  12. dst = dst or ('python.org', 80)
  13. proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl,
  14. loop=loop, waiter=waiter,
  15. app_protocol_factory=ap_factory)
  16. return proto
  17. def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'',
  18. ap_factory=None, whiter=None):
  19. addr = addr or aiosocks.Socks4Addr('localhost', 1080)
  20. auth = auth or aiosocks.Socks4Auth('user')
  21. dst = dst or ('python.org', 80)
  22. proto = aiosocks.Socks4Protocol(
  23. proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
  24. loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
  25. proto._stream_writer = mock.Mock()
  26. proto.read_response = mock.Mock(
  27. side_effect=coro(mock.Mock(return_value=r)))
  28. proto._get_dst_addr = mock.Mock(
  29. side_effect=coro(mock.Mock(return_value=(socket.AF_INET, '127.0.0.1')))
  30. )
  31. return proto
  32. def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None,
  33. ap_factory=None, whiter=None):
  34. addr = addr or aiosocks.Socks5Addr('localhost', 1080)
  35. auth = auth or aiosocks.Socks5Auth('user', 'pwd')
  36. dst = dst or ('python.org', 80)
  37. proto = aiosocks.Socks5Protocol(
  38. proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
  39. loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
  40. proto._stream_writer = mock.Mock()
  41. proto._stream_writer.drain = make_mocked_coro(True)
  42. if not isinstance(r, (list, tuple)):
  43. proto.read_response = mock.Mock(
  44. side_effect=coro(mock.Mock(return_value=r)))
  45. else:
  46. proto.read_response = mock.Mock(
  47. side_effect=coro(mock.Mock(side_effect=r)))
  48. proto._get_dst_addr = mock.Mock(
  49. side_effect=coro(mock.Mock(return_value=(socket.AF_INET, '127.0.0.1')))
  50. )
  51. return proto
  52. def test_base_ctor(loop):
  53. with pytest.raises(ValueError):
  54. BaseSocksProtocol(None, None, None, loop=loop,
  55. waiter=None, app_protocol_factory=None)
  56. with pytest.raises(ValueError):
  57. BaseSocksProtocol(None, None, 123, loop=loop,
  58. waiter=None, app_protocol_factory=None)
  59. with pytest.raises(ValueError):
  60. BaseSocksProtocol(None, None, ('python.org',), loop=loop,
  61. waiter=None, app_protocol_factory=None)
  62. def test_base_write_request(loop):
  63. proto = make_base(loop)
  64. proto._stream_writer = mock.Mock()
  65. proto.write_request([b'\x00', b'\x01\x02', 0x03])
  66. proto._stream_writer.write.assert_called_with(b'\x00\x01\x02\x03')
  67. with pytest.raises(ValueError):
  68. proto.write_request(['\x00'])
  69. async def test_base_negotiate_os_error(loop):
  70. waiter = asyncio.Future(loop=loop)
  71. proto = make_base(loop, waiter=waiter)
  72. proto.socks_request = make_mocked_coro(raise_exception=OSError('test'))
  73. await proto.negotiate(None, None)
  74. with pytest.raises(OSError) as ct:
  75. await waiter
  76. assert 'test' in str(ct)
  77. async def test_base_negotiate_socks_err(loop):
  78. waiter = asyncio.Future(loop=loop)
  79. proto = make_base(loop, waiter=waiter)
  80. proto.socks_request = make_mocked_coro(
  81. raise_exception=aiosocks.SocksError('test'))
  82. await proto.negotiate(None, None)
  83. with pytest.raises(aiosocks.SocksError) as ct:
  84. await waiter
  85. assert 'Can not connect to' in str(ct)
  86. async def test_base_negotiate_without_app_proto(loop):
  87. waiter = asyncio.Future(loop=loop)
  88. proto = make_base(loop, waiter=waiter)
  89. proto.socks_request = make_mocked_coro((None, None))
  90. proto._transport = True
  91. await proto.negotiate(None, None)
  92. await waiter
  93. assert waiter.done()
  94. async def test_base_negotiate_with_app_proto(loop):
  95. waiter = asyncio.Future(loop=loop)
  96. proto = make_base(loop, waiter=waiter,
  97. ap_factory=lambda: asyncio.Protocol())
  98. proto.socks_request = make_mocked_coro((None, None))
  99. await proto.negotiate(None, None)
  100. await waiter
  101. assert waiter.done()
  102. def test_base_connection_lost():
  103. loop_mock = mock.Mock()
  104. app_proto = mock.Mock()
  105. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  106. # negotiate not completed
  107. proto._negotiate_done = False
  108. proto.connection_lost(True)
  109. assert not loop_mock.call_soon.called
  110. # negotiate successfully competed
  111. loop_mock.reset_mock()
  112. proto._negotiate_done = True
  113. proto.connection_lost(True)
  114. assert loop_mock.call_soon.called
  115. # don't call connect_lost, if app_protocol == self
  116. # otherwise recursion
  117. loop_mock.reset_mock()
  118. proto = make_base(loop_mock, ap_factory=None)
  119. proto._negotiate_done = True
  120. proto.connection_lost(True)
  121. assert not loop_mock.call_soon.called
  122. def test_base_pause_writing():
  123. loop_mock = mock.Mock()
  124. app_proto = mock.Mock()
  125. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  126. # negotiate not completed
  127. proto._negotiate_done = False
  128. proto.pause_writing()
  129. assert not proto._app_protocol.pause_writing.called
  130. # negotiate successfully competed
  131. app_proto.reset_mock()
  132. proto._negotiate_done = True
  133. proto.pause_writing()
  134. assert proto._app_protocol.pause_writing.called
  135. # don't call pause_writing, if app_protocol == self
  136. # otherwise recursion
  137. app_proto.reset_mock()
  138. proto = make_base(loop_mock)
  139. proto._negotiate_done = True
  140. proto.pause_writing()
  141. def test_base_resume_writing():
  142. loop_mock = mock.Mock()
  143. app_proto = mock.Mock()
  144. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  145. # negotiate not completed
  146. proto._negotiate_done = False
  147. # negotiate not completed
  148. with pytest.raises(AssertionError):
  149. proto.resume_writing()
  150. assert not proto._app_protocol.resume_writing.called
  151. # negotiate successfully competed
  152. loop_mock.reset_mock()
  153. proto._negotiate_done = True
  154. proto.resume_writing()
  155. assert proto._app_protocol.resume_writing.called
  156. # don't call resume_writing, if app_protocol == self
  157. # otherwise recursion
  158. loop_mock.reset_mock()
  159. proto = make_base(loop_mock)
  160. proto._negotiate_done = True
  161. with pytest.raises(AssertionError):
  162. proto.resume_writing()
  163. def test_base_data_received():
  164. loop_mock = mock.Mock()
  165. app_proto = mock.Mock()
  166. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  167. # negotiate not completed
  168. proto._negotiate_done = False
  169. proto.data_received(b'123')
  170. assert not proto._app_protocol.data_received.called
  171. # negotiate successfully competed
  172. app_proto.reset_mock()
  173. proto._negotiate_done = True
  174. proto.data_received(b'123')
  175. assert proto._app_protocol.data_received.called
  176. # don't call data_received, if app_protocol == self
  177. # otherwise recursion
  178. loop_mock.reset_mock()
  179. proto = make_base(loop_mock)
  180. proto._negotiate_done = True
  181. proto.data_received(b'123')
  182. def test_base_eof_received():
  183. loop_mock = mock.Mock()
  184. app_proto = mock.Mock()
  185. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  186. # negotiate not completed
  187. proto._negotiate_done = False
  188. proto.eof_received()
  189. assert not proto._app_protocol.eof_received.called
  190. # negotiate successfully competed
  191. app_proto.reset_mock()
  192. proto._negotiate_done = True
  193. proto.eof_received()
  194. assert proto._app_protocol.eof_received.called
  195. # don't call pause_writing, if app_protocol == self
  196. # otherwise recursion
  197. app_proto.reset_mock()
  198. proto = make_base(loop_mock)
  199. proto._negotiate_done = True
  200. proto.eof_received()
  201. async def test_base_make_ssl_proto():
  202. loop_mock = mock.Mock()
  203. app_proto = mock.Mock()
  204. ssl_context = ssllib.create_default_context()
  205. proto = make_base(loop_mock,
  206. ap_factory=lambda: app_proto, ssl=ssl_context)
  207. proto.socks_request = make_mocked_coro((None, None))
  208. proto._transport = mock.Mock()
  209. await proto.negotiate(None, None)
  210. mtr = loop_mock._make_ssl_transport
  211. assert mtr.called
  212. assert mtr.call_args[1]['sslcontext'] is ssl_context
  213. async def test_base_func_negotiate_cb_call():
  214. loop_mock = mock.Mock()
  215. waiter = mock.Mock()
  216. proto = make_base(loop_mock, waiter=waiter)
  217. proto.socks_request = make_mocked_coro((None, None))
  218. proto._negotiate_done_cb = mock.Mock()
  219. with mock.patch('aiosocks.protocols.asyncio.Task') as task_mock:
  220. await proto.negotiate(None, None)
  221. assert proto._negotiate_done_cb.called
  222. assert not task_mock.called
  223. async def test_base_coro_negotiate_cb_call():
  224. loop_mock = mock.Mock()
  225. waiter = mock.Mock()
  226. proto = make_base(loop_mock, waiter=waiter)
  227. proto.socks_request = make_mocked_coro((None, None))
  228. proto._negotiate_done_cb = make_mocked_coro(None)
  229. with mock.patch('aiosocks.protocols.asyncio.Task') as task_mock:
  230. await proto.negotiate(None, None)
  231. assert proto._negotiate_done_cb.called
  232. assert task_mock.called
  233. async def test_base_reader_limit(loop):
  234. proto = BaseSocksProtocol(None, None, ('python.org', 80),
  235. None, None, reader_limit=10, loop=loop)
  236. assert proto.reader._limit == 10
  237. proto = BaseSocksProtocol(None, None, ('python.org', 80),
  238. None, None, reader_limit=15, loop=loop)
  239. assert proto.reader._limit == 15
  240. async def test_base_incomplete_error(loop):
  241. proto = BaseSocksProtocol(None, None, ('python.org', 80),
  242. None, None, reader_limit=10, loop=loop)
  243. proto._stream_reader.readexactly = make_mocked_coro(
  244. raise_exception=asyncio.IncompleteReadError(b'part', 5))
  245. with pytest.raises(aiosocks.InvalidServerReply):
  246. await proto.read_response(4)
  247. def test_socks4_ctor(loop):
  248. addr = aiosocks.Socks4Addr('localhost', 1080)
  249. auth = aiosocks.Socks4Auth('user')
  250. dst = ('python.org', 80)
  251. with pytest.raises(ValueError):
  252. aiosocks.Socks4Protocol(None, None, dst, loop=loop,
  253. waiter=None, app_protocol_factory=None)
  254. with pytest.raises(ValueError):
  255. aiosocks.Socks4Protocol(None, auth, dst, loop=loop,
  256. waiter=None, app_protocol_factory=None)
  257. with pytest.raises(ValueError):
  258. aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst,
  259. loop=loop, waiter=None,
  260. app_protocol_factory=None)
  261. with pytest.raises(ValueError):
  262. aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst,
  263. loop=loop, waiter=None,
  264. app_protocol_factory=None)
  265. aiosocks.Socks4Protocol(addr, None, dst, loop=loop,
  266. waiter=None, app_protocol_factory=None)
  267. aiosocks.Socks4Protocol(addr, auth, dst, loop=loop,
  268. waiter=None, app_protocol_factory=None)
  269. async def test_socks4_dst_domain_with_remote_resolve(loop):
  270. proto = make_socks4(loop, dst=('python.org', 80),
  271. r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  272. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  273. proto._stream_writer.write.assert_called_with(
  274. b'\x04\x01\x00P\x00\x00\x00\x01user\x00python.org\x00')
  275. async def test_socks4_dst_domain_with_local_resolve(loop):
  276. proto = make_socks4(loop, dst=('python.org', 80),
  277. rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  278. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  279. proto._stream_writer.write.assert_called_with(
  280. b'\x04\x01\x00P\x7f\x00\x00\x01user\x00')
  281. async def test_socks4_dst_ip_with_remote_resolve(loop):
  282. proto = make_socks4(loop, dst=('127.0.0.1', 8800),
  283. r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  284. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  285. proto._stream_writer.write.assert_called_with(
  286. b'\x04\x01"`\x7f\x00\x00\x01user\x00')
  287. async def test_socks4_dst_ip_with_locale_resolve(loop):
  288. proto = make_socks4(loop, dst=('127.0.0.1', 8800),
  289. rr=False, r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  290. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  291. proto._stream_writer.write.assert_called_with(
  292. b'\x04\x01"`\x7f\x00\x00\x01user\x00')
  293. async def test_socks4_dst_domain_without_user(loop):
  294. proto = make_socks4(loop, auth=aiosocks.Socks4Auth(''),
  295. dst=('python.org', 80),
  296. r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  297. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  298. proto._stream_writer.write.assert_called_with(
  299. b'\x04\x01\x00P\x00\x00\x00\x01\x00python.org\x00')
  300. async def test_socks4_dst_ip_without_user(loop):
  301. proto = make_socks4(loop, auth=aiosocks.Socks4Auth(''),
  302. dst=('127.0.0.1', 8800),
  303. r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  304. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  305. proto._stream_writer.write.assert_called_with(
  306. b'\x04\x01"`\x7f\x00\x00\x01\x00')
  307. async def test_socks4_valid_resp_handling(loop):
  308. proto = make_socks4(loop, r=b'\x00\x5a\x00P\x7f\x00\x00\x01')
  309. r = await proto.socks_request(c.SOCKS_CMD_CONNECT)
  310. assert r == (('python.org', 80), ('127.0.0.1', 80))
  311. async def test_socks4_invalid_reply_resp_handling(loop):
  312. proto = make_socks4(loop, r=b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF')
  313. with pytest.raises(aiosocks.InvalidServerReply):
  314. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  315. async def test_socks_err_resp_handling(loop):
  316. proto = make_socks4(loop, r=b'\x00\x5b\x00P\x7f\x00\x00\x01')
  317. with pytest.raises(aiosocks.SocksError) as cm:
  318. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  319. assert '0x5b' in str(cm)
  320. async def test_socks4_unknown_err_resp_handling(loop):
  321. proto = make_socks4(loop, r=b'\x00\x5e\x00P\x7f\x00\x00\x01')
  322. with pytest.raises(aiosocks.SocksError) as cm:
  323. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  324. assert 'Unknown error' in str(cm)
  325. def test_socks5_ctor(loop):
  326. addr = aiosocks.Socks5Addr('localhost', 1080)
  327. auth = aiosocks.Socks5Auth('user', 'pwd')
  328. dst = ('python.org', 80)
  329. with pytest.raises(ValueError):
  330. aiosocks.Socks5Protocol(None, None, dst, loop=loop,
  331. waiter=None, app_protocol_factory=None)
  332. with pytest.raises(ValueError):
  333. aiosocks.Socks5Protocol(None, auth, dst, loop=loop,
  334. waiter=None, app_protocol_factory=None)
  335. with pytest.raises(ValueError):
  336. aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'),
  337. auth, dst, loop=loop,
  338. waiter=None, app_protocol_factory=None)
  339. with pytest.raises(ValueError):
  340. aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'),
  341. dst, loop=loop,
  342. waiter=None, app_protocol_factory=None)
  343. aiosocks.Socks5Protocol(addr, None, dst, loop=loop,
  344. waiter=None, app_protocol_factory=None)
  345. aiosocks.Socks5Protocol(addr, auth, dst, loop=loop,
  346. waiter=None, app_protocol_factory=None)
  347. async def test_socks5_auth_inv_srv_ver(loop):
  348. proto = make_socks5(loop, r=b'\x00\x00')
  349. with pytest.raises(aiosocks.InvalidServerVersion):
  350. await proto.authenticate()
  351. async def test_socks5_auth_no_acceptable_auth_methods(loop):
  352. proto = make_socks5(loop, r=b'\x05\xFF')
  353. with pytest.raises(aiosocks.NoAcceptableAuthMethods):
  354. await proto.authenticate()
  355. async def test_socks5_auth_unsupported_auth_method(loop):
  356. proto = make_socks5(loop, r=b'\x05\xF0')
  357. with pytest.raises(aiosocks.InvalidServerReply):
  358. await proto.authenticate()
  359. async def test_socks5_auth_usr_pwd_granted(loop):
  360. proto = make_socks5(loop, r=(b'\x05\x02', b'\x01\x00',))
  361. await proto.authenticate()
  362. proto._stream_writer.write.assert_has_calls([
  363. mock.call(b'\x05\x02\x00\x02'),
  364. mock.call(b'\x01\x04user\x03pwd')
  365. ])
  366. async def test_socks5_auth_invalid_reply(loop):
  367. proto = make_socks5(loop, r=(b'\x05\x02', b'\x00\x00',))
  368. with pytest.raises(aiosocks.InvalidServerReply):
  369. await proto.authenticate()
  370. async def test_socks5_auth_access_denied(loop):
  371. proto = make_socks5(loop, r=(b'\x05\x02', b'\x01\x01',))
  372. with pytest.raises(aiosocks.LoginAuthenticationFailed):
  373. await proto.authenticate()
  374. async def test_socks5_auth_anonymous_granted(loop):
  375. proto = make_socks5(loop, r=b'\x05\x00')
  376. await proto.authenticate()
  377. async def test_socks5_build_dst_addr_ipv4(loop):
  378. proto = make_socks5(loop)
  379. dst_req, resolved = await proto.build_dst_address('127.0.0.1', 80)
  380. assert dst_req == [0x01, b'\x7f\x00\x00\x01', b'\x00P']
  381. assert resolved == ('127.0.0.1', 80)
  382. async def test_socks5_build_dst_addr_ipv6(loop):
  383. proto = make_socks5(loop)
  384. dst_req, resolved = await proto.build_dst_address(
  385. '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)
  386. assert dst_req == [
  387. 0x04, b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', b'\x00P']
  388. assert resolved == ('2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)
  389. async def test_socks5_build_dst_addr_domain_with_remote_resolve(loop):
  390. proto = make_socks5(loop)
  391. dst_req, resolved = await proto.build_dst_address('python.org', 80)
  392. assert dst_req == [0x03, b'\n', b'python.org', b'\x00P']
  393. assert resolved == ('python.org', 80)
  394. async def test_socks5_build_dst_addr_domain_with_locale_resolve(loop):
  395. proto = make_socks5(loop, rr=False)
  396. dst_req, resolved = await proto.build_dst_address('python.org', 80)
  397. assert dst_req == [0x01, b'\x7f\x00\x00\x01', b'\x00P']
  398. assert resolved == ('127.0.0.1', 80)
  399. async def test_socks5_rd_addr_ipv4(loop):
  400. proto = make_socks5(loop, r=[b'\x01', b'\x7f\x00\x00\x01', b'\x00P'])
  401. r = await proto.read_address()
  402. assert r == ('127.0.0.1', 80)
  403. async def test_socks5_rd_addr_ipv6(loop):
  404. resp = [
  405. b'\x04',
  406. b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]',
  407. b'\x00P'
  408. ]
  409. proto = make_socks5(loop, r=resp)
  410. r = await proto.read_address()
  411. assert r == ('2001:db8:11a3:9d7:1f34:8a2e:7a0:765d', 80)
  412. async def test_socks5_rd_addr_domain(loop):
  413. proto = make_socks5(loop, r=[b'\x03', b'\n', b'python.org', b'\x00P'])
  414. r = await proto.read_address()
  415. assert r == (b'python.org', 80)
  416. async def test_socks5_socks_req_inv_ver(loop):
  417. proto = make_socks5(loop, r=[b'\x05\x00', b'\x04\x00\x00'])
  418. with pytest.raises(aiosocks.InvalidServerVersion):
  419. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  420. async def test_socks5_socks_req_socks_srv_err(loop):
  421. proto = make_socks5(loop, r=[b'\x05\x00', b'\x05\x02\x00'])
  422. with pytest.raises(aiosocks.SocksError) as ct:
  423. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  424. assert 'Connection not allowed by ruleset' in str(ct)
  425. async def test_socks5_socks_req_unknown_err(loop):
  426. proto = make_socks5(loop, r=[b'\x05\x00', b'\x05\xFF\x00'])
  427. with pytest.raises(aiosocks.SocksError) as ct:
  428. await proto.socks_request(c.SOCKS_CMD_CONNECT)
  429. assert 'Unknown error' in str(ct)
  430. async def test_socks_req_cmd_granted(loop):
  431. # cmd granted
  432. resp = [b'\x05\x00',
  433. b'\x05\x00\x00',
  434. b'\x01', b'\x7f\x00\x00\x01',
  435. b'\x00P']
  436. proto = make_socks5(loop, r=resp)
  437. r = await proto.socks_request(c.SOCKS_CMD_CONNECT)
  438. assert r == (('python.org', 80), ('127.0.0.1', 80))
  439. proto._stream_writer.write.assert_has_calls([
  440. mock.call(b'\x05\x02\x00\x02'),
  441. mock.call(b'\x05\x01\x00\x03\npython.org\x00P')
  442. ])