diff --git a/aiosocks/protocols.py b/aiosocks/protocols.py index ca53747..6b1be69 100644 --- a/aiosocks/protocols.py +++ b/aiosocks/protocols.py @@ -150,7 +150,11 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): @asyncio.coroutine def read_response(self, n): - return (yield from self._stream_reader.readexactly(n)) + try: + return (yield from self._stream_reader.readexactly(n)) + except asyncio.IncompleteReadError as e: + raise InvalidServerReply( + 'Server sent fewer bytes than required (%s)' % str(e)) @asyncio.coroutine def _get_dst_addr(self): diff --git a/tests/test_protocols.py b/tests/test_protocols.py index c6b0226..c451be2 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -310,6 +310,15 @@ class TestBaseSocksProtocol(unittest.TestCase): loop=self.loop) self.assertEqual(proto.reader._limit, 15) + def test_incomplete_error(self): + proto = BaseSocksProtocol(None, None, ('python.org', 80), + None, None, reader_limit=10, + loop=self.loop) + proto._stream_reader.readexactly = fake_coroutine( + asyncio.IncompleteReadError(b'part', 5)) + with self.assertRaises(aiosocks.InvalidServerReply): + self.loop.run_until_complete(proto.read_response(4)) + class TestSocks4Protocol(unittest.TestCase): def setUp(self):