An attempt at adding UDP support to aiosocks. Untested due to lack of server support.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

598 lines
21 KiB

  1. import asyncio
  2. import aiosocks
  3. import unittest
  4. import socket
  5. from unittest import mock
  6. from asyncio import coroutine as coro
  7. import aiosocks.constants as c
  8. from aiosocks.protocols import BaseSocksProtocol
  9. try:
  10. from asyncio import ensure_future
  11. except ImportError:
  12. ensure_future = asyncio.async
  13. def make_base(loop, *, dst=None, waiter=None, ap_factory=None, ssl=None):
  14. dst = dst or ('python.org', 80)
  15. proto = BaseSocksProtocol(None, None, dst=dst, ssl=ssl,
  16. loop=loop, waiter=waiter,
  17. app_protocol_factory=ap_factory)
  18. return proto
  19. def make_socks4(loop, *, addr=None, auth=None, rr=True, dst=None, r=b'',
  20. ap_factory=None, whiter=None):
  21. addr = addr or aiosocks.Socks4Addr('localhost', 1080)
  22. auth = auth or aiosocks.Socks4Auth('user')
  23. dst = dst or ('python.org', 80)
  24. proto = aiosocks.Socks4Protocol(
  25. proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
  26. loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
  27. proto._transport = mock.Mock()
  28. proto.read_response = mock.Mock(
  29. side_effect=coro(mock.Mock(return_value=r)))
  30. proto._get_dst_addr = mock.Mock(
  31. side_effect=coro(mock.Mock(return_value=(socket.AF_INET, '127.0.0.1')))
  32. )
  33. return proto
  34. def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None,
  35. ap_factory=None, whiter=None):
  36. addr = addr or aiosocks.Socks5Addr('localhost', 1080)
  37. auth = auth or aiosocks.Socks5Auth('user', 'pwd')
  38. dst = dst or ('python.org', 80)
  39. proto = aiosocks.Socks5Protocol(
  40. proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
  41. loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
  42. proto._transport = mock.Mock()
  43. if not isinstance(r, (list, tuple)):
  44. proto.read_response = mock.Mock(
  45. side_effect=coro(mock.Mock(return_value=r)))
  46. else:
  47. proto.read_response = mock.Mock(
  48. side_effect=coro(mock.Mock(side_effect=r)))
  49. proto._get_dst_addr = mock.Mock(
  50. side_effect=coro(mock.Mock(return_value=(socket.AF_INET, '127.0.0.1')))
  51. )
  52. return proto
  53. class TestBaseSocksProtocol(unittest.TestCase):
  54. def setUp(self):
  55. self.loop = asyncio.new_event_loop()
  56. asyncio.set_event_loop(None)
  57. def tearDown(self):
  58. self.loop.close()
  59. def test_init(self):
  60. with self.assertRaises(ValueError):
  61. BaseSocksProtocol(None, None, None, loop=self.loop,
  62. waiter=None, app_protocol_factory=None)
  63. with self.assertRaises(ValueError):
  64. BaseSocksProtocol(None, None, 123, loop=self.loop,
  65. waiter=None, app_protocol_factory=None)
  66. with self.assertRaises(ValueError):
  67. BaseSocksProtocol(None, None, ('python.org',), loop=self.loop,
  68. waiter=None, app_protocol_factory=None)
  69. def test_write_request(self):
  70. proto = make_base(self.loop)
  71. proto._transport = mock.Mock()
  72. proto.write_request([b'\x00', b'\x01\x02', 0x03])
  73. proto._transport.write.assert_called_with(b'\x00\x01\x02\x03')
  74. with self.assertRaises(ValueError):
  75. proto.write_request(['\x00'])
  76. @mock.patch('aiosocks.protocols.ensure_future')
  77. def test_connection_made_os_error(self, ef_mock):
  78. os_err_fut = asyncio.Future(loop=self.loop)
  79. ef_mock.return_value = os_err_fut
  80. waiter = asyncio.Future(loop=self.loop)
  81. proto = make_base(self.loop, waiter=waiter)
  82. proto.connection_made(mock.Mock())
  83. self.assertIs(proto._negotiate_fut, os_err_fut)
  84. with self.assertRaises(OSError):
  85. os_err_fut.set_exception(OSError('test'))
  86. self.loop.run_until_complete(os_err_fut)
  87. self.assertIn('test', str(waiter.exception()))
  88. @mock.patch('aiosocks.protocols.ensure_future')
  89. def test_connection_made_socks_err(self, ef_mock):
  90. socks_err_fut = asyncio.Future(loop=self.loop)
  91. ef_mock.return_value = socks_err_fut
  92. waiter = asyncio.Future(loop=self.loop)
  93. proto = make_base(self.loop, waiter=waiter)
  94. proto.connection_made(mock.Mock())
  95. self.assertIs(proto._negotiate_fut, socks_err_fut)
  96. with self.assertRaises(aiosocks.SocksError):
  97. socks_err_fut.set_exception(aiosocks.SocksError('test'))
  98. self.loop.run_until_complete(socks_err_fut)
  99. self.assertIn('Can not connect to', str(waiter.exception()))
  100. @mock.patch('aiosocks.protocols.ensure_future')
  101. def test_connection_made_without_app_proto(self, ef_mock):
  102. success_fut = asyncio.Future(loop=self.loop)
  103. ef_mock.return_value = success_fut
  104. waiter = asyncio.Future(loop=self.loop)
  105. proto = make_base(self.loop, waiter=waiter)
  106. proto.connection_made(mock.Mock())
  107. self.assertIs(proto._negotiate_fut, success_fut)
  108. success_fut.set_result(True)
  109. self.loop.run_until_complete(success_fut)
  110. self.assertTrue(waiter.done())
  111. @mock.patch('aiosocks.protocols.ensure_future')
  112. def test_connection_made_with_app_proto(self, ef_mock):
  113. success_fut = asyncio.Future(loop=self.loop)
  114. ef_mock.return_value = success_fut
  115. waiter = asyncio.Future(loop=self.loop)
  116. proto = make_base(self.loop, waiter=waiter,
  117. ap_factory=lambda: asyncio.Protocol())
  118. proto.connection_made(mock.Mock())
  119. self.assertIs(proto._negotiate_fut, success_fut)
  120. success_fut.set_result(True)
  121. self.loop.run_until_complete(success_fut)
  122. self.assertTrue(waiter.done())
  123. @mock.patch('aiosocks.protocols.ensure_future')
  124. def test_connection_lost(self, ef_mock):
  125. negotiate_fut = asyncio.Future(loop=self.loop)
  126. ef_mock.return_value = negotiate_fut
  127. app_proto = mock.Mock()
  128. loop_mock = mock.Mock()
  129. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  130. proto.connection_made(mock.Mock())
  131. # negotiate not completed
  132. proto.connection_lost(True)
  133. self.assertFalse(loop_mock.call_soon.called)
  134. # negotiate successfully competed
  135. negotiate_fut.set_result(True)
  136. proto.connection_lost(True)
  137. self.assertTrue(loop_mock.call_soon.called)
  138. # negotiate failed
  139. negotiate_fut = asyncio.Future(loop=self.loop)
  140. ef_mock.return_value = negotiate_fut
  141. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  142. proto.connection_made(mock.Mock())
  143. negotiate_fut.set_exception(Exception())
  144. proto.connection_lost(True)
  145. self.assertTrue(loop_mock.call_soon.called)
  146. @mock.patch('aiosocks.protocols.ensure_future')
  147. def test_pause_writing(self, ef_mock):
  148. negotiate_fut = asyncio.Future(loop=self.loop)
  149. ef_mock.return_value = negotiate_fut
  150. app_proto = mock.Mock()
  151. loop_mock = mock.Mock()
  152. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  153. proto.connection_made(mock.Mock())
  154. # negotiate not completed
  155. proto.pause_writing()
  156. self.assertFalse(app_proto.pause_writing.called)
  157. # negotiate successfully competed
  158. negotiate_fut.set_result(True)
  159. proto.pause_writing()
  160. self.assertTrue(app_proto.pause_writing.called)
  161. @mock.patch('aiosocks.protocols.ensure_future')
  162. def test_resume_writing(self, ef_mock):
  163. negotiate_fut = asyncio.Future(loop=self.loop)
  164. ef_mock.return_value = negotiate_fut
  165. app_proto = mock.Mock()
  166. loop_mock = mock.Mock()
  167. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  168. proto.connection_made(mock.Mock())
  169. # negotiate not completed
  170. with self.assertRaises(AssertionError):
  171. proto.resume_writing()
  172. # negotiate fail
  173. negotiate_fut.set_exception(Exception())
  174. proto.resume_writing()
  175. self.assertTrue(app_proto.resume_writing.called)
  176. @mock.patch('aiosocks.protocols.ensure_future')
  177. def test_data_received(self, ef_mock):
  178. negotiate_fut = asyncio.Future(loop=self.loop)
  179. ef_mock.return_value = negotiate_fut
  180. app_proto = mock.Mock()
  181. loop_mock = mock.Mock()
  182. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  183. proto.connection_made(mock.Mock())
  184. # negotiate not completed
  185. proto.data_received(b'123')
  186. self.assertFalse(app_proto.data_received.called)
  187. # negotiate successfully competed
  188. negotiate_fut.set_result(True)
  189. proto.data_received(b'123')
  190. self.assertTrue(app_proto.data_received.called)
  191. @mock.patch('aiosocks.protocols.ensure_future')
  192. def test_eof_received(self, ef_mock):
  193. negotiate_fut = asyncio.Future(loop=self.loop)
  194. ef_mock.return_value = negotiate_fut
  195. app_proto = mock.Mock()
  196. loop_mock = mock.Mock()
  197. proto = make_base(loop_mock, ap_factory=lambda: app_proto)
  198. proto.connection_made(mock.Mock())
  199. # negotiate not completed
  200. proto.eof_received()
  201. self.assertFalse(app_proto.eof_received.called)
  202. # negotiate successfully competed
  203. negotiate_fut.set_result(True)
  204. proto.eof_received()
  205. self.assertTrue(app_proto.eof_received.called)
  206. class TestSocks4Protocol(unittest.TestCase):
  207. def setUp(self):
  208. self.loop = asyncio.new_event_loop()
  209. asyncio.set_event_loop(None)
  210. def tearDown(self):
  211. self.loop.close()
  212. def test_init(self):
  213. addr = aiosocks.Socks4Addr('localhost', 1080)
  214. auth = aiosocks.Socks4Auth('user')
  215. dst = ('python.org', 80)
  216. with self.assertRaises(ValueError):
  217. aiosocks.Socks4Protocol(None, None, dst, loop=self.loop,
  218. waiter=None, app_protocol_factory=None)
  219. with self.assertRaises(ValueError):
  220. aiosocks.Socks4Protocol(None, auth, dst, loop=self.loop,
  221. waiter=None, app_protocol_factory=None)
  222. with self.assertRaises(ValueError):
  223. aiosocks.Socks4Protocol(aiosocks.Socks5Addr('host'), auth, dst,
  224. loop=self.loop, waiter=None,
  225. app_protocol_factory=None)
  226. with self.assertRaises(ValueError):
  227. aiosocks.Socks4Protocol(addr, aiosocks.Socks5Auth('l', 'p'), dst,
  228. loop=self.loop, waiter=None,
  229. app_protocol_factory=None)
  230. aiosocks.Socks4Protocol(addr, None, dst, loop=self.loop,
  231. waiter=None, app_protocol_factory=None)
  232. aiosocks.Socks4Protocol(addr, auth, dst, loop=self.loop,
  233. waiter=None, app_protocol_factory=None)
  234. def test_request_building(self):
  235. resp = b'\x00\x5a\x00P\x7f\x00\x00\x01'
  236. # dst = domain, remote resolve = true
  237. proto = make_socks4(self.loop, dst=('python.org', 80), r=resp)
  238. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  239. self.loop.run_until_complete(req)
  240. proto._transport.write.assert_called_with(
  241. b'\x04\x01\x00P\x00\x00\x00\x01user\x00python.org\x00'
  242. )
  243. # dst = domain, remote resolve = false
  244. proto = make_socks4(self.loop, dst=('python.org', 80),
  245. rr=False, r=resp)
  246. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  247. self.loop.run_until_complete(req)
  248. proto._transport.write.assert_called_with(
  249. b'\x04\x01\x00P\x7f\x00\x00\x01user\x00'
  250. )
  251. # dst = ip, remote resolve = true
  252. proto = make_socks4(self.loop, dst=('127.0.0.1', 8800), r=resp)
  253. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  254. self.loop.run_until_complete(req)
  255. proto._transport.write.assert_called_with(
  256. b'\x04\x01"`\x7f\x00\x00\x01user\x00'
  257. )
  258. # dst = ip, remote resolve = false
  259. proto = make_socks4(self.loop, dst=('127.0.0.1', 8800),
  260. rr=False, r=resp)
  261. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  262. self.loop.run_until_complete(req)
  263. proto._transport.write.assert_called_with(
  264. b'\x04\x01"`\x7f\x00\x00\x01user\x00'
  265. )
  266. # dst = domain, without user
  267. proto = make_socks4(self.loop, auth=aiosocks.Socks4Auth(''),
  268. dst=('python.org', 80), r=resp)
  269. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  270. self.loop.run_until_complete(req)
  271. proto._transport.write.assert_called_with(
  272. b'\x04\x01\x00P\x00\x00\x00\x01\x00python.org\x00'
  273. )
  274. # dst = ip, without user
  275. proto = make_socks4(self.loop, auth=aiosocks.Socks4Auth(''),
  276. dst=('127.0.0.1', 8800), r=resp)
  277. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  278. self.loop.run_until_complete(req)
  279. proto._transport.write.assert_called_with(
  280. b'\x04\x01"`\x7f\x00\x00\x01\x00'
  281. )
  282. def test_response_handling(self):
  283. valid_resp = b'\x00\x5a\x00P\x7f\x00\x00\x01'
  284. invalid_data_resp = b'\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF'
  285. socks_err_resp = b'\x00\x5b\x00P\x7f\x00\x00\x01'
  286. socks_err_unk_resp = b'\x00\x5e\x00P\x7f\x00\x00\x01'
  287. # valid result
  288. proto = make_socks4(self.loop, r=valid_resp)
  289. req = ensure_future(
  290. proto.socks_request(c.SOCKS_CMD_CONNECT), loop=self.loop)
  291. self.loop.run_until_complete(req)
  292. self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80)))
  293. # invalid server reply
  294. proto = make_socks4(self.loop, r=invalid_data_resp)
  295. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  296. with self.assertRaises(aiosocks.InvalidServerReply):
  297. self.loop.run_until_complete(req)
  298. # socks server sent error
  299. proto = make_socks4(self.loop, r=socks_err_resp)
  300. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  301. with self.assertRaises(aiosocks.SocksError) as cm:
  302. self.loop.run_until_complete(req)
  303. self.assertTrue('0x5b' in str(cm.exception))
  304. # socks server send unknown error
  305. proto = make_socks4(self.loop, r=socks_err_unk_resp)
  306. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  307. with self.assertRaises(aiosocks.SocksError) as cm:
  308. self.loop.run_until_complete(req)
  309. self.assertTrue('Unknown error' in str(cm.exception))
  310. class TestSocks5Protocol(unittest.TestCase):
  311. def setUp(self):
  312. self.loop = asyncio.new_event_loop()
  313. asyncio.set_event_loop(None)
  314. def tearDown(self):
  315. self.loop.close()
  316. def test_init(self):
  317. addr = aiosocks.Socks5Addr('localhost', 1080)
  318. auth = aiosocks.Socks5Auth('user', 'pwd')
  319. dst = ('python.org', 80)
  320. with self.assertRaises(ValueError):
  321. aiosocks.Socks5Protocol(None, None, dst, loop=self.loop,
  322. waiter=None, app_protocol_factory=None)
  323. with self.assertRaises(ValueError):
  324. aiosocks.Socks5Protocol(None, auth, dst, loop=self.loop,
  325. waiter=None, app_protocol_factory=None)
  326. with self.assertRaises(ValueError):
  327. aiosocks.Socks5Protocol(aiosocks.Socks4Addr('host'),
  328. auth, dst, loop=self.loop,
  329. waiter=None, app_protocol_factory=None)
  330. with self.assertRaises(ValueError):
  331. aiosocks.Socks5Protocol(addr, aiosocks.Socks4Auth('l'),
  332. dst, loop=self.loop,
  333. waiter=None, app_protocol_factory=None)
  334. aiosocks.Socks5Protocol(addr, None, dst, loop=self.loop,
  335. waiter=None, app_protocol_factory=None)
  336. aiosocks.Socks5Protocol(addr, auth, dst, loop=self.loop,
  337. waiter=None, app_protocol_factory=None)
  338. def test_authenticate(self):
  339. # invalid server version
  340. proto = make_socks5(self.loop, r=b'\x00\x00')
  341. req = proto.authenticate()
  342. with self.assertRaises(aiosocks.InvalidServerVersion):
  343. self.loop.run_until_complete(req)
  344. # anonymous auth granted
  345. proto = make_socks5(self.loop, r=b'\x05\x00')
  346. req = proto.authenticate()
  347. self.loop.run_until_complete(req)
  348. # no acceptable auth methods
  349. proto = make_socks5(self.loop, r=b'\x05\xFF')
  350. req = proto.authenticate()
  351. with self.assertRaises(aiosocks.NoAcceptableAuthMethods):
  352. self.loop.run_until_complete(req)
  353. # unsupported auth method
  354. proto = make_socks5(self.loop, r=b'\x05\xF0')
  355. req = proto.authenticate()
  356. with self.assertRaises(aiosocks.InvalidServerReply):
  357. self.loop.run_until_complete(req)
  358. # auth: username, pwd
  359. # access granted
  360. proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x01\x00',))
  361. req = proto.authenticate()
  362. self.loop.run_until_complete(req)
  363. proto._transport.write.assert_has_calls([
  364. mock.call(b'\x05\x02\x00\x02'),
  365. mock.call(b'\x01\x04user\x03pwd')
  366. ])
  367. # invalid reply
  368. proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x00\x00',))
  369. req = proto.authenticate()
  370. with self.assertRaises(aiosocks.InvalidServerReply):
  371. self.loop.run_until_complete(req)
  372. # access denied
  373. proto = make_socks5(self.loop, r=(b'\x05\x02', b'\x01\x01',))
  374. req = proto.authenticate()
  375. with self.assertRaises(aiosocks.LoginAuthenticationFailed):
  376. self.loop.run_until_complete(req)
  377. def test_write_address(self):
  378. # ipv4
  379. proto = make_socks5(self.loop)
  380. req = proto.write_address('127.0.0.1', 80)
  381. self.loop.run_until_complete(req)
  382. proto._transport.write.assert_called_with(b'\x01\x7f\x00\x00\x01\x00P')
  383. # ipv6
  384. proto = make_socks5(self.loop)
  385. req = proto.write_address(
  386. '2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)
  387. self.loop.run_until_complete(req)
  388. proto._transport.write.assert_called_with(
  389. b'\x04 \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]\x00P')
  390. # domain, remote_resolve = true
  391. proto = make_socks5(self.loop)
  392. req = proto.write_address('python.org', 80)
  393. self.loop.run_until_complete(req)
  394. proto._transport.write.assert_called_with(b'\x03\npython.org\x00P')
  395. # domain, remote resolve = false
  396. proto = make_socks5(self.loop, rr=False)
  397. req = proto.write_address('python.org', 80)
  398. self.loop.run_until_complete(req)
  399. proto._transport.write.assert_called_with(b'\x01\x7f\x00\x00\x01\x00P')
  400. def test_read_address(self):
  401. # ipv4
  402. proto = make_socks5(
  403. self.loop, r=[b'\x01', b'\x7f\x00\x00\x01', b'\x00P'])
  404. req = ensure_future(proto.read_address(), loop=self.loop)
  405. self.loop.run_until_complete(req)
  406. self.assertEqual(req.result(), ('127.0.0.1', 80))
  407. # ipv6
  408. resp = [
  409. b'\x04',
  410. b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]',
  411. b'\x00P'
  412. ]
  413. proto = make_socks5(self.loop, r=resp)
  414. req = ensure_future(proto.read_address(), loop=self.loop)
  415. self.loop.run_until_complete(req)
  416. self.assertEqual(
  417. req.result(), ('2001:db8:11a3:9d7:1f34:8a2e:7a0:765d', 80))
  418. # domain
  419. proto = make_socks5(
  420. self.loop, r=[b'\x03', b'\n', b'python.org', b'\x00P'])
  421. req = ensure_future(proto.read_address(), loop=self.loop)
  422. self.loop.run_until_complete(req)
  423. self.assertEqual(req.result(), (b'python.org', 80))
  424. def test_socks_request(self):
  425. # invalid version
  426. proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x04\x00\x00'])
  427. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  428. with self.assertRaises(aiosocks.InvalidServerVersion):
  429. self.loop.run_until_complete(req)
  430. # socks error
  431. proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x05\x02\x00'])
  432. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  433. with self.assertRaises(aiosocks.SocksError) as ct:
  434. self.loop.run_until_complete(req)
  435. self.assertTrue(
  436. 'Connection not allowed by ruleset' in str(ct.exception))
  437. # socks unknown error
  438. proto = make_socks5(self.loop, r=[b'\x05\x00', b'\x05\xFF\x00'])
  439. req = proto.socks_request(c.SOCKS_CMD_CONNECT)
  440. with self.assertRaises(aiosocks.SocksError) as ct:
  441. self.loop.run_until_complete(req)
  442. self.assertTrue('Unknown error' in str(ct.exception))
  443. # cmd granted
  444. resp = [b'\x05\x00',
  445. b'\x05\x00\x00',
  446. b'\x01', b'\x7f\x00\x00\x01',
  447. b'\x00P']
  448. proto = make_socks5(self.loop, r=resp)
  449. req = ensure_future(proto.socks_request(c.SOCKS_CMD_CONNECT),
  450. loop=self.loop)
  451. self.loop.run_until_complete(req)
  452. self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80)))
  453. proto._transport.write.assert_has_calls([
  454. mock.call(b'\x05\x02\x00\x02'),
  455. mock.call(b'\x05\x01\x00'),
  456. mock.call(b'\x03\npython.org\x00P')
  457. ])