An attempt at adding UDP support to aiosocks. Untested due to lack of server support.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

398 lines
14 KiB

  1. import asyncio
  2. import socket
  3. import struct
  4. from . import constants as c
  5. from .helpers import (
  6. Socks4Addr, Socks5Addr, Socks5Auth, Socks4Auth
  7. )
  8. from .errors import (
  9. SocksError, NoAcceptableAuthMethods, LoginAuthenticationFailed,
  10. InvalidServerReply, InvalidServerVersion
  11. )
  12. try:
  13. from asyncio import ensure_future
  14. except ImportError:
  15. ensure_future = asyncio.async
  16. DEFAULT_LIMIT = getattr(asyncio.streams, '_DEFAULT_LIMIT', 2**16)
  17. class BaseSocksProtocol(asyncio.StreamReaderProtocol):
  18. def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, *,
  19. remote_resolve=True, loop=None, ssl=False,
  20. server_hostname=None, negotiate_done_cb=None,
  21. reader_limit=DEFAULT_LIMIT):
  22. if not isinstance(dst, (tuple, list)) or len(dst) != 2:
  23. raise ValueError(
  24. 'Invalid dst format, tuple("dst_host", dst_port))'
  25. )
  26. self._proxy = proxy
  27. self._auth = proxy_auth
  28. self._dst_host, self._dst_port = dst
  29. self._remote_resolve = remote_resolve
  30. self._waiter = waiter
  31. self._ssl = ssl
  32. self._server_hostname = server_hostname
  33. self._negotiate_done_cb = negotiate_done_cb
  34. self._loop = loop or asyncio.get_event_loop()
  35. self._transport = None
  36. self._negotiate_done = False
  37. self._proxy_peername = None
  38. self._proxy_sockname = None
  39. if app_protocol_factory:
  40. self._app_protocol = app_protocol_factory()
  41. else:
  42. self._app_protocol = self
  43. reader = asyncio.StreamReader(loop=self._loop, limit=reader_limit)
  44. super().__init__(stream_reader=reader,
  45. client_connected_cb=self.negotiate, loop=self._loop)
  46. @asyncio.coroutine
  47. def negotiate(self, reader, writer):
  48. try:
  49. req = self.socks_request(c.SOCKS_CMD_CONNECT)
  50. self._proxy_peername, self._proxy_sockname = yield from req
  51. except SocksError as exc:
  52. exc = SocksError('Can not connect to %s:%s. %s' %
  53. (self._dst_host, self._dst_port, exc))
  54. self._loop.call_soon(self._waiter.set_exception, exc)
  55. except Exception as exc:
  56. self._loop.call_soon(self._waiter.set_exception, exc)
  57. else:
  58. self._negotiate_done = True
  59. if self._ssl:
  60. # Creating a ssl transport needs to be reworked.
  61. # See details: http://bugs.python.org/issue23749
  62. sock = self._transport.get_extra_info('socket')
  63. # temporary fix:
  64. self._transport.pause_reading()
  65. self._transport._closing = True
  66. self._transport._sock = None
  67. self._transport._protocol = None
  68. self._transport._loop = None
  69. self._transport = self._loop._make_ssl_transport(
  70. rawsock=sock, protocol=self._app_protocol,
  71. sslcontext=self._ssl, server_side=False,
  72. server_hostname=self._server_hostname,
  73. waiter=self._waiter)
  74. else:
  75. self._loop.call_soon(self._app_protocol.connection_made,
  76. self._transport)
  77. self._loop.call_soon(self._waiter.set_result, True)
  78. if self._negotiate_done_cb is not None:
  79. res = self._negotiate_done_cb(reader, writer)
  80. if asyncio.iscoroutine(res):
  81. asyncio.Task(res, loop=self._loop)
  82. def connection_made(self, transport):
  83. # connection_made is called
  84. if self._transport:
  85. return
  86. super().connection_made(transport)
  87. self._transport = transport
  88. def connection_lost(self, exc):
  89. if self._negotiate_done and self._app_protocol is not self:
  90. self._loop.call_soon(self._app_protocol.connection_lost, exc)
  91. super().connection_lost(exc)
  92. def pause_writing(self):
  93. if self._negotiate_done and self._app_protocol is not self:
  94. self._app_protocol.pause_writing()
  95. else:
  96. super().pause_writing()
  97. def resume_writing(self):
  98. if self._negotiate_done and self._app_protocol is not self:
  99. self._app_protocol.resume_writing()
  100. else:
  101. super().resume_writing()
  102. def data_received(self, data):
  103. if self._negotiate_done and self._app_protocol is not self:
  104. self._app_protocol.data_received(data)
  105. else:
  106. super().data_received(data)
  107. def eof_received(self):
  108. if self._negotiate_done and self._app_protocol is not self:
  109. self._app_protocol.eof_received()
  110. super().eof_received()
  111. @asyncio.coroutine
  112. def socks_request(self, cmd):
  113. raise NotImplementedError
  114. def write_request(self, request):
  115. bdata = bytearray()
  116. for item in request:
  117. if isinstance(item, int):
  118. bdata.append(item)
  119. elif isinstance(item, (bytearray, bytes)):
  120. bdata += item
  121. else:
  122. raise ValueError('Unsupported item')
  123. self._stream_writer.write(bdata)
  124. @asyncio.coroutine
  125. def read_response(self, n):
  126. return (yield from self._stream_reader.read(n))
  127. @asyncio.coroutine
  128. def _get_dst_addr(self):
  129. infos = yield from self._loop.getaddrinfo(
  130. self._dst_host, self._dst_port, family=socket.AF_UNSPEC,
  131. type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP,
  132. flags=socket.AI_ADDRCONFIG)
  133. if not infos:
  134. raise OSError('getaddrinfo() returned empty list')
  135. return infos[0][0], infos[0][4][0]
  136. @property
  137. def app_protocol(self):
  138. return self._app_protocol
  139. @property
  140. def app_transport(self):
  141. return self._transport
  142. @property
  143. def proxy_sockname(self):
  144. """
  145. Returns the bound IP address and port number at the proxy.
  146. """
  147. return self._proxy_sockname
  148. @property
  149. def proxy_peername(self):
  150. """
  151. Returns the IP and port number of the proxy.
  152. """
  153. sock = self._transport.get_extra_info('socket')
  154. return sock.peername if sock else None
  155. @property
  156. def peername(self):
  157. """
  158. Returns the IP address and port number of the destination
  159. machine (note: get_proxy_peername returns the proxy)
  160. """
  161. return self._proxy_peername
  162. @property
  163. def reader(self):
  164. return self._stream_reader
  165. @property
  166. def writer(self):
  167. return self._stream_writer
  168. class Socks4Protocol(BaseSocksProtocol):
  169. def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
  170. remote_resolve=True, loop=None, ssl=False,
  171. server_hostname=None, negotiate_done_cb=None,
  172. reader_limit=DEFAULT_LIMIT):
  173. proxy_auth = proxy_auth or Socks4Auth('')
  174. if not isinstance(proxy, Socks4Addr):
  175. raise ValueError('Invalid proxy format')
  176. if not isinstance(proxy_auth, Socks4Auth):
  177. raise ValueError('Invalid proxy_auth format')
  178. super().__init__(proxy, proxy_auth, dst, app_protocol_factory,
  179. waiter, remote_resolve=remote_resolve, loop=loop,
  180. ssl=ssl, server_hostname=server_hostname,
  181. reader_limit=reader_limit,
  182. negotiate_done_cb=negotiate_done_cb)
  183. @asyncio.coroutine
  184. def socks_request(self, cmd):
  185. # prepare destination addr/port
  186. host, port = self._dst_host, self._dst_port
  187. port_bytes = struct.pack(b'>H', port)
  188. include_hostname = False
  189. try:
  190. host_bytes = socket.inet_aton(host)
  191. except socket.error:
  192. if self._remote_resolve:
  193. host_bytes = bytes([c.NULL, c.NULL, c.NULL, 0x01])
  194. include_hostname = True
  195. else:
  196. # it's not an IP number, so it's probably a DNS name.
  197. family, host = yield from self._get_dst_addr()
  198. host_bytes = socket.inet_aton(host)
  199. # build and send connect command
  200. req = [c.SOCKS_VER4, cmd, port_bytes,
  201. host_bytes, self._auth.login, c.NULL]
  202. if include_hostname:
  203. req += [self._dst_host.encode('idna'), c.NULL]
  204. self.write_request(req)
  205. # read/process result
  206. resp = yield from self.read_response(8)
  207. if resp[0] != c.NULL:
  208. raise InvalidServerReply('SOCKS4 proxy server sent invalid data')
  209. if resp[1] != c.SOCKS4_GRANTED:
  210. error = c.SOCKS4_ERRORS.get(resp[1], 'Unknown error')
  211. raise SocksError('[Errno {0:#04x}]: {1}'.format(resp[1], error))
  212. binded = socket.inet_ntoa(resp[4:]), struct.unpack('>H', resp[2:4])[0]
  213. return (host, port), binded
  214. class Socks5Protocol(BaseSocksProtocol):
  215. def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
  216. remote_resolve=True, loop=None, ssl=False,
  217. server_hostname=None, negotiate_done_cb=None,
  218. reader_limit=DEFAULT_LIMIT):
  219. proxy_auth = proxy_auth or Socks5Auth('', '')
  220. if not isinstance(proxy, Socks5Addr):
  221. raise ValueError('Invalid proxy format')
  222. if not isinstance(proxy_auth, Socks5Auth):
  223. raise ValueError('Invalid proxy_auth format')
  224. super().__init__(proxy, proxy_auth, dst, app_protocol_factory,
  225. waiter, remote_resolve=remote_resolve, loop=loop,
  226. ssl=ssl, server_hostname=server_hostname,
  227. reader_limit=reader_limit,
  228. negotiate_done_cb=negotiate_done_cb)
  229. @asyncio.coroutine
  230. def socks_request(self, cmd):
  231. yield from self.authenticate()
  232. # build and send command
  233. dst_addr, resolved = yield from self.build_dst_address(
  234. self._dst_host, self._dst_port)
  235. self.write_request([c.SOCKS_VER5, cmd, c.RSV] + dst_addr)
  236. # read/process command response
  237. resp = yield from self.read_response(3)
  238. if resp[0] != c.SOCKS_VER5:
  239. raise InvalidServerVersion(
  240. 'SOCKS5 proxy server sent invalid version'
  241. )
  242. if resp[1] != c.SOCKS5_GRANTED:
  243. error = c.SOCKS5_ERRORS.get(resp[1], 'Unknown error')
  244. raise SocksError('[Errno {0:#04x}]: {1}'.format(resp[1], error))
  245. binded = yield from self.read_address()
  246. return resolved, binded
  247. @asyncio.coroutine
  248. def authenticate(self):
  249. # send available auth methods
  250. if self._auth.login and self._auth.password:
  251. req = [c.SOCKS_VER5, 0x02,
  252. c.SOCKS5_AUTH_ANONYMOUS, c.SOCKS5_AUTH_UNAME_PWD]
  253. else:
  254. req = [c.SOCKS_VER5, 0x01, c.SOCKS5_AUTH_ANONYMOUS]
  255. self.write_request(req)
  256. # read/process response and send auth data if necessary
  257. chosen_auth = yield from self.read_response(2)
  258. if chosen_auth[0] != c.SOCKS_VER5:
  259. raise InvalidServerVersion(
  260. 'SOCKS5 proxy server sent invalid version'
  261. )
  262. if chosen_auth[1] == c.SOCKS5_AUTH_UNAME_PWD:
  263. req = [0x01, chr(len(self._auth.login)).encode(), self._auth.login,
  264. chr(len(self._auth.password)).encode(), self._auth.password]
  265. self.write_request(req)
  266. auth_status = yield from self.read_response(2)
  267. if auth_status[0] != 0x01:
  268. raise InvalidServerReply(
  269. 'SOCKS5 proxy server sent invalid data'
  270. )
  271. if auth_status[1] != c.SOCKS5_GRANTED:
  272. raise LoginAuthenticationFailed(
  273. "SOCKS5 authentication failed"
  274. )
  275. # offered auth methods rejected
  276. elif chosen_auth[1] != c.SOCKS5_AUTH_ANONYMOUS:
  277. if chosen_auth[1] == c.SOCKS5_AUTH_NO_ACCEPTABLE_METHODS:
  278. raise NoAcceptableAuthMethods(
  279. 'All offered SOCKS5 authentication methods were rejected'
  280. )
  281. else:
  282. raise InvalidServerReply(
  283. 'SOCKS5 proxy server sent invalid data'
  284. )
  285. @asyncio.coroutine
  286. def build_dst_address(self, host, port):
  287. family_to_byte = {socket.AF_INET: c.SOCKS5_ATYP_IPv4,
  288. socket.AF_INET6: c.SOCKS5_ATYP_IPv6}
  289. port_bytes = struct.pack('>H', port)
  290. # if the given destination address is an IP address, we will
  291. # use the IP address request even if remote resolving was specified.
  292. for family in (socket.AF_INET, socket.AF_INET6):
  293. try:
  294. host_bytes = socket.inet_pton(family, host)
  295. req = [family_to_byte[family], host_bytes, port_bytes]
  296. return req, (host, port)
  297. except socket.error:
  298. pass
  299. # it's not an IP number, so it's probably a DNS name.
  300. if self._remote_resolve:
  301. host_bytes = host.encode('idna')
  302. req = [c.SOCKS5_ATYP_DOMAIN, chr(len(host_bytes)).encode(),
  303. host_bytes, port_bytes]
  304. else:
  305. family, host_bytes = yield from self._get_dst_addr()
  306. host_bytes = socket.inet_pton(family, host_bytes)
  307. req = [family_to_byte[family], host_bytes, port_bytes]
  308. host = socket.inet_ntop(family, host_bytes)
  309. return req, (host, port)
  310. @asyncio.coroutine
  311. def read_address(self):
  312. atype = yield from self.read_response(1)
  313. if atype[0] == c.SOCKS5_ATYP_IPv4:
  314. addr = socket.inet_ntoa((yield from self.read_response(4)))
  315. elif atype[0] == c.SOCKS5_ATYP_DOMAIN:
  316. length = yield from self.read_response(1)
  317. addr = yield from self.read_response(ord(length))
  318. elif atype[0] == c.SOCKS5_ATYP_IPv6:
  319. addr = yield from self.read_response(16)
  320. addr = socket.inet_ntop(socket.AF_INET6, addr)
  321. else:
  322. raise InvalidServerReply('SOCKS5 proxy server sent invalid data')
  323. port = yield from self.read_response(2)
  324. port = struct.unpack('>H', port)[0]
  325. return addr, port