diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | Dockerfile | 31 | ||||
-rw-r--r-- | README | 27 | ||||
-rwxr-xr-x | entrypoint.sh | 2 | ||||
-rw-r--r-- | go.mod | 5 | ||||
-rw-r--r-- | go.sum | 2 | ||||
-rw-r--r-- | main.go | 223 | ||||
-rw-r--r-- | python/strip_dns64.py | 154 |
8 files changed, 445 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c65bd2 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +strip-dns64 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..36c0691 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,31 @@ +FROM docker.io/library/golang:1.22-alpine AS builder + +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +RUN CGO_ENABLED=0 GOOS=linux go build -o strip-dns64 + +FROM alpine:3.19 + +RUN apk --no-cache add ca-certificates libcap + +WORKDIR /app +COPY --from=builder /app/strip-dns64 . +COPY entrypoint.sh . + +RUN adduser -D -H -h /app dnsuser +RUN chown dnsuser:dnsuser /app/strip-dns64 /app/entrypoint.sh +RUN chmod +x /app/entrypoint.sh + +RUN setcap cap_net_bind_service=+ep /app/strip-dns64 + +ENV UPSTREAM_DNS="[2606:4700:4700::1111]:53" +ENV LISTEN_ADDRS="[::]:53" + +USER dnsuser + +ENTRYPOINT ["/app/entrypoint.sh"] + +EXPOSE 53/udp @@ -0,0 +1,27 @@ +strip-dns64 is a DNS forwarder which removes synthesized AAAA answers from +upstream DNS64 servers, implemented initially in Python and then Go for improved +performance + +podman build -t strip-dns64 . + +podman run -d \ + --name strip-dns64 \ + --network host \ + -e UPSTREAM_DNS="[2001:4860:4860::8888]:53" \ + -e LISTEN_ADDRS="127.0.0.1:53,[::]:53" \ + strip-dns64 + +podman run -d \ + --name strip-dns64 \ + --cap-add=NET_BIND_SERVICE \ + -p 53:53/udp \ + -p [::]:53:53/udp \ + -e UPSTREAM_DNS="[2001:4860:4860::8888]:53" \ + strip-dns64 + +Usage of ./strip-dns64: + -listen string + Comma-separated list of IP:PORT to listen on (default "[::]:53") + -upstream string + Upstream DNS server (format: ip:port or [ipv6]:port) (default "[2606:4700:4700::1111]:53") + diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 0000000..d7cd260 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,2 @@ +#!/bin/sh +exec /app/strip-dns64 -upstream "${UPSTREAM_DNS}" -listen "${LISTEN_ADDRS}" @@ -0,0 +1,5 @@ +module strip-dns64 + +go 1.22.2 + +require golang.org/x/net v0.30.0 @@ -0,0 +1,2 @@ +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= @@ -0,0 +1,223 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +type DNSForwarder struct { + upstreamDNS string + listenAddresses []ListenAddress + nat64Prefix *net.IPNet +} + +type ListenAddress struct { + IP string + Port int +} + +func NewDNSForwarder(upstreamDNS string, listenAddresses []ListenAddress) *DNSForwarder { + _, nat64Network, _ := net.ParseCIDR("64:ff9b::/96") + return &DNSForwarder{ + upstreamDNS: upstreamDNS, + listenAddresses: listenAddresses, + nat64Prefix: nat64Network, + } +} + +func (df *DNSForwarder) createSocket(listenIP string, listenPort int) (*net.UDPConn, error) { + addr := &net.UDPAddr{ + Port: listenPort, + } + + if listenIP == "::" { + addr.IP = net.IPv6zero + } else { + addr.IP = net.ParseIP(listenIP) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + + log.Printf("DNS forwarder listening on %s:%d", listenIP, listenPort) + return conn, nil +} + +func (df *DNSForwarder) handleSocket(conn *net.UDPConn) { + buffer := make([]byte, 4096) + for { + n, remoteAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + log.Printf("Error reading from socket: %v", err) + continue + } + + go func(data []byte, addr *net.UDPAddr) { + response := df.processQuery(data[:n]) + if response != nil { + _, err := conn.WriteToUDP(response, addr) + if err != nil { + log.Printf("Error sending response: %v", err) + } + } + }(buffer[:n], remoteAddr) + } +} + +func (df *DNSForwarder) processQuery(queryData []byte) []byte { + // Check minimum DNS message size + if len(queryData) < 12 { + log.Printf("Query too short to be valid DNS message") + return nil + } + + var msg dnsmessage.Message + if err := msg.Unpack(queryData); err != nil { + log.Printf("Error unpacking DNS message: %v", err) + errorMsg := &dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: msg.Header.ID, + Response: true, + RCode: dnsmessage.RCodeFormatError, + }, + } + response, packErr := errorMsg.Pack() + if packErr != nil { + log.Printf("Error packing error response: %v", packErr) + return nil + } + return response + } + + if len(msg.Questions) == 0 { + log.Printf("DNS query contains no questions") + return nil + } + + upstreamConn, err := net.Dial("udp", df.upstreamDNS) + if err != nil { + log.Printf("Error connecting to upstream DNS: %v", err) + return nil + } + defer upstreamConn.Close() + + // Set a reasonable timeout for upstream queries (5 seconds) + upstreamConn.SetDeadline(time.Now().Add(5 * time.Second)) + + _, err = upstreamConn.Write(queryData) + if err != nil { + log.Printf("Error sending query to upstream: %v", err) + return nil + } + + responseBuffer := make([]byte, 4096) + n, err := upstreamConn.Read(responseBuffer) + if err != nil { + log.Printf("Error reading response from upstream: %v", err) + return nil + } + + if n < 12 { + log.Printf("Response too short to be valid DNS message") + return nil + } + + return df.filterDNS64Responses(responseBuffer[:n]) +} + +func (df *DNSForwarder) filterDNS64Responses(responseData []byte) []byte { + var msg dnsmessage.Message + if err := msg.Unpack(responseData); err != nil { + log.Printf("Error unpacking response: %v", err) + return responseData + } + + var filteredAnswers []dnsmessage.Resource + for _, answer := range msg.Answers { + if answer.Header.Type == dnsmessage.TypeAAAA { + aaaa, ok := answer.Body.(*dnsmessage.AAAAResource) + if ok { + ip := net.IP(aaaa.AAAA[:]) + if !df.nat64Prefix.Contains(ip) { + filteredAnswers = append(filteredAnswers, answer) + } + } + } else { + filteredAnswers = append(filteredAnswers, answer) + } + } + + msg.Answers = filteredAnswers + response, err := msg.Pack() + if err != nil { + log.Printf("Error packing filtered response: %v", err) + return responseData + } + + return response +} + +func (df *DNSForwarder) Start() error { + var wg sync.WaitGroup + for _, addr := range df.listenAddresses { + conn, err := df.createSocket(addr.IP, addr.Port) + if err != nil { + return fmt.Errorf("failed to create socket for %s:%d: %v", addr.IP, addr.Port, err) + } + + wg.Add(1) + go func() { + defer wg.Done() + df.handleSocket(conn) + }() + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + log.Println("Shutting down...") + return nil +} + +func parseListenAddress(addr string) ListenAddress { + parts := strings.Split(addr, ":") + if len(parts) < 2 { + return ListenAddress{IP: "::", Port: 53} + } + port := 53 + fmt.Sscanf(parts[len(parts)-1], "%d", &port) + ip := strings.Join(parts[:len(parts)-1], ":") + if ip == "" { + ip = "::" + } + return ListenAddress{IP: ip, Port: port} +} + +func main() { + upstreamDNS := flag.String("upstream", "[2606:4700:4700::1111]:53", "Upstream DNS server (format: ip:port or [ipv6]:port)") + listenAddrs := flag.String("listen", "[::]:53", "Comma-separated list of IP:PORT to listen on") + flag.Parse() + + var addresses []ListenAddress + for _, addr := range strings.Split(*listenAddrs, ",") { + addresses = append(addresses, parseListenAddress(addr)) + } + + forwarder := NewDNSForwarder(*upstreamDNS, addresses) + if err := forwarder.Start(); err != nil { + log.Fatal(err) + } +} diff --git a/python/strip_dns64.py b/python/strip_dns64.py new file mode 100644 index 0000000..1cefd9f --- /dev/null +++ b/python/strip_dns64.py @@ -0,0 +1,154 @@ +import argparse
+import ipaddress
+import socket
+import sys
+import threading
+from typing import Optional, Tuple, Union, List
+
+import dns.message
+import dns.query
+import dns.resolver
+import dns.rrset
+
+
+class DNSForwarder:
+ def __init__(self, upstream_dns: str, listen_addresses: List[Tuple[str, int]]):
+ self.upstream_dns = upstream_dns
+ self.listen_addresses = listen_addresses
+ self.nat64_prefix = ipaddress.IPv6Network("64:ff9b::/96")
+ self.sockets = []
+
+ def create_socket(self, listen_ip: str, listen_port: int) -> socket.socket:
+ if ':' in listen_ip:
+ socket_family = socket.AF_INET6
+ else:
+ socket_family = socket.AF_INET
+
+ sock = socket.socket(socket_family, socket.SOCK_DGRAM)
+
+ if listen_ip == '::' and socket_family == socket.AF_INET6:
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+
+ sock.bind((listen_ip, listen_port))
+ print(f"DNS forwarder listening on {listen_ip}:{listen_port}")
+ return sock
+
+ def handle_socket(self, sock: socket.socket) -> None:
+ while True:
+ try:
+ data, addr = sock.recvfrom(4096)
+ response = self.process_query(data)
+ if response:
+ sock.sendto(response, addr)
+ except Exception as e:
+ print(f"Error handling socket: {e}")
+
+ def start(self) -> None:
+ print(f"Forwarding to upstream DNS server: {self.upstream_dns}")
+
+ # Create sockets for all listen addresses
+ threads = []
+ for listen_ip, listen_port in self.listen_addresses:
+ try:
+ sock = self.create_socket(listen_ip, listen_port)
+ self.sockets.append(sock)
+
+ # Create a thread for each socket
+ thread = threading.Thread(target=self.handle_socket, args=(sock,))
+ thread.daemon = True
+ threads.append(thread)
+ thread.start()
+ except Exception as e:
+ print(f"Failed to create socket for [{listen_ip}]:{listen_port}: {e}")
+
+ # Wait for all threads
+ try:
+ for thread in threads:
+ thread.join()
+ except KeyboardInterrupt:
+ print("\nShutting down...")
+ for sock in self.sockets:
+ sock.close()
+ sys.exit(0)
+
+ def process_query(self, query_data: bytes) -> Optional[bytes]:
+ try:
+ query = dns.message.from_wire(query_data)
+ response = dns.query.udp(query, self.upstream_dns)
+
+ filtered_response = self.filter_dns64_responses(response)
+
+ return filtered_response.to_wire()
+ except Exception as e:
+ print(f"Error processing query: {e}")
+ return None
+
+ def filter_dns64_responses(self, response: dns.message.Message) -> dns.message.Message:
+ filtered_response = dns.message.Message(response.id)
+ filtered_response.flags = response.flags
+ filtered_response.set_opcode(response.opcode())
+ filtered_response.set_rcode(response.rcode())
+
+ # Copy questions
+ filtered_response.question = response.question
+
+ for rrset in response.answer:
+ filtered_rrset = dns.rrset.RRset(rrset.name, rrset.rdclass, rrset.rdtype)
+
+ for rr in rrset:
+ if rrset.rdtype == dns.rdatatype.AAAA:
+ ip = ipaddress.IPv6Address(rr.address)
+ if not ip in self.nat64_prefix:
+ filtered_rrset.add(rr)
+ else:
+ filtered_rrset.add(rr)
+
+ if len(filtered_rrset) > 0:
+ filtered_response.answer.append(filtered_rrset)
+
+ # Copy additional and authority sections
+ filtered_response.additional = response.additional
+ filtered_response.authority = response.authority
+
+ return filtered_response
+
+
+def get_default_dns() -> str:
+ resolver = dns.resolver.Resolver()
+ return resolver.nameservers[0]
+
+
+def parse_arguments() -> Tuple[str, List[Tuple[str, int]]]:
+ parser = argparse.ArgumentParser(description="DNS forwarder that strips DNS64 responses")
+ parser.add_argument("--upstream", help="Upstream DNS server IP address", default=get_default_dns())
+ parser.add_argument("--listen", help="IP address and port to listen on (format: IP:PORT)",
+ action='append', default=[])
+ args = parser.parse_args()
+
+ # If no listen addresses specified, use default
+ if not args.listen:
+ args.listen = [":::53"]
+
+ # Parse listen addresses
+ listen_addresses = []
+ for addr in args.listen:
+ if ':' in addr:
+ ip, port = addr.rsplit(':', 1)
+ if not ip:
+ ip = "::"
+ else:
+ ip = addr
+ port = "53"
+ listen_addresses.append((ip, int(port)))
+
+ return args.upstream, listen_addresses
+
+
+def main() -> None:
+ upstream_dns, listen_addresses = parse_arguments()
+ forwarder = DNSForwarder(upstream_dns, listen_addresses)
+ forwarder.start()
+
+
+if __name__ == "__main__":
+ main()
|