From ac6921113db6ee4c786bfbe0e124c25569df2c68 Mon Sep 17 00:00:00 2001 From: Jordan Date: Sun, 17 Nov 2024 21:57:16 -0700 Subject: initial commit --- .gitignore | 1 + Dockerfile | 31 +++++++ README | 27 ++++++ entrypoint.sh | 2 + go.mod | 5 ++ go.sum | 2 + main.go | 223 ++++++++++++++++++++++++++++++++++++++++++++++++++ python/strip_dns64.py | 154 ++++++++++++++++++++++++++++++++++ 8 files changed, 445 insertions(+) create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README create mode 100755 entrypoint.sh create mode 100644 go.mod create mode 100644 go.sum create mode 100644 main.go create mode 100644 python/strip_dns64.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2c65bd2 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +strip-dns64 diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..36c0691 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,31 @@ +FROM docker.io/library/golang:1.22-alpine AS builder + +WORKDIR /app +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +RUN CGO_ENABLED=0 GOOS=linux go build -o strip-dns64 + +FROM alpine:3.19 + +RUN apk --no-cache add ca-certificates libcap + +WORKDIR /app +COPY --from=builder /app/strip-dns64 . +COPY entrypoint.sh . + +RUN adduser -D -H -h /app dnsuser +RUN chown dnsuser:dnsuser /app/strip-dns64 /app/entrypoint.sh +RUN chmod +x /app/entrypoint.sh + +RUN setcap cap_net_bind_service=+ep /app/strip-dns64 + +ENV UPSTREAM_DNS="[2606:4700:4700::1111]:53" +ENV LISTEN_ADDRS="[::]:53" + +USER dnsuser + +ENTRYPOINT ["/app/entrypoint.sh"] + +EXPOSE 53/udp diff --git a/README b/README new file mode 100644 index 0000000..39b1982 --- /dev/null +++ b/README @@ -0,0 +1,27 @@ +strip-dns64 is a DNS forwarder which removes synthesized AAAA answers from +upstream DNS64 servers, implemented initially in Python and then Go for improved +performance + +podman build -t strip-dns64 . + +podman run -d \ + --name strip-dns64 \ + --network host \ + -e UPSTREAM_DNS="[2001:4860:4860::8888]:53" \ + -e LISTEN_ADDRS="127.0.0.1:53,[::]:53" \ + strip-dns64 + +podman run -d \ + --name strip-dns64 \ + --cap-add=NET_BIND_SERVICE \ + -p 53:53/udp \ + -p [::]:53:53/udp \ + -e UPSTREAM_DNS="[2001:4860:4860::8888]:53" \ + strip-dns64 + +Usage of ./strip-dns64: + -listen string + Comma-separated list of IP:PORT to listen on (default "[::]:53") + -upstream string + Upstream DNS server (format: ip:port or [ipv6]:port) (default "[2606:4700:4700::1111]:53") + diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 0000000..d7cd260 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,2 @@ +#!/bin/sh +exec /app/strip-dns64 -upstream "${UPSTREAM_DNS}" -listen "${LISTEN_ADDRS}" diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f47130e --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module strip-dns64 + +go 1.22.2 + +require golang.org/x/net v0.30.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b338806 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= diff --git a/main.go b/main.go new file mode 100644 index 0000000..091fb1d --- /dev/null +++ b/main.go @@ -0,0 +1,223 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +type DNSForwarder struct { + upstreamDNS string + listenAddresses []ListenAddress + nat64Prefix *net.IPNet +} + +type ListenAddress struct { + IP string + Port int +} + +func NewDNSForwarder(upstreamDNS string, listenAddresses []ListenAddress) *DNSForwarder { + _, nat64Network, _ := net.ParseCIDR("64:ff9b::/96") + return &DNSForwarder{ + upstreamDNS: upstreamDNS, + listenAddresses: listenAddresses, + nat64Prefix: nat64Network, + } +} + +func (df *DNSForwarder) createSocket(listenIP string, listenPort int) (*net.UDPConn, error) { + addr := &net.UDPAddr{ + Port: listenPort, + } + + if listenIP == "::" { + addr.IP = net.IPv6zero + } else { + addr.IP = net.ParseIP(listenIP) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + + log.Printf("DNS forwarder listening on %s:%d", listenIP, listenPort) + return conn, nil +} + +func (df *DNSForwarder) handleSocket(conn *net.UDPConn) { + buffer := make([]byte, 4096) + for { + n, remoteAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + log.Printf("Error reading from socket: %v", err) + continue + } + + go func(data []byte, addr *net.UDPAddr) { + response := df.processQuery(data[:n]) + if response != nil { + _, err := conn.WriteToUDP(response, addr) + if err != nil { + log.Printf("Error sending response: %v", err) + } + } + }(buffer[:n], remoteAddr) + } +} + +func (df *DNSForwarder) processQuery(queryData []byte) []byte { + // Check minimum DNS message size + if len(queryData) < 12 { + log.Printf("Query too short to be valid DNS message") + return nil + } + + var msg dnsmessage.Message + if err := msg.Unpack(queryData); err != nil { + log.Printf("Error unpacking DNS message: %v", err) + errorMsg := &dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: msg.Header.ID, + Response: true, + RCode: dnsmessage.RCodeFormatError, + }, + } + response, packErr := errorMsg.Pack() + if packErr != nil { + log.Printf("Error packing error response: %v", packErr) + return nil + } + return response + } + + if len(msg.Questions) == 0 { + log.Printf("DNS query contains no questions") + return nil + } + + upstreamConn, err := net.Dial("udp", df.upstreamDNS) + if err != nil { + log.Printf("Error connecting to upstream DNS: %v", err) + return nil + } + defer upstreamConn.Close() + + // Set a reasonable timeout for upstream queries (5 seconds) + upstreamConn.SetDeadline(time.Now().Add(5 * time.Second)) + + _, err = upstreamConn.Write(queryData) + if err != nil { + log.Printf("Error sending query to upstream: %v", err) + return nil + } + + responseBuffer := make([]byte, 4096) + n, err := upstreamConn.Read(responseBuffer) + if err != nil { + log.Printf("Error reading response from upstream: %v", err) + return nil + } + + if n < 12 { + log.Printf("Response too short to be valid DNS message") + return nil + } + + return df.filterDNS64Responses(responseBuffer[:n]) +} + +func (df *DNSForwarder) filterDNS64Responses(responseData []byte) []byte { + var msg dnsmessage.Message + if err := msg.Unpack(responseData); err != nil { + log.Printf("Error unpacking response: %v", err) + return responseData + } + + var filteredAnswers []dnsmessage.Resource + for _, answer := range msg.Answers { + if answer.Header.Type == dnsmessage.TypeAAAA { + aaaa, ok := answer.Body.(*dnsmessage.AAAAResource) + if ok { + ip := net.IP(aaaa.AAAA[:]) + if !df.nat64Prefix.Contains(ip) { + filteredAnswers = append(filteredAnswers, answer) + } + } + } else { + filteredAnswers = append(filteredAnswers, answer) + } + } + + msg.Answers = filteredAnswers + response, err := msg.Pack() + if err != nil { + log.Printf("Error packing filtered response: %v", err) + return responseData + } + + return response +} + +func (df *DNSForwarder) Start() error { + var wg sync.WaitGroup + for _, addr := range df.listenAddresses { + conn, err := df.createSocket(addr.IP, addr.Port) + if err != nil { + return fmt.Errorf("failed to create socket for %s:%d: %v", addr.IP, addr.Port, err) + } + + wg.Add(1) + go func() { + defer wg.Done() + df.handleSocket(conn) + }() + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + log.Println("Shutting down...") + return nil +} + +func parseListenAddress(addr string) ListenAddress { + parts := strings.Split(addr, ":") + if len(parts) < 2 { + return ListenAddress{IP: "::", Port: 53} + } + port := 53 + fmt.Sscanf(parts[len(parts)-1], "%d", &port) + ip := strings.Join(parts[:len(parts)-1], ":") + if ip == "" { + ip = "::" + } + return ListenAddress{IP: ip, Port: port} +} + +func main() { + upstreamDNS := flag.String("upstream", "[2606:4700:4700::1111]:53", "Upstream DNS server (format: ip:port or [ipv6]:port)") + listenAddrs := flag.String("listen", "[::]:53", "Comma-separated list of IP:PORT to listen on") + flag.Parse() + + var addresses []ListenAddress + for _, addr := range strings.Split(*listenAddrs, ",") { + addresses = append(addresses, parseListenAddress(addr)) + } + + forwarder := NewDNSForwarder(*upstreamDNS, addresses) + if err := forwarder.Start(); err != nil { + log.Fatal(err) + } +} diff --git a/python/strip_dns64.py b/python/strip_dns64.py new file mode 100644 index 0000000..1cefd9f --- /dev/null +++ b/python/strip_dns64.py @@ -0,0 +1,154 @@ +import argparse +import ipaddress +import socket +import sys +import threading +from typing import Optional, Tuple, Union, List + +import dns.message +import dns.query +import dns.resolver +import dns.rrset + + +class DNSForwarder: + def __init__(self, upstream_dns: str, listen_addresses: List[Tuple[str, int]]): + self.upstream_dns = upstream_dns + self.listen_addresses = listen_addresses + self.nat64_prefix = ipaddress.IPv6Network("64:ff9b::/96") + self.sockets = [] + + def create_socket(self, listen_ip: str, listen_port: int) -> socket.socket: + if ':' in listen_ip: + socket_family = socket.AF_INET6 + else: + socket_family = socket.AF_INET + + sock = socket.socket(socket_family, socket.SOCK_DGRAM) + + if listen_ip == '::' and socket_family == socket.AF_INET6: + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + + sock.bind((listen_ip, listen_port)) + print(f"DNS forwarder listening on {listen_ip}:{listen_port}") + return sock + + def handle_socket(self, sock: socket.socket) -> None: + while True: + try: + data, addr = sock.recvfrom(4096) + response = self.process_query(data) + if response: + sock.sendto(response, addr) + except Exception as e: + print(f"Error handling socket: {e}") + + def start(self) -> None: + print(f"Forwarding to upstream DNS server: {self.upstream_dns}") + + # Create sockets for all listen addresses + threads = [] + for listen_ip, listen_port in self.listen_addresses: + try: + sock = self.create_socket(listen_ip, listen_port) + self.sockets.append(sock) + + # Create a thread for each socket + thread = threading.Thread(target=self.handle_socket, args=(sock,)) + thread.daemon = True + threads.append(thread) + thread.start() + except Exception as e: + print(f"Failed to create socket for [{listen_ip}]:{listen_port}: {e}") + + # Wait for all threads + try: + for thread in threads: + thread.join() + except KeyboardInterrupt: + print("\nShutting down...") + for sock in self.sockets: + sock.close() + sys.exit(0) + + def process_query(self, query_data: bytes) -> Optional[bytes]: + try: + query = dns.message.from_wire(query_data) + response = dns.query.udp(query, self.upstream_dns) + + filtered_response = self.filter_dns64_responses(response) + + return filtered_response.to_wire() + except Exception as e: + print(f"Error processing query: {e}") + return None + + def filter_dns64_responses(self, response: dns.message.Message) -> dns.message.Message: + filtered_response = dns.message.Message(response.id) + filtered_response.flags = response.flags + filtered_response.set_opcode(response.opcode()) + filtered_response.set_rcode(response.rcode()) + + # Copy questions + filtered_response.question = response.question + + for rrset in response.answer: + filtered_rrset = dns.rrset.RRset(rrset.name, rrset.rdclass, rrset.rdtype) + + for rr in rrset: + if rrset.rdtype == dns.rdatatype.AAAA: + ip = ipaddress.IPv6Address(rr.address) + if not ip in self.nat64_prefix: + filtered_rrset.add(rr) + else: + filtered_rrset.add(rr) + + if len(filtered_rrset) > 0: + filtered_response.answer.append(filtered_rrset) + + # Copy additional and authority sections + filtered_response.additional = response.additional + filtered_response.authority = response.authority + + return filtered_response + + +def get_default_dns() -> str: + resolver = dns.resolver.Resolver() + return resolver.nameservers[0] + + +def parse_arguments() -> Tuple[str, List[Tuple[str, int]]]: + parser = argparse.ArgumentParser(description="DNS forwarder that strips DNS64 responses") + parser.add_argument("--upstream", help="Upstream DNS server IP address", default=get_default_dns()) + parser.add_argument("--listen", help="IP address and port to listen on (format: IP:PORT)", + action='append', default=[]) + args = parser.parse_args() + + # If no listen addresses specified, use default + if not args.listen: + args.listen = [":::53"] + + # Parse listen addresses + listen_addresses = [] + for addr in args.listen: + if ':' in addr: + ip, port = addr.rsplit(':', 1) + if not ip: + ip = "::" + else: + ip = addr + port = "53" + listen_addresses.append((ip, int(port))) + + return args.upstream, listen_addresses + + +def main() -> None: + upstream_dns, listen_addresses = parse_arguments() + forwarder = DNSForwarder(upstream_dns, listen_addresses) + forwarder.start() + + +if __name__ == "__main__": + main() -- cgit v1.2.3-54-g00ecf