@@ -55,7 +55,7 @@ class BaseSocksProtocol(asyncio.StreamReaderProtocol): | |||
async def negotiate(self, reader, writer): | |||
try: | |||
req = self.socks_request(c.SOCKS_CMD_CONNECT) | |||
req = self.socks_request(self.cmd) | |||
self._proxy_peername, self._proxy_sockname = await req | |||
except SocksError as exc: | |||
exc = SocksError('Can not connect to %s:%s. %s' % | |||
@@ -4,17 +4,48 @@ import socket | |||
from aiohttp.test_utils import unused_port | |||
def _asyncsockpair(): | |||
'''Create a pair of sockets that are bound to each other. | |||
The function will return a tuple of two coroutine's, that | |||
each, when await'ed upon, will return the reader/writer pair.''' | |||
socka, sockb = socket.socketpair() | |||
return asyncio.open_connection(sock=socka), \ | |||
asyncio.open_connection(sock=sockb) | |||
async def _getreaderwriter(): | |||
'''Return a reader/writer pair. Any data written | |||
to the reader can be read from the writer side. | |||
returns (reader, writer).''' | |||
socka, sockb = _asyncsockpair() | |||
ardr, awrr = await socka | |||
brdr, bwrr = await sockb | |||
# don't close, as it also closes the reader as well | |||
awrr.write_eof() | |||
return ardr, bwrr | |||
class FakeSocksSrv: | |||
def __init__(self, loop, write_buff): | |||
self._loop = loop | |||
self._write_buff = write_buff | |||
self._transports = [] | |||
self._srv = None | |||
self._pipes = None | |||
self.port = unused_port() | |||
def get_reader(self): | |||
return self._pipes[0] | |||
async def __aenter__(self): | |||
transports = self._transports | |||
write_buff = self._write_buff | |||
pipes = await _getreaderwriter() | |||
self._pipes = pipes | |||
class SocksPrimitiveProtocol(asyncio.Protocol): | |||
_transport = None | |||
@@ -24,6 +55,7 @@ class FakeSocksSrv: | |||
transports.append(transport) | |||
def data_received(self, data): | |||
pipes[1].write(data) | |||
self._transport.write(write_buff) | |||
def factory(): | |||
@@ -39,7 +71,14 @@ class FakeSocksSrv: | |||
tr.close() | |||
self._srv.close() | |||
self._pipes[1].close() | |||
await self._srv.wait_closed() | |||
try: | |||
await self._pipes[1].wait_closed() | |||
except Exception: | |||
pass | |||
class FakeSocks4Srv: | |||
@@ -11,6 +11,7 @@ from aiohttp.test_utils import make_mocked_coro | |||
from aiosocks.test_utils import FakeSocksSrv, FakeSocks4Srv | |||
from aiosocks.connector import ProxyConnector, ProxyClientRequest | |||
from aiosocks.errors import SocksConnectionError | |||
from aiosocks import constants as c | |||
from async_timeout import timeout | |||
from unittest import mock | |||
@@ -117,6 +118,7 @@ async def test_socks5_datagram_success_anonymous(): | |||
portnum = 53 | |||
dst = (dname, portnum) | |||
# Fake SOCKS server UDP relay | |||
class FakeDGramTransport(asyncio.DatagramTransport): | |||
def sendto(self, data, addr=None): | |||
# Verify correct packet was receieved | |||
@@ -142,6 +144,7 @@ async def test_socks5_datagram_success_anonymous(): | |||
sockservdgram = FakeDGramTransport() | |||
# Fake the creation of the UDP relay | |||
async def fake_cde(factory, remote_addr): | |||
assert remote_addr == ('1.1.1.1', 1111) | |||
@@ -151,10 +154,15 @@ async def test_socks5_datagram_success_anonymous(): | |||
return sockservdgram, proto | |||
# Open the UDP connection | |||
with mock.patch.object(loop, 'create_datagram_endpoint', | |||
fake_cde) as m: | |||
dgram = await aiosocks.open_datagram(addr, None, dst, loop=loop) | |||
rdr = srv.get_reader() | |||
# make sure we negotiated the correct command | |||
assert (await rdr.readexactly(5))[4] == c.SOCKS_CMD_UDP_ASSOCIATE | |||
assert dgram.proxy_sockname == ('1.1.1.1', 1111) | |||
dgram.send(b'some data') | |||