aboutsummaryrefslogtreecommitdiff
path: root/python/strip_dns64.py
blob: 1cefd9f2393f40bf799969add789970d8a564b68 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse
import ipaddress
import socket
import sys
import threading
from typing import Optional, Tuple, Union, List

import dns.message
import dns.query
import dns.resolver
import dns.rrset


class DNSForwarder:
    def __init__(self, upstream_dns: str, listen_addresses: List[Tuple[str, int]]):
        self.upstream_dns = upstream_dns
        self.listen_addresses = listen_addresses
        self.nat64_prefix = ipaddress.IPv6Network("64:ff9b::/96")
        self.sockets = []

    def create_socket(self, listen_ip: str, listen_port: int) -> socket.socket:
        if ':' in listen_ip:
            socket_family = socket.AF_INET6
        else:
            socket_family = socket.AF_INET

        sock = socket.socket(socket_family, socket.SOCK_DGRAM)

        if listen_ip == '::' and socket_family == socket.AF_INET6:
            sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)

        sock.bind((listen_ip, listen_port))
        print(f"DNS forwarder listening on {listen_ip}:{listen_port}")
        return sock

    def handle_socket(self, sock: socket.socket) -> None:
        while True:
            try:
                data, addr = sock.recvfrom(4096)
                response = self.process_query(data)
                if response:
                    sock.sendto(response, addr)
            except Exception as e:
                print(f"Error handling socket: {e}")

    def start(self) -> None:
        print(f"Forwarding to upstream DNS server: {self.upstream_dns}")

        # Create sockets for all listen addresses
        threads = []
        for listen_ip, listen_port in self.listen_addresses:
            try:
                sock = self.create_socket(listen_ip, listen_port)
                self.sockets.append(sock)

                # Create a thread for each socket
                thread = threading.Thread(target=self.handle_socket, args=(sock,))
                thread.daemon = True
                threads.append(thread)
                thread.start()
            except Exception as e:
                print(f"Failed to create socket for [{listen_ip}]:{listen_port}: {e}")

        # Wait for all threads
        try:
            for thread in threads:
                thread.join()
        except KeyboardInterrupt:
            print("\nShutting down...")
            for sock in self.sockets:
                sock.close()
            sys.exit(0)

    def process_query(self, query_data: bytes) -> Optional[bytes]:
        try:
            query = dns.message.from_wire(query_data)
            response = dns.query.udp(query, self.upstream_dns)

            filtered_response = self.filter_dns64_responses(response)

            return filtered_response.to_wire()
        except Exception as e:
            print(f"Error processing query: {e}")
            return None

    def filter_dns64_responses(self, response: dns.message.Message) -> dns.message.Message:
        filtered_response = dns.message.Message(response.id)
        filtered_response.flags = response.flags
        filtered_response.set_opcode(response.opcode())
        filtered_response.set_rcode(response.rcode())

        # Copy questions
        filtered_response.question = response.question

        for rrset in response.answer:
            filtered_rrset = dns.rrset.RRset(rrset.name, rrset.rdclass, rrset.rdtype)

            for rr in rrset:
                if rrset.rdtype == dns.rdatatype.AAAA:
                    ip = ipaddress.IPv6Address(rr.address)
                    if not ip in self.nat64_prefix:
                        filtered_rrset.add(rr)
                else:
                    filtered_rrset.add(rr)

            if len(filtered_rrset) > 0:
                filtered_response.answer.append(filtered_rrset)

        # Copy additional and authority sections
        filtered_response.additional = response.additional
        filtered_response.authority = response.authority

        return filtered_response


def get_default_dns() -> str:
    resolver = dns.resolver.Resolver()
    return resolver.nameservers[0]


def parse_arguments() -> Tuple[str, List[Tuple[str, int]]]:
    parser = argparse.ArgumentParser(description="DNS forwarder that strips DNS64 responses")
    parser.add_argument("--upstream", help="Upstream DNS server IP address", default=get_default_dns())
    parser.add_argument("--listen", help="IP address and port to listen on (format: IP:PORT)",
                        action='append', default=[])
    args = parser.parse_args()

    # If no listen addresses specified, use default
    if not args.listen:
        args.listen = [":::53"]

    # Parse listen addresses
    listen_addresses = []
    for addr in args.listen:
        if ':' in addr:
            ip, port = addr.rsplit(':', 1)
            if not ip:
                ip = "::"
        else:
            ip = addr
            port = "53"
        listen_addresses.append((ip, int(port)))

    return args.upstream, listen_addresses


def main() -> None:
    upstream_dns, listen_addresses = parse_arguments()
    forwarder = DNSForwarder(upstream_dns, listen_addresses)
    forwarder.start()


if __name__ == "__main__":
    main()