aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan <me@jordan.im>2024-11-17 21:57:16 -0700
committerJordan <me@jordan.im>2024-11-17 21:57:16 -0700
commitac6921113db6ee4c786bfbe0e124c25569df2c68 (patch)
treedf584ebc160bfdc3f20262a49c62d49472ecae5b
downloadstrip-dns64-master.tar.gz
strip-dns64-master.zip
initial commitHEADmaster
-rw-r--r--.gitignore1
-rw-r--r--Dockerfile31
-rw-r--r--README27
-rwxr-xr-xentrypoint.sh2
-rw-r--r--go.mod5
-rw-r--r--go.sum2
-rw-r--r--main.go223
-rw-r--r--python/strip_dns64.py154
8 files changed, 445 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2c65bd2
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+strip-dns64
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..36c0691
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,31 @@
+FROM docker.io/library/golang:1.22-alpine AS builder
+
+WORKDIR /app
+COPY go.mod go.sum ./
+RUN go mod download
+
+COPY . .
+RUN CGO_ENABLED=0 GOOS=linux go build -o strip-dns64
+
+FROM alpine:3.19
+
+RUN apk --no-cache add ca-certificates libcap
+
+WORKDIR /app
+COPY --from=builder /app/strip-dns64 .
+COPY entrypoint.sh .
+
+RUN adduser -D -H -h /app dnsuser
+RUN chown dnsuser:dnsuser /app/strip-dns64 /app/entrypoint.sh
+RUN chmod +x /app/entrypoint.sh
+
+RUN setcap cap_net_bind_service=+ep /app/strip-dns64
+
+ENV UPSTREAM_DNS="[2606:4700:4700::1111]:53"
+ENV LISTEN_ADDRS="[::]:53"
+
+USER dnsuser
+
+ENTRYPOINT ["/app/entrypoint.sh"]
+
+EXPOSE 53/udp
diff --git a/README b/README
new file mode 100644
index 0000000..39b1982
--- /dev/null
+++ b/README
@@ -0,0 +1,27 @@
+strip-dns64 is a DNS forwarder which removes synthesized AAAA answers from
+upstream DNS64 servers, implemented initially in Python and then Go for improved
+performance
+
+podman build -t strip-dns64 .
+
+podman run -d \
+ --name strip-dns64 \
+ --network host \
+ -e UPSTREAM_DNS="[2001:4860:4860::8888]:53" \
+ -e LISTEN_ADDRS="127.0.0.1:53,[::]:53" \
+ strip-dns64
+
+podman run -d \
+ --name strip-dns64 \
+ --cap-add=NET_BIND_SERVICE \
+ -p 53:53/udp \
+ -p [::]:53:53/udp \
+ -e UPSTREAM_DNS="[2001:4860:4860::8888]:53" \
+ strip-dns64
+
+Usage of ./strip-dns64:
+ -listen string
+ Comma-separated list of IP:PORT to listen on (default "[::]:53")
+ -upstream string
+ Upstream DNS server (format: ip:port or [ipv6]:port) (default "[2606:4700:4700::1111]:53")
+
diff --git a/entrypoint.sh b/entrypoint.sh
new file mode 100755
index 0000000..d7cd260
--- /dev/null
+++ b/entrypoint.sh
@@ -0,0 +1,2 @@
+#!/bin/sh
+exec /app/strip-dns64 -upstream "${UPSTREAM_DNS}" -listen "${LISTEN_ADDRS}"
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..f47130e
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,5 @@
+module strip-dns64
+
+go 1.22.2
+
+require golang.org/x/net v0.30.0
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..b338806
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
+golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..091fb1d
--- /dev/null
+++ b/main.go
@@ -0,0 +1,223 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "os/signal"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+type DNSForwarder struct {
+ upstreamDNS string
+ listenAddresses []ListenAddress
+ nat64Prefix *net.IPNet
+}
+
+type ListenAddress struct {
+ IP string
+ Port int
+}
+
+func NewDNSForwarder(upstreamDNS string, listenAddresses []ListenAddress) *DNSForwarder {
+ _, nat64Network, _ := net.ParseCIDR("64:ff9b::/96")
+ return &DNSForwarder{
+ upstreamDNS: upstreamDNS,
+ listenAddresses: listenAddresses,
+ nat64Prefix: nat64Network,
+ }
+}
+
+func (df *DNSForwarder) createSocket(listenIP string, listenPort int) (*net.UDPConn, error) {
+ addr := &net.UDPAddr{
+ Port: listenPort,
+ }
+
+ if listenIP == "::" {
+ addr.IP = net.IPv6zero
+ } else {
+ addr.IP = net.ParseIP(listenIP)
+ }
+
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+
+ log.Printf("DNS forwarder listening on %s:%d", listenIP, listenPort)
+ return conn, nil
+}
+
+func (df *DNSForwarder) handleSocket(conn *net.UDPConn) {
+ buffer := make([]byte, 4096)
+ for {
+ n, remoteAddr, err := conn.ReadFromUDP(buffer)
+ if err != nil {
+ log.Printf("Error reading from socket: %v", err)
+ continue
+ }
+
+ go func(data []byte, addr *net.UDPAddr) {
+ response := df.processQuery(data[:n])
+ if response != nil {
+ _, err := conn.WriteToUDP(response, addr)
+ if err != nil {
+ log.Printf("Error sending response: %v", err)
+ }
+ }
+ }(buffer[:n], remoteAddr)
+ }
+}
+
+func (df *DNSForwarder) processQuery(queryData []byte) []byte {
+ // Check minimum DNS message size
+ if len(queryData) < 12 {
+ log.Printf("Query too short to be valid DNS message")
+ return nil
+ }
+
+ var msg dnsmessage.Message
+ if err := msg.Unpack(queryData); err != nil {
+ log.Printf("Error unpacking DNS message: %v", err)
+ errorMsg := &dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: msg.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeFormatError,
+ },
+ }
+ response, packErr := errorMsg.Pack()
+ if packErr != nil {
+ log.Printf("Error packing error response: %v", packErr)
+ return nil
+ }
+ return response
+ }
+
+ if len(msg.Questions) == 0 {
+ log.Printf("DNS query contains no questions")
+ return nil
+ }
+
+ upstreamConn, err := net.Dial("udp", df.upstreamDNS)
+ if err != nil {
+ log.Printf("Error connecting to upstream DNS: %v", err)
+ return nil
+ }
+ defer upstreamConn.Close()
+
+ // Set a reasonable timeout for upstream queries (5 seconds)
+ upstreamConn.SetDeadline(time.Now().Add(5 * time.Second))
+
+ _, err = upstreamConn.Write(queryData)
+ if err != nil {
+ log.Printf("Error sending query to upstream: %v", err)
+ return nil
+ }
+
+ responseBuffer := make([]byte, 4096)
+ n, err := upstreamConn.Read(responseBuffer)
+ if err != nil {
+ log.Printf("Error reading response from upstream: %v", err)
+ return nil
+ }
+
+ if n < 12 {
+ log.Printf("Response too short to be valid DNS message")
+ return nil
+ }
+
+ return df.filterDNS64Responses(responseBuffer[:n])
+}
+
+func (df *DNSForwarder) filterDNS64Responses(responseData []byte) []byte {
+ var msg dnsmessage.Message
+ if err := msg.Unpack(responseData); err != nil {
+ log.Printf("Error unpacking response: %v", err)
+ return responseData
+ }
+
+ var filteredAnswers []dnsmessage.Resource
+ for _, answer := range msg.Answers {
+ if answer.Header.Type == dnsmessage.TypeAAAA {
+ aaaa, ok := answer.Body.(*dnsmessage.AAAAResource)
+ if ok {
+ ip := net.IP(aaaa.AAAA[:])
+ if !df.nat64Prefix.Contains(ip) {
+ filteredAnswers = append(filteredAnswers, answer)
+ }
+ }
+ } else {
+ filteredAnswers = append(filteredAnswers, answer)
+ }
+ }
+
+ msg.Answers = filteredAnswers
+ response, err := msg.Pack()
+ if err != nil {
+ log.Printf("Error packing filtered response: %v", err)
+ return responseData
+ }
+
+ return response
+}
+
+func (df *DNSForwarder) Start() error {
+ var wg sync.WaitGroup
+ for _, addr := range df.listenAddresses {
+ conn, err := df.createSocket(addr.IP, addr.Port)
+ if err != nil {
+ return fmt.Errorf("failed to create socket for %s:%d: %v", addr.IP, addr.Port, err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ df.handleSocket(conn)
+ }()
+ }
+
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+ <-sigChan
+
+ log.Println("Shutting down...")
+ return nil
+}
+
+func parseListenAddress(addr string) ListenAddress {
+ parts := strings.Split(addr, ":")
+ if len(parts) < 2 {
+ return ListenAddress{IP: "::", Port: 53}
+ }
+ port := 53
+ fmt.Sscanf(parts[len(parts)-1], "%d", &port)
+ ip := strings.Join(parts[:len(parts)-1], ":")
+ if ip == "" {
+ ip = "::"
+ }
+ return ListenAddress{IP: ip, Port: port}
+}
+
+func main() {
+ upstreamDNS := flag.String("upstream", "[2606:4700:4700::1111]:53", "Upstream DNS server (format: ip:port or [ipv6]:port)")
+ listenAddrs := flag.String("listen", "[::]:53", "Comma-separated list of IP:PORT to listen on")
+ flag.Parse()
+
+ var addresses []ListenAddress
+ for _, addr := range strings.Split(*listenAddrs, ",") {
+ addresses = append(addresses, parseListenAddress(addr))
+ }
+
+ forwarder := NewDNSForwarder(*upstreamDNS, addresses)
+ if err := forwarder.Start(); err != nil {
+ log.Fatal(err)
+ }
+}
diff --git a/python/strip_dns64.py b/python/strip_dns64.py
new file mode 100644
index 0000000..1cefd9f
--- /dev/null
+++ b/python/strip_dns64.py
@@ -0,0 +1,154 @@
+import argparse
+import ipaddress
+import socket
+import sys
+import threading
+from typing import Optional, Tuple, Union, List
+
+import dns.message
+import dns.query
+import dns.resolver
+import dns.rrset
+
+
+class DNSForwarder:
+ def __init__(self, upstream_dns: str, listen_addresses: List[Tuple[str, int]]):
+ self.upstream_dns = upstream_dns
+ self.listen_addresses = listen_addresses
+ self.nat64_prefix = ipaddress.IPv6Network("64:ff9b::/96")
+ self.sockets = []
+
+ def create_socket(self, listen_ip: str, listen_port: int) -> socket.socket:
+ if ':' in listen_ip:
+ socket_family = socket.AF_INET6
+ else:
+ socket_family = socket.AF_INET
+
+ sock = socket.socket(socket_family, socket.SOCK_DGRAM)
+
+ if listen_ip == '::' and socket_family == socket.AF_INET6:
+ sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+
+ sock.bind((listen_ip, listen_port))
+ print(f"DNS forwarder listening on {listen_ip}:{listen_port}")
+ return sock
+
+ def handle_socket(self, sock: socket.socket) -> None:
+ while True:
+ try:
+ data, addr = sock.recvfrom(4096)
+ response = self.process_query(data)
+ if response:
+ sock.sendto(response, addr)
+ except Exception as e:
+ print(f"Error handling socket: {e}")
+
+ def start(self) -> None:
+ print(f"Forwarding to upstream DNS server: {self.upstream_dns}")
+
+ # Create sockets for all listen addresses
+ threads = []
+ for listen_ip, listen_port in self.listen_addresses:
+ try:
+ sock = self.create_socket(listen_ip, listen_port)
+ self.sockets.append(sock)
+
+ # Create a thread for each socket
+ thread = threading.Thread(target=self.handle_socket, args=(sock,))
+ thread.daemon = True
+ threads.append(thread)
+ thread.start()
+ except Exception as e:
+ print(f"Failed to create socket for [{listen_ip}]:{listen_port}: {e}")
+
+ # Wait for all threads
+ try:
+ for thread in threads:
+ thread.join()
+ except KeyboardInterrupt:
+ print("\nShutting down...")
+ for sock in self.sockets:
+ sock.close()
+ sys.exit(0)
+
+ def process_query(self, query_data: bytes) -> Optional[bytes]:
+ try:
+ query = dns.message.from_wire(query_data)
+ response = dns.query.udp(query, self.upstream_dns)
+
+ filtered_response = self.filter_dns64_responses(response)
+
+ return filtered_response.to_wire()
+ except Exception as e:
+ print(f"Error processing query: {e}")
+ return None
+
+ def filter_dns64_responses(self, response: dns.message.Message) -> dns.message.Message:
+ filtered_response = dns.message.Message(response.id)
+ filtered_response.flags = response.flags
+ filtered_response.set_opcode(response.opcode())
+ filtered_response.set_rcode(response.rcode())
+
+ # Copy questions
+ filtered_response.question = response.question
+
+ for rrset in response.answer:
+ filtered_rrset = dns.rrset.RRset(rrset.name, rrset.rdclass, rrset.rdtype)
+
+ for rr in rrset:
+ if rrset.rdtype == dns.rdatatype.AAAA:
+ ip = ipaddress.IPv6Address(rr.address)
+ if not ip in self.nat64_prefix:
+ filtered_rrset.add(rr)
+ else:
+ filtered_rrset.add(rr)
+
+ if len(filtered_rrset) > 0:
+ filtered_response.answer.append(filtered_rrset)
+
+ # Copy additional and authority sections
+ filtered_response.additional = response.additional
+ filtered_response.authority = response.authority
+
+ return filtered_response
+
+
+def get_default_dns() -> str:
+ resolver = dns.resolver.Resolver()
+ return resolver.nameservers[0]
+
+
+def parse_arguments() -> Tuple[str, List[Tuple[str, int]]]:
+ parser = argparse.ArgumentParser(description="DNS forwarder that strips DNS64 responses")
+ parser.add_argument("--upstream", help="Upstream DNS server IP address", default=get_default_dns())
+ parser.add_argument("--listen", help="IP address and port to listen on (format: IP:PORT)",
+ action='append', default=[])
+ args = parser.parse_args()
+
+ # If no listen addresses specified, use default
+ if not args.listen:
+ args.listen = [":::53"]
+
+ # Parse listen addresses
+ listen_addresses = []
+ for addr in args.listen:
+ if ':' in addr:
+ ip, port = addr.rsplit(':', 1)
+ if not ip:
+ ip = "::"
+ else:
+ ip = addr
+ port = "53"
+ listen_addresses.append((ip, int(port)))
+
+ return args.upstream, listen_addresses
+
+
+def main() -> None:
+ upstream_dns, listen_addresses = parse_arguments()
+ forwarder = DNSForwarder(upstream_dns, listen_addresses)
+ forwarder.start()
+
+
+if __name__ == "__main__":
+ main()