diff options
Diffstat (limited to 'python/strip_dns64.py')
-rw-r--r-- | python/strip_dns64.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/python/strip_dns64.py b/python/strip_dns64.py new file mode 100644 index 0000000..1cefd9f --- /dev/null +++ b/python/strip_dns64.py @@ -0,0 +1,154 @@ +import argparse
+import ipaddress
+import socket
+import sys
+import threading
+from typing import Optional, Tuple, Union, List
+
+import dns.message
+import dns.query
+import dns.resolver
+import dns.rrset
+
+
+class DNSForwarder:
+ def __init__(self, upstream_dns: str, listen_addresses: List[Tuple[str, int]]):
+ self.upstream_dns = upstream_dns
+ self.listen_addresses = listen_addresses
+ self.nat64_prefix = ipaddress.IPv6Network("64:ff9b::/96")
+ self.sockets = []
+
+ def create_socket(self, listen_ip: str, listen_port: int) -> socket.socket:
+ if ':' in listen_ip:
+ socket_family = socket.AF_INET6
+ else:
+ socket_family = socket.AF_INET
+
+ sock = socket.socket(socket_family, socket.SOCK_DGRAM)
+
+ if listen_ip == '::' and socket_family == socket.AF_INET6:
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+
+ sock.bind((listen_ip, listen_port))
+ print(f"DNS forwarder listening on {listen_ip}:{listen_port}")
+ return sock
+
+ def handle_socket(self, sock: socket.socket) -> None:
+ while True:
+ try:
+ data, addr = sock.recvfrom(4096)
+ response = self.process_query(data)
+ if response:
+ sock.sendto(response, addr)
+ except Exception as e:
+ print(f"Error handling socket: {e}")
+
+ def start(self) -> None:
+ print(f"Forwarding to upstream DNS server: {self.upstream_dns}")
+
+ # Create sockets for all listen addresses
+ threads = []
+ for listen_ip, listen_port in self.listen_addresses:
+ try:
+ sock = self.create_socket(listen_ip, listen_port)
+ self.sockets.append(sock)
+
+ # Create a thread for each socket
+ thread = threading.Thread(target=self.handle_socket, args=(sock,))
+ thread.daemon = True
+ threads.append(thread)
+ thread.start()
+ except Exception as e:
+ print(f"Failed to create socket for [{listen_ip}]:{listen_port}: {e}")
+
+ # Wait for all threads
+ try:
+ for thread in threads:
+ thread.join()
+ except KeyboardInterrupt:
+ print("\nShutting down...")
+ for sock in self.sockets:
+ sock.close()
+ sys.exit(0)
+
+ def process_query(self, query_data: bytes) -> Optional[bytes]:
+ try:
+ query = dns.message.from_wire(query_data)
+ response = dns.query.udp(query, self.upstream_dns)
+
+ filtered_response = self.filter_dns64_responses(response)
+
+ return filtered_response.to_wire()
+ except Exception as e:
+ print(f"Error processing query: {e}")
+ return None
+
+ def filter_dns64_responses(self, response: dns.message.Message) -> dns.message.Message:
+ filtered_response = dns.message.Message(response.id)
+ filtered_response.flags = response.flags
+ filtered_response.set_opcode(response.opcode())
+ filtered_response.set_rcode(response.rcode())
+
+ # Copy questions
+ filtered_response.question = response.question
+
+ for rrset in response.answer:
+ filtered_rrset = dns.rrset.RRset(rrset.name, rrset.rdclass, rrset.rdtype)
+
+ for rr in rrset:
+ if rrset.rdtype == dns.rdatatype.AAAA:
+ ip = ipaddress.IPv6Address(rr.address)
+ if not ip in self.nat64_prefix:
+ filtered_rrset.add(rr)
+ else:
+ filtered_rrset.add(rr)
+
+ if len(filtered_rrset) > 0:
+ filtered_response.answer.append(filtered_rrset)
+
+ # Copy additional and authority sections
+ filtered_response.additional = response.additional
+ filtered_response.authority = response.authority
+
+ return filtered_response
+
+
+def get_default_dns() -> str:
+ resolver = dns.resolver.Resolver()
+ return resolver.nameservers[0]
+
+
+def parse_arguments() -> Tuple[str, List[Tuple[str, int]]]:
+ parser = argparse.ArgumentParser(description="DNS forwarder that strips DNS64 responses")
+ parser.add_argument("--upstream", help="Upstream DNS server IP address", default=get_default_dns())
+ parser.add_argument("--listen", help="IP address and port to listen on (format: IP:PORT)",
+ action='append', default=[])
+ args = parser.parse_args()
+
+ # If no listen addresses specified, use default
+ if not args.listen:
+ args.listen = [":::53"]
+
+ # Parse listen addresses
+ listen_addresses = []
+ for addr in args.listen:
+ if ':' in addr:
+ ip, port = addr.rsplit(':', 1)
+ if not ip:
+ ip = "::"
+ else:
+ ip = addr
+ port = "53"
+ listen_addresses.append((ip, int(port)))
+
+ return args.upstream, listen_addresses
+
+
+def main() -> None:
+ upstream_dns, listen_addresses = parse_arguments()
+ forwarder = DNSForwarder(upstream_dns, listen_addresses)
+ forwarder.start()
+
+
+if __name__ == "__main__":
+ main()
|