Browse Source

Use client_connected_cb for negotiating with socks server. Added negotiate_done_cb callback

main
nibrag 8 years ago
parent
commit
dd6c30c8be
1 changed files with 47 additions and 37 deletions
  1. +47
    -37
      aiosocks/protocols.py

+ 47
- 37
aiosocks/protocols.py View File

@@ -19,7 +19,7 @@ except ImportError:
class BaseSocksProtocol(asyncio.StreamReaderProtocol): class BaseSocksProtocol(asyncio.StreamReaderProtocol):
def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter, def __init__(self, proxy, proxy_auth, dst, app_protocol_factory, waiter,
remote_resolve=True, loop=None, ssl=False, remote_resolve=True, loop=None, ssl=False,
server_hostname=None):
server_hostname=None, negotiate_done_cb=None):
if not isinstance(dst, (tuple, list)) or len(dst) != 2: if not isinstance(dst, (tuple, list)) or len(dst) != 2:
raise ValueError( raise ValueError(
'Invalid dst format, tuple("dst_host", dst_port))' 'Invalid dst format, tuple("dst_host", dst_port))'
@@ -29,13 +29,14 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):
self._auth = proxy_auth self._auth = proxy_auth
self._dst_host, self._dst_port = dst self._dst_host, self._dst_port = dst
self._remote_resolve = remote_resolve self._remote_resolve = remote_resolve

self._loop = loop or asyncio.get_event_loop()
self._transport = None
self._waiter = waiter self._waiter = waiter
self._negotiate_fut = None
self._ssl = ssl self._ssl = ssl
self._server_hostname = server_hostname self._server_hostname = server_hostname
self._negotiate_done_cb = negotiate_done_cb
self._loop = loop or asyncio.get_event_loop()

self._transport = None
self._negotiate_done = False


if app_protocol_factory: if app_protocol_factory:
self._app_protocol = app_protocol_factory() self._app_protocol = app_protocol_factory()
@@ -44,7 +45,41 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):


reader = asyncio.StreamReader(loop=self._loop) reader = asyncio.StreamReader(loop=self._loop)


super().__init__(stream_reader=reader, loop=self._loop)
super().__init__(stream_reader=reader,
client_connected_cb=self.negotiate, loop=self._loop)

@asyncio.coroutine
def negotiate(self, reader, writer):
try:
req = self.socks_request(c.SOCKS_CMD_CONNECT)
self._proxy_peername, self._proxy_sockname = yield from req
except SocksError as exc:
exc = SocksError('Can not connect to %s:%s. %s' %
(self._dst_host, self._dst_port, exc))
self._loop.call_soon(self._waiter.set_exception, exc)
except Exception as exc:
self._loop.call_soon(self._waiter.set_exception, exc)
else:
self._negotiate_done = True

if self._ssl:
sock = self._transport.get_extra_info('socket')

self._transport = self._loop._make_ssl_transport(
rawsock=sock, protocol=self._app_protocol,
sslcontext=self._ssl, server_side=False,
server_hostname=self._server_hostname,
waiter=self._waiter)
else:
self._loop.call_soon(self._app_protocol.connection_made,
self._transport)
self._loop.call_soon(self._waiter.set_result, True)

if self._negotiate_done_cb is not None:
res = self._negotiate_done_cb(reader, writer)

if asyncio.iscoroutine(res):
asyncio.Task(res, loop=self._loop)


def connection_made(self, transport): def connection_made(self, transport):
# connection_made is called # connection_made is called
@@ -54,55 +89,31 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):
super().connection_made(transport) super().connection_made(transport)
self._transport = transport self._transport = transport


def init_app_protocol(fut):
exc = fut.exception()
if exc:
if isinstance(exc, SocksError):
exc = SocksError('Can not connect to %s:%s. %s' %
(self._dst_host, self._dst_port, exc))
self._waiter.set_exception(exc)
else:
if self._ssl:
sock = self._transport.get_extra_info('socket')

self._transport = self._loop._make_ssl_transport(
rawsock=sock, protocol=self._app_protocol,
sslcontext=self._ssl, server_side=False,
server_hostname=self._server_hostname,
waiter=self._waiter)
else:
self._app_protocol.connection_made(transport)
self._waiter.set_result(True)

req_coro = self.socks_request(c.SOCKS_CMD_CONNECT)
self._negotiate_fut = ensure_future(req_coro, loop=self._loop)
self._negotiate_fut.add_done_callback(init_app_protocol)

def connection_lost(self, exc): def connection_lost(self, exc):
if self._negotiate_fut.done() and not self._negotiate_fut.exception():
if self._negotiate_done and self._app_protocol is not self:
self._loop.call_soon(self._app_protocol.connection_lost, exc) self._loop.call_soon(self._app_protocol.connection_lost, exc)
super().connection_lost(exc) super().connection_lost(exc)


def pause_writing(self): def pause_writing(self):
if self._negotiate_fut.done():
if self._negotiate_done and self._app_protocol is not self:
self._app_protocol.pause_writing() self._app_protocol.pause_writing()
else: else:
super().pause_writing() super().pause_writing()


def resume_writing(self): def resume_writing(self):
if self._negotiate_fut.done():
if self._negotiate_done and self._app_protocol is not self:
self._app_protocol.resume_writing() self._app_protocol.resume_writing()
else: else:
super().resume_writing() super().resume_writing()


def data_received(self, data): def data_received(self, data):
if self._negotiate_fut.done():
if self._negotiate_done and self._app_protocol is not self:
self._app_protocol.data_received(data) self._app_protocol.data_received(data)
else: else:
super().data_received(data) super().data_received(data)


def eof_received(self): def eof_received(self):
if self._negotiate_fut.done() and not self._negotiate_fut.exception():
if self._negotiate_done and self._app_protocol is not self:
self._app_protocol.eof_received() self._app_protocol.eof_received()
super().eof_received() super().eof_received()


@@ -120,8 +131,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol):
bdata += item bdata += item
else: else:
raise ValueError('Unsupported item') raise ValueError('Unsupported item')

self._transport.write(bdata)
self._stream_writer.write(bdata)


@asyncio.coroutine @asyncio.coroutine
def read_response(self, n): def read_response(self, n):


Loading…
Cancel
Save