diff --git a/dnsrewriteproxy.py b/dnsrewriteproxy.py index 46ed465..81ace8e 100644 --- a/dnsrewriteproxy.py +++ b/dnsrewriteproxy.py @@ -143,10 +143,7 @@ def DnsProxy( query = parse(request_data) try: - return pack( - error(query, ERRORS.REFUSED) if query.qd[0].qtype != TYPES.A else - (await proxy(request_logger, resolve, query)) - ) + return pack(await proxy(request_logger, resolve, query)) except Exception: request_logger.exception('Failed to proxy %s', query) return pack(error(query, ERRORS.SERVFAIL)) @@ -158,6 +155,10 @@ def DnsProxy( name_str_lower = query.qd[0].name.lower().decode('idna') request_logger.info('Decoded: %s', name_str_lower) + if query.qd[0].qtype != TYPES.A: + request_logger.info('Unhandled query type: %s', query.qd[0].qtype) + return error(query, ERRORS.REFUSED) + for pattern, replace in rules: rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower) if num_matches: diff --git a/test.py b/test.py index bde52e7..8ef046b 100644 --- a/test.py +++ b/test.py @@ -63,6 +63,19 @@ class TestProxy(unittest.TestCase): self.assertEqual(type(response[0]), IPv4AddressExpiresAt) + @async_test + async def test_e2e_match_all_wrong_type(self): + resolve, clear_cache = get_resolver(3535) + self.add_async_cleanup(clear_cache) + start = DnsProxy(get_socket=get_socket(3535), rules=((r'(^.*$)', r'\1'),)) + server_task = await start() + self.add_async_cleanup(await_cancel, server_task) + + with self.assertRaises(DnsResponseCode) as cm: + await resolve('www.google.com', TYPES.AAAA) + + self.assertEqual(cm.exception.args[0], 5) + @async_test async def test_e2e_default_port_match_all(self): resolve, clear_cache = get_resolver(53)