diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e91116e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +repos: +- repo: git://github.com/pre-commit/pre-commit-hooks + rev: v2.4.0 + hooks: + - id: flake8 + args: + - --max-line-length=99 + - id: check-ast + - id: check-case-conflict + - id: debug-statements + - id: double-quote-string-fixer + - id: end-of-file-fixer +- repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v1.4.4 + hooks: + - id: autopep8 + args: + - --in-place + - --max-line-length=99 +# We run pylint from local env, to ensure modules can be found +- repo: local + hooks: + - id: pylint + name: pylint + entry: env PYTHONPATH=app python3 -m pylint.__main__ + language: system + types: [python] + args: + - --disable=broad-except + - --disable=duplicate-code + - --disable=invalid-name + - --disable=missing-docstring + - --disable=no-self-use + - --disable=possibly-unused-variable + - --disable=protected-access + - --disable=too-few-public-methods + - --disable=too-many-arguments + - --disable=too-many-branches + - --disable=too-many-lines + - --disable=too-many-locals + - --disable=too-many-public-methods + - --disable=too-many-statements + - --disable=try-except-raise + - --include-naming-hint=yes + - --max-args=10 + - --max-line-length=99 + - --max-locals=25 + - --max-returns=10 diff --git a/README.md b/README.md index 1a5e1fa..1b25b01 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,18 @@ # dns-rewrite-proxy + A DNS proxy server that conditionally rewrites and filters A record requests + + +## Usage + +```python +from dnsrewriteproxy import DnsProxy + +start = DnsProxy() + +# Proxy is running, accepting UDP requests on port 53 +stop = await start() + +# Stopped +await stop() +``` diff --git a/dnsrewriteproxy.py b/dnsrewriteproxy.py new file mode 100644 index 0000000..5358f53 --- /dev/null +++ b/dnsrewriteproxy.py @@ -0,0 +1,282 @@ +# We run a custom nameserver to respond to A record-requests +# +# - Conditionally proxy some requests to the VPC-DNS server [dnsmasq can do this] +# - Respond with a CNAME to certain records [dnsmasq cannot do this] +# +# We _could_ wrap dnsmasq with some custom code, so it polls a file to work +# out its CNAME limitation, and we asynchronously update that file, but since +# we're already hitting its limit, suspect its overall less hacky to just +# write a DNS server that response to requests in the way we want + +from asyncio import ( + CancelledError, + Future, + Queue, + create_task, + get_running_loop, +) +from enum import ( + IntEnum, +) +import logging +import socket + +from aiodnsresolver import ( + RESPONSE, + TYPES, + DnsRecordDoesNotExist, + DnsResponseCode, + Message, + Resolver, + ResourceRecord, + pack, + parse, +) + + +def get_socket_default(): + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(('', 53)) + return sock + + +def get_resolver_default(): + return Resolver() + + +def get_resolver_with_upstream(upstream): + async def get_nameservers(_, __): + for _ in range(0, 5): + yield (0.5, (upstream, 53)) + + return Resolver(get_nameservers=get_nameservers) + + +def get_logger_default(): + return logging.getLogger('dnsrewriteproxy') + + +def DnsProxy( + get_resolver=get_resolver_default, get_logger=get_logger_default, + get_socket=get_socket_default, + num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000, +): + + class ERRORS(IntEnum): + FORMERR = 1 + SERVFAIL = 2 + NXDOMAIN = 3 + REFUSED = 5 + + def __str__(self): + return self.name + + loop = get_running_loop() + logger = get_logger() + + # The "main" task of the server: it receives incoming requests and puts + # them in a queue that is then fetched from and processed by the proxy + # workers + + async def server_worker(sock, resolve): + downstream_queue = Queue(maxsize=downstream_queue_maxsize) + upstream_queue = Queue(maxsize=upstream_queue_maxsize) + + # It would "usually" be ok to send downstream from multiple tasks, but + # if the socket has a full buffer, it would raise a BlockingIOError, + # and we will need to attach a reader. We can only attach one reader + # per underlying file, and since we have a single socket, we have a + # single file. So we send downstream from a single task + downstream_worker_task = create_task(downstream_worker(sock, downstream_queue)) + + # We have multiple upstream workers to be able to send multiple + # requests upstream concurrently, and add responses to downstream_queue + upstream_worker_tasks = [ + create_task(upstream_worker(resolve, upstream_queue, downstream_queue)) + for _ in range(0, num_workers)] + + try: + while True: + request_data, addr = await recvfrom(loop, sock, 512) + await upstream_queue.put((request_data, addr)) + finally: + # Finish upstream requests, which can add to to the downstream + # queue + await upstream_queue.join() + for upstream_task in upstream_worker_tasks: + upstream_task.cancel() + + # Ensure we have sent the responses downstream + await downstream_queue.join() + downstream_worker_task.cancel() + + # Wait for the tasks to really be finished + for upstream_task in upstream_worker_tasks: + try: + await upstream_task + except Exception: + pass + try: + await downstream_worker_task + except Exception: + pass + + async def upstream_worker(resolve, upstream_queue, downstream_queue): + while True: + request_data, addr = await upstream_queue.get() + + try: + response_data = await get_response_data(resolve, request_data) + except Exception: + logger.exception('Exception from handler_request_data %s', addr) + upstream_queue.task_done() + continue + + await downstream_queue.put((response_data, addr)) + upstream_queue.task_done() + + async def downstream_worker(sock, downstream_queue): + while True: + response_data, addr = await downstream_queue.get() + await sendto(loop, sock, response_data, addr) + downstream_queue.task_done() + + async def get_response_data(resolve, request_data): + # This may raise an exception, which is handled at a higher level. + # We can't [and I suspect shouldn't try to] return an error to the + # client, since we're not able to extract the QID, so the client won't + # be able to match it with an outgoing request + query = parse(request_data) + + if not query.qd: + return pack(error(query, ERRORS.REFUSED)) + + try: + return pack( + error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else + (await proxy(resolve, query)) + ) + except Exception: + logger.exception('Failed to proxy %s', query) + return pack(error(query, ERRORS.SERVFAIL)) + + async def proxy(resolve, query): + name_bytes = query.qd[0].name + name_str = query.qd[0].name.decode('idna') + + try: + ip_addresses = await resolve(name_str, TYPES.A) + except DnsRecordDoesNotExist: + return error(query, ERRORS.NXDOMAIN) + except DnsResponseCode as dns_response_code_error: + return error(query, dns_response_code_error.args[0]) + + reponse_records = tuple( + ResourceRecord(name=name_bytes, qtype=TYPES.A, + qclass=1, ttl=5, rdata=ip_address.packed) + for ip_address in ip_addresses + ) + return Message( + qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0, + qd=query.qd, an=reponse_records, ns=(), ar=(), + ) + + async def start(): + # The socket is created synchronously and passed to the server worker, + # so if there is an error creating it, this function will raise an + # exception. If no exeption is raise, we are indeed listening# + sock = get_socket() + + # The resolver is also created synchronously, since it can parse + # /etc/hosts or /etc/resolve.conf, and can raise an exception if + # something goes wrong with that + resolve, clear_cache = get_resolver() + server_worker_task = create_task(server_worker(sock, resolve)) + + async def stop(): + server_worker_task.cancel() + try: + await server_worker_task + except CancelledError: + pass + + sock.close() + await clear_cache() + + return stop + + return start + + +def error(query, rcode): + return Message( + qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode, + qd=query.qd, an=(), ns=(), ar=(), + ) + + +async def recvfrom(loop, sock, max_bytes): + try: + return sock.recvfrom(max_bytes) + except BlockingIOError: + pass + + def reader(): + try: + (data, addr) = sock.recvfrom(max_bytes) + except BlockingIOError: + pass + except BaseException as exception: + loop.remove_reader(fileno) + if not result.done(): + result.set_exception(exception) + else: + loop.remove_reader(fileno) + if not result.done(): + result.set_result((data, addr)) + + fileno = sock.fileno() + result = Future() + loop.add_reader(fileno, reader) + + try: + return await result + finally: + loop.remove_reader(fileno) + + +async def sendto(loop, sock, data, addr): + # In our cases, the UDP responses will always be 512 bytes or less. + # Even if sendto sent some of the data, there is no way for the other + # end to reconstruct their order, so we don't include any logic to send + # the rest of the data. Since it's UDP, the client already has to have + # retry logic + + try: + return sock.sendto(data, addr) + except BlockingIOError: + pass + + def writer(): + try: + num_bytes = sock.sendto(data, addr) + except BlockingIOError: + pass + except BaseException as exception: + loop.remove_writer(fileno) + if not result.done(): + result.set_exception(exception) + else: + loop.remove_writer(fileno) + if not result.done(): + result.set_result(num_bytes) + + fileno = sock.fileno() + result = Future() + loop.add_writer(fileno, writer) + + try: + return await result + finally: + loop.remove_writer(fileno) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..894d214 --- /dev/null +++ b/setup.py @@ -0,0 +1,31 @@ +import setuptools + + +def long_description(): + with open('README.md', 'r') as file: + return file.read() + + +setuptools.setup( + name='dnsrewriteproxy', + version='0.0.0', + author='Department for International Trade', + author_email='webops@digital.trade.gov.uk', + description='A DNS proxy server that conditionally rewrites and filters A record requests', + long_description=long_description(), + long_description_content_type='text/markdown', + url='https://github.com/uktrade/dnsrewriteproxy', + classifiers=[ + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: MIT License', + 'Topic :: Internet :: Name Service (DNS)', + ], + python_requires='>=3.7.0', + py_modules=[ + 'dnsrewriteproxy', + ], + install_requires=[ + 'aiodnsresolver>=0.0.149', + ], + test_suite='test', +) diff --git a/test.py b/test.py new file mode 100644 index 0000000..a84dfcb --- /dev/null +++ b/test.py @@ -0,0 +1,48 @@ +import asyncio +import socket +import unittest + + +from aiodnsresolver import ( + TYPES, + Resolver, + IPv4AddressExpiresAt, +) +from dnsrewriteproxy import ( + DnsProxy, +) + + +def async_test(func): + def wrapper(*args, **kwargs): + future = func(*args, **kwargs) + loop = asyncio.get_event_loop() + loop.run_until_complete(future) + return wrapper + + +class TestProxy(unittest.TestCase): + def add_async_cleanup(self, coroutine): + self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine()) + + @async_test + async def test_e2e(self): + + def get_socket(): + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(('', 3535)) + return sock + + async def get_nameservers(_, __): + for _ in range(0, 5): + yield (0.5, ('127.0.0.1', 3535)) + + resolve, clear_cache = Resolver(get_nameservers=get_nameservers) + self.add_async_cleanup(clear_cache) + start = DnsProxy(get_socket=get_socket) + stop = await start() + self.add_async_cleanup(stop) + + response = await resolve('www.google.com', TYPES.A) + self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt))