|
|
@@ -19,7 +19,7 @@ except ImportError: |
|
|
|
class BaseSocksProtocol(asyncio.StreamReaderProtocol): |
|
|
|
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, |
|
|
|
remote_resolve=True, loop=None, ssl=False, |
|
|
|
server_hostname=None): |
|
|
|
server_hostname=None, negotiate_done_cb=None): |
|
|
|
if not isinstance(dst, (tuple, list)) or len(dst) != 2: |
|
|
|
raise ValueError( |
|
|
|
'Invalid dst format, tuple("dst_host", dst_port))' |
|
|
@@ -29,13 +29,14 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): |
|
|
|
self._auth = proxy_auth |
|
|
|
self._dst_host, self._dst_port = dst |
|
|
|
self._remote_resolve = remote_resolve |
|
|
|
|
|
|
|
self._loop = loop or asyncio.get_event_loop() |
|
|
|
self._transport = None |
|
|
|
self._waiter = waiter |
|
|
|
self._negotiate_fut = None |
|
|
|
self._ssl = ssl |
|
|
|
self._server_hostname = server_hostname |
|
|
|
self._negotiate_done_cb = negotiate_done_cb |
|
|
|
self._loop = loop or asyncio.get_event_loop() |
|
|
|
|
|
|
|
self._transport = None |
|
|
|
self._negotiate_done = False |
|
|
|
|
|
|
|
if app_protocol_factory: |
|
|
|
self._app_protocol = app_protocol_factory() |
|
|
@@ -44,7 +45,41 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): |
|
|
|
|
|
|
|
reader = asyncio.StreamReader(loop=self._loop) |
|
|
|
|
|
|
|
super().__init__(stream_reader=reader, loop=self._loop) |
|
|
|
super().__init__(stream_reader=reader, |
|
|
|
client_connected_cb=self.negotiate, loop=self._loop) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def negotiate(self, reader, writer): |
|
|
|
try: |
|
|
|
req = self.socks_request(c.SOCKS_CMD_CONNECT) |
|
|
|
self._proxy_peername, self._proxy_sockname = yield from req |
|
|
|
except SocksError as exc: |
|
|
|
exc = SocksError('Can not connect to %s:%s. %s' % |
|
|
|
(self._dst_host, self._dst_port, exc)) |
|
|
|
self._loop.call_soon(self._waiter.set_exception, exc) |
|
|
|
except Exception as exc: |
|
|
|
self._loop.call_soon(self._waiter.set_exception, exc) |
|
|
|
else: |
|
|
|
self._negotiate_done = True |
|
|
|
|
|
|
|
if self._ssl: |
|
|
|
sock = self._transport.get_extra_info('socket') |
|
|
|
|
|
|
|
self._transport = self._loop._make_ssl_transport( |
|
|
|
rawsock=sock, protocol=self._app_protocol, |
|
|
|
sslcontext=self._ssl, server_side=False, |
|
|
|
server_hostname=self._server_hostname, |
|
|
|
waiter=self._waiter) |
|
|
|
else: |
|
|
|
self._loop.call_soon(self._app_protocol.connection_made, |
|
|
|
self._transport) |
|
|
|
self._loop.call_soon(self._waiter.set_result, True) |
|
|
|
|
|
|
|
if self._negotiate_done_cb is not None: |
|
|
|
res = self._negotiate_done_cb(reader, writer) |
|
|
|
|
|
|
|
if asyncio.iscoroutine(res): |
|
|
|
asyncio.Task(res, loop=self._loop) |
|
|
|
|
|
|
|
def connection_made(self, transport): |
|
|
|
# connection_made is called |
|
|
@@ -54,55 +89,31 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): |
|
|
|
super().connection_made(transport) |
|
|
|
self._transport = transport |
|
|
|
|
|
|
|
def init_app_protocol(fut): |
|
|
|
exc = fut.exception() |
|
|
|
if exc: |
|
|
|
if isinstance(exc, SocksError): |
|
|
|
exc = SocksError('Can not connect to %s:%s. %s' % |
|
|
|
(self._dst_host, self._dst_port, exc)) |
|
|
|
self._waiter.set_exception(exc) |
|
|
|
else: |
|
|
|
if self._ssl: |
|
|
|
sock = self._transport.get_extra_info('socket') |
|
|
|
|
|
|
|
self._transport = self._loop._make_ssl_transport( |
|
|
|
rawsock=sock, protocol=self._app_protocol, |
|
|
|
sslcontext=self._ssl, server_side=False, |
|
|
|
server_hostname=self._server_hostname, |
|
|
|
waiter=self._waiter) |
|
|
|
else: |
|
|
|
self._app_protocol.connection_made(transport) |
|
|
|
self._waiter.set_result(True) |
|
|
|
|
|
|
|
req_coro = self.socks_request(c.SOCKS_CMD_CONNECT) |
|
|
|
self._negotiate_fut = ensure_future(req_coro, loop=self._loop) |
|
|
|
self._negotiate_fut.add_done_callback(init_app_protocol) |
|
|
|
|
|
|
|
def connection_lost(self, exc): |
|
|
|
if self._negotiate_fut.done() and not self._negotiate_fut.exception(): |
|
|
|
if self._negotiate_done and self._app_protocol is not self: |
|
|
|
self._loop.call_soon(self._app_protocol.connection_lost, exc) |
|
|
|
super().connection_lost(exc) |
|
|
|
|
|
|
|
def pause_writing(self): |
|
|
|
if self._negotiate_fut.done(): |
|
|
|
if self._negotiate_done and self._app_protocol is not self: |
|
|
|
self._app_protocol.pause_writing() |
|
|
|
else: |
|
|
|
super().pause_writing() |
|
|
|
|
|
|
|
def resume_writing(self): |
|
|
|
if self._negotiate_fut.done(): |
|
|
|
if self._negotiate_done and self._app_protocol is not self: |
|
|
|
self._app_protocol.resume_writing() |
|
|
|
else: |
|
|
|
super().resume_writing() |
|
|
|
|
|
|
|
def data_received(self, data): |
|
|
|
if self._negotiate_fut.done(): |
|
|
|
if self._negotiate_done and self._app_protocol is not self: |
|
|
|
self._app_protocol.data_received(data) |
|
|
|
else: |
|
|
|
super().data_received(data) |
|
|
|
|
|
|
|
def eof_received(self): |
|
|
|
if self._negotiate_fut.done() and not self._negotiate_fut.exception(): |
|
|
|
if self._negotiate_done and self._app_protocol is not self: |
|
|
|
self._app_protocol.eof_received() |
|
|
|
super().eof_received() |
|
|
|
|
|
|
@@ -120,8 +131,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): |
|
|
|
bdata += item |
|
|
|
else: |
|
|
|
raise ValueError('Unsupported item') |
|
|
|
|
|
|
|
self._transport.write(bdata) |
|
|
|
self._stream_writer.write(bdata) |
|
|
|
|
|
|
|
@asyncio.coroutine |
|
|
|
def read_response(self, n): |
|
|
|