@@ -55,7 +55,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||||
async def negotiate(self, reader, writer): | async def negotiate(self, reader, writer): | ||||
try: | try: | ||||
req = self.socks_request(c.SOCKS_CMD_CONNECT) | |||||
req = self.socks_request(self.cmd) | |||||
self._proxy_peername, self._proxy_sockname = await req | self._proxy_peername, self._proxy_sockname = await req | ||||
except SocksError as exc: | except SocksError as exc: | ||||
exc = SocksError('Can not connect to %s:%s. %s' % | exc = SocksError('Can not connect to %s:%s. %s' % | ||||
@@ -4,17 +4,48 @@ import socket | |||||
from aiohttp.test_utils import unused_port | from aiohttp.test_utils import unused_port | ||||
def _asyncsockpair(): | |||||
'''Create a pair of sockets that are bound to each other. | |||||
The function will return a tuple of two coroutine's, that | |||||
each, when await'ed upon, will return the reader/writer pair.''' | |||||
socka, sockb = socket.socketpair() | |||||
return asyncio.open_connection(sock=socka), \ | |||||
asyncio.open_connection(sock=sockb) | |||||
async def _getreaderwriter(): | |||||
'''Return a reader/writer pair. Any data written | |||||
to the reader can be read from the writer side. | |||||
returns (reader, writer).''' | |||||
socka, sockb = _asyncsockpair() | |||||
ardr, awrr = await socka | |||||
brdr, bwrr = await sockb | |||||
# don't close, as it also closes the reader as well | |||||
awrr.write_eof() | |||||
return ardr, bwrr | |||||
class FakeSocksSrv: | class FakeSocksSrv: | ||||
def __init__(self, loop, write_buff): | def __init__(self, loop, write_buff): | ||||
self._loop = loop | self._loop = loop | ||||
self._write_buff = write_buff | self._write_buff = write_buff | ||||
self._transports = [] | self._transports = [] | ||||
self._srv = None | self._srv = None | ||||
self._pipes = None | |||||
self.port = unused_port() | self.port = unused_port() | ||||
def get_reader(self): | |||||
return self._pipes[0] | |||||
async def __aenter__(self): | async def __aenter__(self): | ||||
transports = self._transports | transports = self._transports | ||||
write_buff = self._write_buff | write_buff = self._write_buff | ||||
pipes = await _getreaderwriter() | |||||
self._pipes = pipes | |||||
class SocksPrimitiveProtocol(asyncio.Protocol): | class SocksPrimitiveProtocol(asyncio.Protocol): | ||||
_transport = None | _transport = None | ||||
@@ -24,6 +55,7 @@ class FakeSocksSrv: | |||||
transports.append(transport) | transports.append(transport) | ||||
def data_received(self, data): | def data_received(self, data): | ||||
pipes[1].write(data) | |||||
self._transport.write(write_buff) | self._transport.write(write_buff) | ||||
def factory(): | def factory(): | ||||
@@ -39,7 +71,14 @@ class FakeSocksSrv: | |||||
tr.close() | tr.close() | ||||
self._srv.close() | self._srv.close() | ||||
self._pipes[1].close() | |||||
await self._srv.wait_closed() | await self._srv.wait_closed() | ||||
try: | |||||
await self._pipes[1].wait_closed() | |||||
except Exception: | |||||
pass | |||||
class FakeSocks4Srv: | class FakeSocks4Srv: | ||||
@@ -11,6 +11,7 @@ from aiohttp.test_utils import make_mocked_coro | |||||
from aiosocks.test_utils import FakeSocksSrv, FakeSocks4Srv | from aiosocks.test_utils import FakeSocksSrv, FakeSocks4Srv | ||||
from aiosocks.connector import ProxyConnector, ProxyClientRequest | from aiosocks.connector import ProxyConnector, ProxyClientRequest | ||||
from aiosocks.errors import SocksConnectionError | from aiosocks.errors import SocksConnectionError | ||||
from aiosocks import constants as c | |||||
from async_timeout import timeout | from async_timeout import timeout | ||||
from unittest import mock | from unittest import mock | ||||
@@ -117,6 +118,7 @@ async def test_socks5_datagram_success_anonymous(): | |||||
portnum = 53 | portnum = 53 | ||||
dst = (dname, portnum) | dst = (dname, portnum) | ||||
# Fake SOCKS server UDP relay | |||||
class FakeDGramTransport(asyncio.DatagramTransport): | class FakeDGramTransport(asyncio.DatagramTransport): | ||||
def sendto(self, data, addr=None): | def sendto(self, data, addr=None): | ||||
# Verify correct packet was receieved | # Verify correct packet was receieved | ||||
@@ -142,6 +144,7 @@ async def test_socks5_datagram_success_anonymous(): | |||||
sockservdgram = FakeDGramTransport() | sockservdgram = FakeDGramTransport() | ||||
# Fake the creation of the UDP relay | |||||
async def fake_cde(factory, remote_addr): | async def fake_cde(factory, remote_addr): | ||||
assert remote_addr == ('1.1.1.1', 1111) | assert remote_addr == ('1.1.1.1', 1111) | ||||
@@ -151,10 +154,15 @@ async def test_socks5_datagram_success_anonymous(): | |||||
return sockservdgram, proto | return sockservdgram, proto | ||||
# Open the UDP connection | |||||
with mock.patch.object(loop, 'create_datagram_endpoint', | with mock.patch.object(loop, 'create_datagram_endpoint', | ||||
fake_cde) as m: | fake_cde) as m: | ||||
dgram = await aiosocks.open_datagram(addr, None, dst, loop=loop) | dgram = await aiosocks.open_datagram(addr, None, dst, loop=loop) | ||||
rdr = srv.get_reader() | |||||
# make sure we negotiated the correct command | |||||
assert (await rdr.readexactly(5))[4] == c.SOCKS_CMD_UDP_ASSOCIATE | |||||
assert dgram.proxy_sockname == ('1.1.1.1', 1111) | assert dgram.proxy_sockname == ('1.1.1.1', 1111) | ||||
dgram.send(b'some data') | dgram.send(b'some data') | ||||