| @@ -19,7 +19,7 @@ except ImportError: | |||||
| class BaseSocksProtocol(asyncio.StreamReaderProtocol): | class BaseSocksProtocol(asyncio.StreamReaderProtocol): | ||||
| def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, | ||||
| remote_resolve=True, loop=None, ssl=False, | remote_resolve=True, loop=None, ssl=False, | ||||
| server_hostname=None): | |||||
| server_hostname=None, negotiate_done_cb=None): | |||||
| if not isinstance(dst, (tuple, list)) or len(dst) != 2: | if not isinstance(dst, (tuple, list)) or len(dst) != 2: | ||||
| raise ValueError( | raise ValueError( | ||||
| 'Invalid dst format, tuple("dst_host", dst_port))' | 'Invalid dst format, tuple("dst_host", dst_port))' | ||||
| @@ -29,13 +29,14 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| self._auth = proxy_auth | self._auth = proxy_auth | ||||
| self._dst_host, self._dst_port = dst | self._dst_host, self._dst_port = dst | ||||
| self._remote_resolve = remote_resolve | self._remote_resolve = remote_resolve | ||||
| self._loop = loop or asyncio.get_event_loop() | |||||
| self._transport = None | |||||
| self._waiter = waiter | self._waiter = waiter | ||||
| self._negotiate_fut = None | |||||
| self._ssl = ssl | self._ssl = ssl | ||||
| self._server_hostname = server_hostname | self._server_hostname = server_hostname | ||||
| self._negotiate_done_cb = negotiate_done_cb | |||||
| self._loop = loop or asyncio.get_event_loop() | |||||
| self._transport = None | |||||
| self._negotiate_done = False | |||||
| if app_protocol_factory: | if app_protocol_factory: | ||||
| self._app_protocol = app_protocol_factory() | self._app_protocol = app_protocol_factory() | ||||
| @@ -44,7 +45,41 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| reader = asyncio.StreamReader(loop=self._loop) | reader = asyncio.StreamReader(loop=self._loop) | ||||
| super().__init__(stream_reader=reader, loop=self._loop) | |||||
| super().__init__(stream_reader=reader, | |||||
| client_connected_cb=self.negotiate, loop=self._loop) | |||||
| @asyncio.coroutine | |||||
| def negotiate(self, reader, writer): | |||||
| try: | |||||
| req = self.socks_request(c.SOCKS_CMD_CONNECT) | |||||
| self._proxy_peername, self._proxy_sockname = yield from req | |||||
| except SocksError as exc: | |||||
| exc = SocksError('Can not connect to %s:%s. %s' % | |||||
| (self._dst_host, self._dst_port, exc)) | |||||
| self._loop.call_soon(self._waiter.set_exception, exc) | |||||
| except Exception as exc: | |||||
| self._loop.call_soon(self._waiter.set_exception, exc) | |||||
| else: | |||||
| self._negotiate_done = True | |||||
| if self._ssl: | |||||
| sock = self._transport.get_extra_info('socket') | |||||
| self._transport = self._loop._make_ssl_transport( | |||||
| rawsock=sock, protocol=self._app_protocol, | |||||
| sslcontext=self._ssl, server_side=False, | |||||
| server_hostname=self._server_hostname, | |||||
| waiter=self._waiter) | |||||
| else: | |||||
| self._loop.call_soon(self._app_protocol.connection_made, | |||||
| self._transport) | |||||
| self._loop.call_soon(self._waiter.set_result, True) | |||||
| if self._negotiate_done_cb is not None: | |||||
| res = self._negotiate_done_cb(reader, writer) | |||||
| if asyncio.iscoroutine(res): | |||||
| asyncio.Task(res, loop=self._loop) | |||||
| def connection_made(self, transport): | def connection_made(self, transport): | ||||
| # connection_made is called | # connection_made is called | ||||
| @@ -54,55 +89,31 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| super().connection_made(transport) | super().connection_made(transport) | ||||
| self._transport = transport | self._transport = transport | ||||
| def init_app_protocol(fut): | |||||
| exc = fut.exception() | |||||
| if exc: | |||||
| if isinstance(exc, SocksError): | |||||
| exc = SocksError('Can not connect to %s:%s. %s' % | |||||
| (self._dst_host, self._dst_port, exc)) | |||||
| self._waiter.set_exception(exc) | |||||
| else: | |||||
| if self._ssl: | |||||
| sock = self._transport.get_extra_info('socket') | |||||
| self._transport = self._loop._make_ssl_transport( | |||||
| rawsock=sock, protocol=self._app_protocol, | |||||
| sslcontext=self._ssl, server_side=False, | |||||
| server_hostname=self._server_hostname, | |||||
| waiter=self._waiter) | |||||
| else: | |||||
| self._app_protocol.connection_made(transport) | |||||
| self._waiter.set_result(True) | |||||
| req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) | |||||
| self._negotiate_fut = ensure_future(req_coro, loop=self._loop) | |||||
| self._negotiate_fut.add_done_callback(init_app_protocol) | |||||
| def connection_lost(self, exc): | def connection_lost(self, exc): | ||||
| if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||||
| if self._negotiate_done and self._app_protocol is not self: | |||||
| self._loop.call_soon(self._app_protocol.connection_lost, exc) | self._loop.call_soon(self._app_protocol.connection_lost, exc) | ||||
| super().connection_lost(exc) | super().connection_lost(exc) | ||||
| def pause_writing(self): | def pause_writing(self): | ||||
| if self._negotiate_fut.done(): | |||||
| if self._negotiate_done and self._app_protocol is not self: | |||||
| self._app_protocol.pause_writing() | self._app_protocol.pause_writing() | ||||
| else: | else: | ||||
| super().pause_writing() | super().pause_writing() | ||||
| def resume_writing(self): | def resume_writing(self): | ||||
| if self._negotiate_fut.done(): | |||||
| if self._negotiate_done and self._app_protocol is not self: | |||||
| self._app_protocol.resume_writing() | self._app_protocol.resume_writing() | ||||
| else: | else: | ||||
| super().resume_writing() | super().resume_writing() | ||||
| def data_received(self, data): | def data_received(self, data): | ||||
| if self._negotiate_fut.done(): | |||||
| if self._negotiate_done and self._app_protocol is not self: | |||||
| self._app_protocol.data_received(data) | self._app_protocol.data_received(data) | ||||
| else: | else: | ||||
| super().data_received(data) | super().data_received(data) | ||||
| def eof_received(self): | def eof_received(self): | ||||
| if self._negotiate_fut.done() and not self._negotiate_fut.exception(): | |||||
| if self._negotiate_done and self._app_protocol is not self: | |||||
| self._app_protocol.eof_received() | self._app_protocol.eof_received() | ||||
| super().eof_received() | super().eof_received() | ||||
| @@ -120,8 +131,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
| bdata += item | bdata += item | ||||
| else: | else: | ||||
| raise ValueError('Unsupported item') | raise ValueError('Unsupported item') | ||||
| self._transport.write(bdata) | |||||
| self._stream_writer.write(bdata) | |||||
| @asyncio.coroutine | @asyncio.coroutine | ||||
| def read_response(self, n): | def read_response(self, n): | ||||