Browse Source

Fix: some socks5 severs expect fully-formed command request

main
nibrag 8 years ago
parent
commit
e8a2b7cfba
2 changed files with 30 additions and 29 deletions
  1. +6
    -8
      aiosocks/protocols.py
  2. +24
    -21
      tests/test_protocols.py

+ 6
- 8
aiosocks/protocols.py View File

@@ -284,9 +284,9 @@ class Socks5Protocol(BaseSocksProtocol):
yield from self.authenticate()

# build and send command
self.write_request([c.SOCKS_VER5, cmd, c.RSV])
resolved = yield from self.write_address(self._dst_host,
self._dst_port)
dst_addr, resolved = yield from self.build_dst_address(
self._dst_host, self._dst_port)
self.write_request([c.SOCKS_VER5, cmd, c.RSV] + dst_addr)

# read/process command response
resp = yield from self.read_response(3)
@@ -348,7 +348,7 @@ class Socks5Protocol(BaseSocksProtocol):
)

@asyncio.coroutine
def write_address(self, host, port):
def build_dst_address(self, host, port):
family_to_byte = {socket.AF_INET: c.SOCKS5_ATYP_IPv4,
socket.AF_INET6: c.SOCKS5_ATYP_IPv6}
port_bytes = struct.pack('>H', port)
@@ -359,8 +359,7 @@ class Socks5Protocol(BaseSocksProtocol):
try:
host_bytes = socket.inet_pton(family, host)
req = [family_to_byte[family], host_bytes, port_bytes]
self.write_request(req)
return host, port
return req, (host, port)
except socket.error:
pass

@@ -375,8 +374,7 @@ class Socks5Protocol(BaseSocksProtocol):
req = [family_to_byte[family], host_bytes, port_bytes]
host = socket.inet_ntop(family, host_bytes)

self.write_request(req)
return host, port
return req, (host, port)

@asyncio.coroutine
def read_address(self):


+ 24
- 21
tests/test_protocols.py View File

@@ -53,6 +53,7 @@ def make_socks5(loop, *, addr=None, auth=None, rr=True, dst=None, r=None,
proxy=addr, proxy_auth=auth, dst=dst, remote_resolve=rr,
loop=loop, app_protocol_factory=ap_factory, waiter=whiter)
proto._stream_writer = mock.Mock()
proto._stream_writer.drain = fake_coroutine(True)

if not isinstance(r, (list, tuple)):
proto.read_response = mock.Mock(
@@ -526,37 +527,40 @@ class TestSocks5Protocol(unittest.TestCase):
req = proto.authenticate()
self.loop.run_until_complete(req)

def test_wr_addr_ipv4(self):
def test_build_dst_addr_ipv4(self):
proto = make_socks5(self.loop)
req = proto.write_address('127.0.0.1', 80)
self.loop.run_until_complete(req)
c = proto.build_dst_address('127.0.0.1', 80)
dst_req, resolved = self.loop.run_until_complete(c)

proto._stream_writer.write.assert_called_with(
b'\x01\x7f\x00\x00\x01\x00P')
self.assertEqual(dst_req, [0x01, b'\x7f\x00\x00\x01', b'\x00P'])
self.assertEqual(resolved, ('127.0.0.1', 80))

def test_wr_addr_ipv6(self):
def test_build_dst_addr_ipv6(self):
proto = make_socks5(self.loop)
req = proto.write_address(
c = proto.build_dst_address(
'2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80)
self.loop.run_until_complete(req)
dst_req, resolved = self.loop.run_until_complete(c)

proto._stream_writer.write.assert_called_with(
b'\x04 \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]\x00P')
self.assertEqual(dst_req, [
0x04, b' \x01\r\xb8\x11\xa3\t\xd7\x1f4\x8a.\x07\xa0v]', b'\x00P'])
self.assertEqual(resolved,
('2001:0db8:11a3:09d7:1f34:8a2e:07a0:765d', 80))

def test_wr_addr_domain_with_remote_resolve(self):
def test_build_dst_addr_domain_with_remote_resolve(self):
proto = make_socks5(self.loop)
req = proto.write_address('python.org', 80)
self.loop.run_until_complete(req)
c = proto.build_dst_address('python.org', 80)
dst_req, resolved = self.loop.run_until_complete(c)

proto._stream_writer.write.assert_called_with(b'\x03\npython.org\x00P')
self.assertEqual(dst_req, [0x03, b'\n', b'python.org', b'\x00P'])
self.assertEqual(resolved, ('python.org', 80))

def test_wr_addr_domain_with_locale_resolve(self):
def test_build_dst_addr_domain_with_locale_resolve(self):
proto = make_socks5(self.loop, rr=False)
req = proto.write_address('python.org', 80)
self.loop.run_until_complete(req)
c = proto.build_dst_address('python.org', 80)
dst_req, resolved = self.loop.run_until_complete(c)

proto._stream_writer.write.assert_called_with(
b'\x01\x7f\x00\x00\x01\x00P')
self.assertEqual(dst_req, [0x01, b'\x7f\x00\x00\x01', b'\x00P'])
self.assertEqual(resolved, ('127.0.0.1', 80))

def test_rd_addr_ipv4(self):
proto = make_socks5(
@@ -624,6 +628,5 @@ class TestSocks5Protocol(unittest.TestCase):
self.assertEqual(req.result(), (('python.org', 80), ('127.0.0.1', 80)))
proto._stream_writer.write.assert_has_calls([
mock.call(b'\x05\x02\x00\x02'),
mock.call(b'\x05\x01\x00'),
mock.call(b'\x03\npython.org\x00P')
mock.call(b'\x05\x01\x00\x03\npython.org\x00P')
])

Loading…
Cancel
Save