| @@ -19,7 +19,7 @@ except ImportError: | |||
| class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
| def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | |||
| remote_resolve=True, loop=None, ssl=False, | |||
| server_hostname=None): | |||
| server_hostname=None, negotiate_done_cb=None): | |||
| if not isinstance(dst, (tuple, list)) or len(dst) != 2: | |||
| raise ValueError( | |||
| 'Invalid dst format, tuple("dst_host", dst_port))' | |||
| @@ -29,13 +29,14 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
| self._auth = proxy_auth | |||
| self._dst_host, self._dst_port = dst | |||
| self._remote_resolve = remote_resolve | |||
| self._loop = loop or asyncio.get_event_loop() | |||
| self._transport = None | |||
| self._waiter = waiter | |||
| self._negotiate_fut = None | |||
| self._ssl = ssl | |||
| self._server_hostname = server_hostname | |||
| self._negotiate_done_cb = negotiate_done_cb | |||
| self._loop = loop or asyncio.get_event_loop() | |||
| self._transport = None | |||
| self._negotiate_done = False | |||
| if app_protocol_factory: | |||
| self._app_protocol = app_protocol_factory() | |||
| @@ -44,7 +45,41 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
| reader = asyncio.StreamReader(loop=self._loop) | |||
| super().__init__(stream_reader=reader, loop=self._loop) | |||
| super().__init__(stream_reader=reader, | |||
| client_connected_cb=self.negotiate, loop=self._loop) | |||
| @asyncio.coroutine | |||
| def negotiate(self, reader, writer): | |||
| try: | |||
| req = self.socks_request(c.SOCKS_CMD_CONNECT) | |||
| self._proxy_peername, self._proxy_sockname = yield from req | |||
| except SocksError as exc: | |||
| exc = SocksError('Can not connect to %s:%s. %s' % | |||
| (self._dst_host, self._dst_port, exc)) | |||
| self._loop.call_soon(self._waiter.set_exception, exc) | |||
| except Exception as exc: | |||
| self._loop.call_soon(self._waiter.set_exception, exc) | |||
| else: | |||
| self._negotiate_done = True | |||
| if self._ssl: | |||
| sock = self._transport.get_extra_info('socket') | |||
| self._transport = self._loop._make_ssl_transport( | |||
| rawsock=sock, protocol=self._app_protocol, | |||
| sslcontext=self._ssl, server_side=False, | |||
| server_hostname=self._server_hostname, | |||
| waiter=self._waiter) | |||
| else: | |||
| self._loop.call_soon(self._app_protocol.connection_made, | |||
| self._transport) | |||
| self._loop.call_soon(self._waiter.set_result, True) | |||
| if self._negotiate_done_cb is not None: | |||
| res = self._negotiate_done_cb(reader, writer) | |||
| if asyncio.iscoroutine(res): | |||
| asyncio.Task(res, loop=self._loop) | |||
| def connection_made(self, transport): | |||
| # connection_made is called | |||
| @@ -54,55 +89,31 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
| super().connection_made(transport) | |||
| self._transport = transport | |||
| def init_app_protocol(fut): | |||
| exc = fut.exception() | |||
| if exc: | |||
| if isinstance(exc, SocksError): | |||
| exc = SocksError('Can not connect to %s:%s. %s' % | |||
| (self._dst_host, self._dst_port, exc)) | |||
| self._waiter.set_exception(exc) | |||
| else: | |||
| if self._ssl: | |||
| sock = self._transport.get_extra_info('socket') | |||
| self._transport = self._loop._make_ssl_transport( | |||
| rawsock=sock, protocol=self._app_protocol, | |||
| sslcontext=self._ssl, server_side=False, | |||
| server_hostname=self._server_hostname, | |||
| waiter=self._waiter) | |||
| else: | |||
| self._app_protocol.connection_made(transport) | |||
| self._waiter.set_result(True) | |||
| req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | |||
| self._negotiate_fut = ensure_future(req_coro, loop=self._loop) | |||
| self._negotiate_fut.add_done_callback(init_app_protocol) | |||
| def connection_lost(self, exc): | |||
| if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||
| if self._negotiate_done and self._app_protocol is not self: | |||
| self._loop.call_soon(self._app_protocol.connection_lost, exc) | |||
| super().connection_lost(exc) | |||
| def pause_writing(self): | |||
| if self._negotiate_fut.done(): | |||
| if self._negotiate_done and self._app_protocol is not self: | |||
| self._app_protocol.pause_writing() | |||
| else: | |||
| super().pause_writing() | |||
| def resume_writing(self): | |||
| if self._negotiate_fut.done(): | |||
| if self._negotiate_done and self._app_protocol is not self: | |||
| self._app_protocol.resume_writing() | |||
| else: | |||
| super().resume_writing() | |||
| def data_received(self, data): | |||
| if self._negotiate_fut.done(): | |||
| if self._negotiate_done and self._app_protocol is not self: | |||
| self._app_protocol.data_received(data) | |||
| else: | |||
| super().data_received(data) | |||
| def eof_received(self): | |||
| if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||
| if self._negotiate_done and self._app_protocol is not self: | |||
| self._app_protocol.eof_received() | |||
| super().eof_received() | |||
| @@ -120,8 +131,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
| bdata += item | |||
| else: | |||
| raise ValueError('Unsupported item') | |||
| self._transport.write(bdata) | |||
| self._stream_writer.write(bdata) | |||
| @asyncio.coroutine | |||
| def read_response(self, n): | |||