Browse Source

(feat) Initial behaviour

main
Michal Charemza 5 years ago
parent
commit
d6833401f4
No known key found for this signature in database GPG Key ID: 4BBAF0F6B73C4363
5 changed files with 425 additions and 0 deletions
  1. +48
    -0
      .pre-commit-config.yaml
  2. +16
    -0
      README.md
  3. +282
    -0
      dnsrewriteproxy.py
  4. +31
    -0
      setup.py
  5. +48
    -0
      test.py

+ 48
- 0
.pre-commit-config.yaml View File

@@ -0,0 +1,48 @@
repos:
- repo: git://github.com/pre-commit/pre-commit-hooks
rev: v2.4.0
hooks:
- id: flake8
args:
- --max-line-length=99
- id: check-ast
- id: check-case-conflict
- id: debug-statements
- id: double-quote-string-fixer
- id: end-of-file-fixer
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: v1.4.4
hooks:
- id: autopep8
args:
- --in-place
- --max-line-length=99
# We run pylint from local env, to ensure modules can be found
- repo: local
hooks:
- id: pylint
name: pylint
entry: env PYTHONPATH=app python3 -m pylint.__main__
language: system
types: [python]
args:
- --disable=broad-except
- --disable=duplicate-code
- --disable=invalid-name
- --disable=missing-docstring
- --disable=no-self-use
- --disable=possibly-unused-variable
- --disable=protected-access
- --disable=too-few-public-methods
- --disable=too-many-arguments
- --disable=too-many-branches
- --disable=too-many-lines
- --disable=too-many-locals
- --disable=too-many-public-methods
- --disable=too-many-statements
- --disable=try-except-raise
- --include-naming-hint=yes
- --max-args=10
- --max-line-length=99
- --max-locals=25
- --max-returns=10

+ 16
- 0
README.md View File

@@ -1,2 +1,18 @@
# dns-rewrite-proxy

A DNS proxy server that conditionally rewrites and filters A record requests


## Usage

```python
from dnsrewriteproxy import DnsProxy

start = DnsProxy()

# Proxy is running, accepting UDP requests on port 53
stop = await start()

# Stopped
await stop()
```

+ 282
- 0
dnsrewriteproxy.py View File

@@ -0,0 +1,282 @@
# We run a custom nameserver to respond to A record-requests
#
# - Conditionally proxy some requests to the VPC-DNS server [dnsmasq can do this]
# - Respond with a CNAME to certain records [dnsmasq cannot do this]
#
# We _could_ wrap dnsmasq with some custom code, so it polls a file to work
# out its CNAME limitation, and we asynchronously update that file, but since
# we're already hitting its limit, suspect its overall less hacky to just
# write a DNS server that response to requests in the way we want

from asyncio import (
CancelledError,
Future,
Queue,
create_task,
get_running_loop,
)
from enum import (
IntEnum,
)
import logging
import socket

from aiodnsresolver import (
RESPONSE,
TYPES,
DnsRecordDoesNotExist,
DnsResponseCode,
Message,
Resolver,
ResourceRecord,
pack,
parse,
)


def get_socket_default():
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(('', 53))
return sock


def get_resolver_default():
return Resolver()


def get_resolver_with_upstream(upstream):
async def get_nameservers(_, __):
for _ in range(0, 5):
yield (0.5, (upstream, 53))

return Resolver(get_nameservers=get_nameservers)


def get_logger_default():
return logging.getLogger('dnsrewriteproxy')


def DnsProxy(
get_resolver=get_resolver_default, get_logger=get_logger_default,
get_socket=get_socket_default,
num_workers=1000, downstream_queue_maxsize=10000, upstream_queue_maxsize=10000,
):

class ERRORS(IntEnum):
FORMERR = 1
SERVFAIL = 2
NXDOMAIN = 3
REFUSED = 5

def __str__(self):
return self.name

loop = get_running_loop()
logger = get_logger()

# The "main" task of the server: it receives incoming requests and puts
# them in a queue that is then fetched from and processed by the proxy
# workers

async def server_worker(sock, resolve):
downstream_queue = Queue(maxsize=downstream_queue_maxsize)
upstream_queue = Queue(maxsize=upstream_queue_maxsize)

# It would "usually" be ok to send downstream from multiple tasks, but
# if the socket has a full buffer, it would raise a BlockingIOError,
# and we will need to attach a reader. We can only attach one reader
# per underlying file, and since we have a single socket, we have a
# single file. So we send downstream from a single task
downstream_worker_task = create_task(downstream_worker(sock, downstream_queue))

# We have multiple upstream workers to be able to send multiple
# requests upstream concurrently, and add responses to downstream_queue
upstream_worker_tasks = [
create_task(upstream_worker(resolve, upstream_queue, downstream_queue))
for _ in range(0, num_workers)]

try:
while True:
request_data, addr = await recvfrom(loop, sock, 512)
await upstream_queue.put((request_data, addr))
finally:
# Finish upstream requests, which can add to to the downstream
# queue
await upstream_queue.join()
for upstream_task in upstream_worker_tasks:
upstream_task.cancel()

# Ensure we have sent the responses downstream
await downstream_queue.join()
downstream_worker_task.cancel()

# Wait for the tasks to really be finished
for upstream_task in upstream_worker_tasks:
try:
await upstream_task
except Exception:
pass
try:
await downstream_worker_task
except Exception:
pass

async def upstream_worker(resolve, upstream_queue, downstream_queue):
while True:
request_data, addr = await upstream_queue.get()

try:
response_data = await get_response_data(resolve, request_data)
except Exception:
logger.exception('Exception from handler_request_data %s', addr)
upstream_queue.task_done()
continue

await downstream_queue.put((response_data, addr))
upstream_queue.task_done()

async def downstream_worker(sock, downstream_queue):
while True:
response_data, addr = await downstream_queue.get()
await sendto(loop, sock, response_data, addr)
downstream_queue.task_done()

async def get_response_data(resolve, request_data):
# This may raise an exception, which is handled at a higher level.
# We can't [and I suspect shouldn't try to] return an error to the
# client, since we're not able to extract the QID, so the client won't
# be able to match it with an outgoing request
query = parse(request_data)

if not query.qd:
return pack(error(query, ERRORS.REFUSED))

try:
return pack(
error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else
(await proxy(resolve, query))
)
except Exception:
logger.exception('Failed to proxy %s', query)
return pack(error(query, ERRORS.SERVFAIL))

async def proxy(resolve, query):
name_bytes = query.qd[0].name
name_str = query.qd[0].name.decode('idna')

try:
ip_addresses = await resolve(name_str, TYPES.A)
except DnsRecordDoesNotExist:
return error(query, ERRORS.NXDOMAIN)
except DnsResponseCode as dns_response_code_error:
return error(query, dns_response_code_error.args[0])

reponse_records = tuple(
ResourceRecord(name=name_bytes, qtype=TYPES.A,
qclass=1, ttl=5, rdata=ip_address.packed)
for ip_address in ip_addresses
)
return Message(
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
qd=query.qd, an=reponse_records, ns=(), ar=(),
)

async def start():
# The socket is created synchronously and passed to the server worker,
# so if there is an error creating it, this function will raise an
# exception. If no exeption is raise, we are indeed listening#
sock = get_socket()

# The resolver is also created synchronously, since it can parse
# /etc/hosts or /etc/resolve.conf, and can raise an exception if
# something goes wrong with that
resolve, clear_cache = get_resolver()
server_worker_task = create_task(server_worker(sock, resolve))

async def stop():
server_worker_task.cancel()
try:
await server_worker_task
except CancelledError:
pass

sock.close()
await clear_cache()

return stop

return start


def error(query, rcode):
return Message(
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
qd=query.qd, an=(), ns=(), ar=(),
)


async def recvfrom(loop, sock, max_bytes):
try:
return sock.recvfrom(max_bytes)
except BlockingIOError:
pass

def reader():
try:
(data, addr) = sock.recvfrom(max_bytes)
except BlockingIOError:
pass
except BaseException as exception:
loop.remove_reader(fileno)
if not result.done():
result.set_exception(exception)
else:
loop.remove_reader(fileno)
if not result.done():
result.set_result((data, addr))

fileno = sock.fileno()
result = Future()
loop.add_reader(fileno, reader)

try:
return await result
finally:
loop.remove_reader(fileno)


async def sendto(loop, sock, data, addr):
# In our cases, the UDP responses will always be 512 bytes or less.
# Even if sendto sent some of the data, there is no way for the other
# end to reconstruct their order, so we don't include any logic to send
# the rest of the data. Since it's UDP, the client already has to have
# retry logic

try:
return sock.sendto(data, addr)
except BlockingIOError:
pass

def writer():
try:
num_bytes = sock.sendto(data, addr)
except BlockingIOError:
pass
except BaseException as exception:
loop.remove_writer(fileno)
if not result.done():
result.set_exception(exception)
else:
loop.remove_writer(fileno)
if not result.done():
result.set_result(num_bytes)

fileno = sock.fileno()
result = Future()
loop.add_writer(fileno, writer)

try:
return await result
finally:
loop.remove_writer(fileno)

+ 31
- 0
setup.py View File

@@ -0,0 +1,31 @@
import setuptools


def long_description():
with open('README.md', 'r') as file:
return file.read()


setuptools.setup(
name='dnsrewriteproxy',
version='0.0.0',
author='Department for International Trade',
author_email='webops@digital.trade.gov.uk',
description='A DNS proxy server that conditionally rewrites and filters A record requests',
long_description=long_description(),
long_description_content_type='text/markdown',
url='https://github.com/uktrade/dnsrewriteproxy',
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: MIT License',
'Topic :: Internet :: Name Service (DNS)',
],
python_requires='>=3.7.0',
py_modules=[
'dnsrewriteproxy',
],
install_requires=[
'aiodnsresolver>=0.0.149',
],
test_suite='test',
)

+ 48
- 0
test.py View File

@@ -0,0 +1,48 @@
import asyncio
import socket
import unittest


from aiodnsresolver import (
TYPES,
Resolver,
IPv4AddressExpiresAt,
)
from dnsrewriteproxy import (
DnsProxy,
)


def async_test(func):
def wrapper(*args, **kwargs):
future = func(*args, **kwargs)
loop = asyncio.get_event_loop()
loop.run_until_complete(future)
return wrapper


class TestProxy(unittest.TestCase):
def add_async_cleanup(self, coroutine):
self.addCleanup(asyncio.get_running_loop().run_until_complete, coroutine())

@async_test
async def test_e2e(self):

def get_socket():
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(('', 3535))
return sock

async def get_nameservers(_, __):
for _ in range(0, 5):
yield (0.5, ('127.0.0.1', 3535))

resolve, clear_cache = Resolver(get_nameservers=get_nameservers)
self.add_async_cleanup(clear_cache)
start = DnsProxy(get_socket=get_socket)
stop = await start()
self.add_async_cleanup(stop)

response = await resolve('www.google.com', TYPES.A)
self.assertTrue(isinstance(response[0], IPv4AddressExpiresAt))

Loading…
Cancel
Save