1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
|
import argparse
import ipaddress
import socket
import sys
import threading
from typing import Optional, Tuple, Union, List
import dns.message
import dns.query
import dns.resolver
import dns.rrset
class DNSForwarder:
def __init__(self, upstream_dns: str, listen_addresses: List[Tuple[str, int]]):
self.upstream_dns = upstream_dns
self.listen_addresses = listen_addresses
self.nat64_prefix = ipaddress.IPv6Network("64:ff9b::/96")
self.sockets = []
def create_socket(self, listen_ip: str, listen_port: int) -> socket.socket:
if ':' in listen_ip:
socket_family = socket.AF_INET6
else:
socket_family = socket.AF_INET
sock = socket.socket(socket_family, socket.SOCK_DGRAM)
if listen_ip == '::' and socket_family == socket.AF_INET6:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
sock.bind((listen_ip, listen_port))
print(f"DNS forwarder listening on {listen_ip}:{listen_port}")
return sock
def handle_socket(self, sock: socket.socket) -> None:
while True:
try:
data, addr = sock.recvfrom(4096)
response = self.process_query(data)
if response:
sock.sendto(response, addr)
except Exception as e:
print(f"Error handling socket: {e}")
def start(self) -> None:
print(f"Forwarding to upstream DNS server: {self.upstream_dns}")
# Create sockets for all listen addresses
threads = []
for listen_ip, listen_port in self.listen_addresses:
try:
sock = self.create_socket(listen_ip, listen_port)
self.sockets.append(sock)
# Create a thread for each socket
thread = threading.Thread(target=self.handle_socket, args=(sock,))
thread.daemon = True
threads.append(thread)
thread.start()
except Exception as e:
print(f"Failed to create socket for [{listen_ip}]:{listen_port}: {e}")
# Wait for all threads
try:
for thread in threads:
thread.join()
except KeyboardInterrupt:
print("\nShutting down...")
for sock in self.sockets:
sock.close()
sys.exit(0)
def process_query(self, query_data: bytes) -> Optional[bytes]:
try:
query = dns.message.from_wire(query_data)
response = dns.query.udp(query, self.upstream_dns)
filtered_response = self.filter_dns64_responses(response)
return filtered_response.to_wire()
except Exception as e:
print(f"Error processing query: {e}")
return None
def filter_dns64_responses(self, response: dns.message.Message) -> dns.message.Message:
filtered_response = dns.message.Message(response.id)
filtered_response.flags = response.flags
filtered_response.set_opcode(response.opcode())
filtered_response.set_rcode(response.rcode())
# Copy questions
filtered_response.question = response.question
for rrset in response.answer:
filtered_rrset = dns.rrset.RRset(rrset.name, rrset.rdclass, rrset.rdtype)
for rr in rrset:
if rrset.rdtype == dns.rdatatype.AAAA:
ip = ipaddress.IPv6Address(rr.address)
if not ip in self.nat64_prefix:
filtered_rrset.add(rr)
else:
filtered_rrset.add(rr)
if len(filtered_rrset) > 0:
filtered_response.answer.append(filtered_rrset)
# Copy additional and authority sections
filtered_response.additional = response.additional
filtered_response.authority = response.authority
return filtered_response
def get_default_dns() -> str:
resolver = dns.resolver.Resolver()
return resolver.nameservers[0]
def parse_arguments() -> Tuple[str, List[Tuple[str, int]]]:
parser = argparse.ArgumentParser(description="DNS forwarder that strips DNS64 responses")
parser.add_argument("--upstream", help="Upstream DNS server IP address", default=get_default_dns())
parser.add_argument("--listen", help="IP address and port to listen on (format: IP:PORT)",
action='append', default=[])
args = parser.parse_args()
# If no listen addresses specified, use default
if not args.listen:
args.listen = [":::53"]
# Parse listen addresses
listen_addresses = []
for addr in args.listen:
if ':' in addr:
ip, port = addr.rsplit(':', 1)
if not ip:
ip = "::"
else:
ip = addr
port = "53"
listen_addresses.append((ip, int(port)))
return args.upstream, listen_addresses
def main() -> None:
upstream_dns, listen_addresses = parse_arguments()
forwarder = DNSForwarder(upstream_dns, listen_addresses)
forwarder.start()
if __name__ == "__main__":
main()
|