An attempt at adding UDP support to aiosocks. Untested due to lack of server support.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

243 lines
6.9 KiB

  1. import asyncio
  2. import aiohttp
  3. import contextlib
  4. import gc
  5. import os
  6. import socket
  7. import ssl
  8. import struct
  9. import threading
  10. from unittest import mock
  11. from aiohttp.server import ServerHttpProtocol
  12. try:
  13. from asyncio import ensure_future
  14. except ImportError:
  15. ensure_future = asyncio.async
  16. def fake_coroutine(return_value):
  17. def coro(*args, **kwargs):
  18. if isinstance(return_value, Exception):
  19. raise return_value
  20. return return_value
  21. return mock.Mock(side_effect=asyncio.coroutine(coro))
  22. def find_unused_port():
  23. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  24. s.bind(('127.0.0.1', 0))
  25. port = s.getsockname()[1]
  26. s.close()
  27. return port
  28. @contextlib.contextmanager
  29. def fake_socks_srv(loop, write_buff):
  30. transports = []
  31. class SocksPrimitiveProtocol(asyncio.Protocol):
  32. _transport = None
  33. def connection_made(self, transport):
  34. self._transport = transport
  35. transports.append(transport)
  36. def data_received(self, data):
  37. self._transport.write(write_buff)
  38. port = find_unused_port()
  39. def factory():
  40. return SocksPrimitiveProtocol()
  41. srv = loop.run_until_complete(
  42. loop.create_server(factory, '127.0.0.1', port))
  43. yield port
  44. for tr in transports:
  45. tr.close()
  46. srv.close()
  47. loop.run_until_complete(srv.wait_closed())
  48. gc.collect()
  49. @contextlib.contextmanager
  50. def fake_socks4_srv(loop):
  51. port = find_unused_port()
  52. transports = []
  53. futures = []
  54. class Socks4Protocol(asyncio.StreamReaderProtocol):
  55. def __init__(self, _loop):
  56. self._loop = _loop
  57. reader = asyncio.StreamReader(loop=self._loop)
  58. super().__init__(reader, client_connected_cb=self.negotiate,
  59. loop=self._loop)
  60. def connection_made(self, transport):
  61. transports.append(transport)
  62. super().connection_made(transport)
  63. @asyncio.coroutine
  64. def negotiate(self, reader, writer):
  65. writer.write(b'\x00\x5a\x04W\x01\x01\x01\x01')
  66. data = yield from reader.read(9)
  67. dst_port = struct.unpack('>H', data[2:4])[0]
  68. dst_addr = data[4:8]
  69. if data[-1] != 0x00:
  70. while True:
  71. byte = yield from reader.read(1)
  72. if byte == 0x00:
  73. break
  74. if dst_addr == b'\x00\x00\x00\x01':
  75. dst_addr = bytearray()
  76. while True:
  77. byte = yield from reader.read(1)
  78. if byte == 0x00:
  79. break
  80. dst_addr.append(byte)
  81. else:
  82. dst_addr = socket.inet_ntoa(dst_addr)
  83. cl_reader, cl_writer = yield from asyncio.open_connection(
  84. host=dst_addr, port=dst_port, loop=self._loop
  85. )
  86. transports.append(cl_writer)
  87. cl_fut = ensure_future(
  88. self.retranslator(reader, cl_writer), loop=self._loop)
  89. dst_fut = ensure_future(
  90. self.retranslator(cl_reader, writer), loop=self._loop)
  91. futures.append(cl_fut)
  92. futures.append(dst_fut)
  93. @asyncio.coroutine
  94. def retranslator(self, reader, writer):
  95. data = bytearray()
  96. while True:
  97. try:
  98. byte = yield from reader.read(1)
  99. if not byte:
  100. break
  101. data.append(byte[0])
  102. writer.write(byte)
  103. yield from writer.drain()
  104. except:
  105. break
  106. def run(_fut):
  107. thread_loop = asyncio.new_event_loop()
  108. asyncio.set_event_loop(thread_loop)
  109. srv_coroutine = thread_loop.create_server(
  110. lambda: Socks4Protocol(thread_loop), '127.0.0.1', port)
  111. srv = thread_loop.run_until_complete(srv_coroutine)
  112. waiter = asyncio.Future(loop=thread_loop)
  113. loop.call_soon_threadsafe(
  114. _fut.set_result, (thread_loop, waiter))
  115. try:
  116. thread_loop.run_until_complete(waiter)
  117. finally:
  118. # close opened transports
  119. for tr in transports:
  120. tr.close()
  121. for ft in futures:
  122. if not ft.done():
  123. ft.set_result(1)
  124. srv.close()
  125. thread_loop.stop()
  126. thread_loop.close()
  127. gc.collect()
  128. fut = asyncio.Future(loop=loop)
  129. srv_thread = threading.Thread(target=run, args=(fut,))
  130. srv_thread.start()
  131. _thread_loop, _waiter = loop.run_until_complete(fut)
  132. yield port
  133. _thread_loop.call_soon_threadsafe(_waiter.set_result, None)
  134. srv_thread.join()
  135. @contextlib.contextmanager
  136. def http_srv(loop, *, listen_addr=('127.0.0.1', 0), use_ssl=False):
  137. transports = []
  138. class TestHttpServer(ServerHttpProtocol):
  139. def connection_made(self, transport):
  140. transports.append(transport)
  141. super().connection_made(transport)
  142. @asyncio.coroutine
  143. def handle_request(self, message, payload):
  144. response = aiohttp.Response(self.writer, 200, message.version)
  145. text = b'Test message'
  146. response.add_header('Content-type', 'text/plain')
  147. response.add_header('Content-length', str(len(text)))
  148. response.send_headers()
  149. response.write(text)
  150. response.write_eof()
  151. if use_ssl:
  152. here = os.path.join(os.path.dirname(__file__), '..', 'tests')
  153. keyfile = os.path.join(here, 'sample.key')
  154. certfile = os.path.join(here, 'sample.crt')
  155. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
  156. sslcontext.load_cert_chain(certfile, keyfile)
  157. else:
  158. sslcontext = None
  159. def run(_fut):
  160. thread_loop = asyncio.new_event_loop()
  161. asyncio.set_event_loop(thread_loop)
  162. host, port = listen_addr
  163. srv_coroutine = thread_loop.create_server(
  164. lambda: TestHttpServer(), host, port, ssl=sslcontext)
  165. srv = thread_loop.run_until_complete(srv_coroutine)
  166. waiter = asyncio.Future(loop=thread_loop)
  167. loop.call_soon_threadsafe(
  168. _fut.set_result, (thread_loop, waiter,
  169. srv.sockets[0].getsockname()))
  170. try:
  171. thread_loop.run_until_complete(waiter)
  172. finally:
  173. # close opened transports
  174. for tr in transports:
  175. tr.close()
  176. srv.close()
  177. thread_loop.stop()
  178. thread_loop.close()
  179. gc.collect()
  180. fut = asyncio.Future(loop=loop)
  181. srv_thread = threading.Thread(target=run, args=(fut,))
  182. srv_thread.start()
  183. _thread_loop, _waiter, _addr = loop.run_until_complete(fut)
  184. url = '{}://{}:{}'.format(
  185. 'https' if use_ssl else 'http', *_addr)
  186. yield url
  187. _thread_loop.call_soon_threadsafe(_waiter.set_result, None)
  188. srv_thread.join()