diff --git a/aiosocks/protocols.py b/aiosocks/protocols.py index bd7cbcf..c5ff3a7 100644 --- a/aiosocks/protocols.py +++ b/aiosocks/protocols.py @@ -19,7 +19,7 @@ except ImportError: class BaseSocksProtocol(asyncio.StreamReaderProtocol): def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, remote_resolve=True, loop=None, ssl=False, - server_hostname=None): + server_hostname=None, negotiate_done_cb=None): if not isinstance(dst, (tuple, list)) or len(dst) != 2: raise ValueError( 'Invalid dst format, tuple("dst_host", dst_port))' @@ -29,13 +29,14 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): self._auth = proxy_auth self._dst_host, self._dst_port = dst self._remote_resolve = remote_resolve - - self._loop = loop or asyncio.get_event_loop() - self._transport = None self._waiter = waiter - self._negotiate_fut = None self._ssl = ssl self._server_hostname = server_hostname + self._negotiate_done_cb = negotiate_done_cb + self._loop = loop or asyncio.get_event_loop() + + self._transport = None + self._negotiate_done = False if app_protocol_factory: self._app_protocol = app_protocol_factory() @@ -44,7 +45,41 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): reader = asyncio.StreamReader(loop=self._loop) - super().__init__(stream_reader=reader, loop=self._loop) + super().__init__(stream_reader=reader, + client_connected_cb=self.negotiate, loop=self._loop) + + @asyncio.coroutine + def negotiate(self, reader, writer): + try: + req = self.socks_request(c.SOCKS_CMD_CONNECT) + self._proxy_peername, self._proxy_sockname = yield from req + except SocksError as exc: + exc = SocksError('Can not connect to %s:%s. %s' % + (self._dst_host, self._dst_port, exc)) + self._loop.call_soon(self._waiter.set_exception, exc) + except Exception as exc: + self._loop.call_soon(self._waiter.set_exception, exc) + else: + self._negotiate_done = True + + if self._ssl: + sock = self._transport.get_extra_info('socket') + + self._transport = self._loop._make_ssl_transport( + rawsock=sock, protocol=self._app_protocol, + sslcontext=self._ssl, server_side=False, + server_hostname=self._server_hostname, + waiter=self._waiter) + else: + self._loop.call_soon(self._app_protocol.connection_made, + self._transport) + self._loop.call_soon(self._waiter.set_result, True) + + if self._negotiate_done_cb is not None: + res = self._negotiate_done_cb(reader, writer) + + if asyncio.iscoroutine(res): + asyncio.Task(res, loop=self._loop) def connection_made(self, transport): # connection_made is called @@ -54,55 +89,31 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): super().connection_made(transport) self._transport = transport - def init_app_protocol(fut): - exc = fut.exception() - if exc: - if isinstance(exc, SocksError): - exc = SocksError('Can not connect to %s:%s. %s' % - (self._dst_host, self._dst_port, exc)) - self._waiter.set_exception(exc) - else: - if self._ssl: - sock = self._transport.get_extra_info('socket') - - self._transport = self._loop._make_ssl_transport( - rawsock=sock, protocol=self._app_protocol, - sslcontext=self._ssl, server_side=False, - server_hostname=self._server_hostname, - waiter=self._waiter) - else: - self._app_protocol.connection_made(transport) - self._waiter.set_result(True) - - req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) - self._negotiate_fut = ensure_future(req_coro, loop=self._loop) - self._negotiate_fut.add_done_callback(init_app_protocol) - def connection_lost(self, exc): - if self._negotiate_fut.done() and not self._negotiate_fut.exception(): + if self._negotiate_done and self._app_protocol is not self: self._loop.call_soon(self._app_protocol.connection_lost, exc) super().connection_lost(exc) def pause_writing(self): - if self._negotiate_fut.done(): + if self._negotiate_done and self._app_protocol is not self: self._app_protocol.pause_writing() else: super().pause_writing() def resume_writing(self): - if self._negotiate_fut.done(): + if self._negotiate_done and self._app_protocol is not self: self._app_protocol.resume_writing() else: super().resume_writing() def data_received(self, data): - if self._negotiate_fut.done(): + if self._negotiate_done and self._app_protocol is not self: self._app_protocol.data_received(data) else: super().data_received(data) def eof_received(self): - if self._negotiate_fut.done() and not self._negotiate_fut.exception(): + if self._negotiate_done and self._app_protocol is not self: self._app_protocol.eof_received() super().eof_received() @@ -120,8 +131,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): bdata += item else: raise ValueError('Unsupported item') - - self._transport.write(bdata) + self._stream_writer.write(bdata) @asyncio.coroutine def read_response(self, n):