diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 223 |
1 files changed, 223 insertions, 0 deletions
@@ -0,0 +1,223 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "golang.org/x/net/dns/dnsmessage" +) + +type DNSForwarder struct { + upstreamDNS string + listenAddresses []ListenAddress + nat64Prefix *net.IPNet +} + +type ListenAddress struct { + IP string + Port int +} + +func NewDNSForwarder(upstreamDNS string, listenAddresses []ListenAddress) *DNSForwarder { + _, nat64Network, _ := net.ParseCIDR("64:ff9b::/96") + return &DNSForwarder{ + upstreamDNS: upstreamDNS, + listenAddresses: listenAddresses, + nat64Prefix: nat64Network, + } +} + +func (df *DNSForwarder) createSocket(listenIP string, listenPort int) (*net.UDPConn, error) { + addr := &net.UDPAddr{ + Port: listenPort, + } + + if listenIP == "::" { + addr.IP = net.IPv6zero + } else { + addr.IP = net.ParseIP(listenIP) + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + + log.Printf("DNS forwarder listening on %s:%d", listenIP, listenPort) + return conn, nil +} + +func (df *DNSForwarder) handleSocket(conn *net.UDPConn) { + buffer := make([]byte, 4096) + for { + n, remoteAddr, err := conn.ReadFromUDP(buffer) + if err != nil { + log.Printf("Error reading from socket: %v", err) + continue + } + + go func(data []byte, addr *net.UDPAddr) { + response := df.processQuery(data[:n]) + if response != nil { + _, err := conn.WriteToUDP(response, addr) + if err != nil { + log.Printf("Error sending response: %v", err) + } + } + }(buffer[:n], remoteAddr) + } +} + +func (df *DNSForwarder) processQuery(queryData []byte) []byte { + // Check minimum DNS message size + if len(queryData) < 12 { + log.Printf("Query too short to be valid DNS message") + return nil + } + + var msg dnsmessage.Message + if err := msg.Unpack(queryData); err != nil { + log.Printf("Error unpacking DNS message: %v", err) + errorMsg := &dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: msg.Header.ID, + Response: true, + RCode: dnsmessage.RCodeFormatError, + }, + } + response, packErr := errorMsg.Pack() + if packErr != nil { + log.Printf("Error packing error response: %v", packErr) + return nil + } + return response + } + + if len(msg.Questions) == 0 { + log.Printf("DNS query contains no questions") + return nil + } + + upstreamConn, err := net.Dial("udp", df.upstreamDNS) + if err != nil { + log.Printf("Error connecting to upstream DNS: %v", err) + return nil + } + defer upstreamConn.Close() + + // Set a reasonable timeout for upstream queries (5 seconds) + upstreamConn.SetDeadline(time.Now().Add(5 * time.Second)) + + _, err = upstreamConn.Write(queryData) + if err != nil { + log.Printf("Error sending query to upstream: %v", err) + return nil + } + + responseBuffer := make([]byte, 4096) + n, err := upstreamConn.Read(responseBuffer) + if err != nil { + log.Printf("Error reading response from upstream: %v", err) + return nil + } + + if n < 12 { + log.Printf("Response too short to be valid DNS message") + return nil + } + + return df.filterDNS64Responses(responseBuffer[:n]) +} + +func (df *DNSForwarder) filterDNS64Responses(responseData []byte) []byte { + var msg dnsmessage.Message + if err := msg.Unpack(responseData); err != nil { + log.Printf("Error unpacking response: %v", err) + return responseData + } + + var filteredAnswers []dnsmessage.Resource + for _, answer := range msg.Answers { + if answer.Header.Type == dnsmessage.TypeAAAA { + aaaa, ok := answer.Body.(*dnsmessage.AAAAResource) + if ok { + ip := net.IP(aaaa.AAAA[:]) + if !df.nat64Prefix.Contains(ip) { + filteredAnswers = append(filteredAnswers, answer) + } + } + } else { + filteredAnswers = append(filteredAnswers, answer) + } + } + + msg.Answers = filteredAnswers + response, err := msg.Pack() + if err != nil { + log.Printf("Error packing filtered response: %v", err) + return responseData + } + + return response +} + +func (df *DNSForwarder) Start() error { + var wg sync.WaitGroup + for _, addr := range df.listenAddresses { + conn, err := df.createSocket(addr.IP, addr.Port) + if err != nil { + return fmt.Errorf("failed to create socket for %s:%d: %v", addr.IP, addr.Port, err) + } + + wg.Add(1) + go func() { + defer wg.Done() + df.handleSocket(conn) + }() + } + + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + log.Println("Shutting down...") + return nil +} + +func parseListenAddress(addr string) ListenAddress { + parts := strings.Split(addr, ":") + if len(parts) < 2 { + return ListenAddress{IP: "::", Port: 53} + } + port := 53 + fmt.Sscanf(parts[len(parts)-1], "%d", &port) + ip := strings.Join(parts[:len(parts)-1], ":") + if ip == "" { + ip = "::" + } + return ListenAddress{IP: ip, Port: port} +} + +func main() { + upstreamDNS := flag.String("upstream", "[2606:4700:4700::1111]:53", "Upstream DNS server (format: ip:port or [ipv6]:port)") + listenAddrs := flag.String("listen", "[::]:53", "Comma-separated list of IP:PORT to listen on") + flag.Parse() + + var addresses []ListenAddress + for _, addr := range strings.Split(*listenAddrs, ",") { + addresses = append(addresses, parseListenAddress(addr)) + } + + forwarder := NewDNSForwarder(*upstreamDNS, addresses) + if err := forwarder.Start(); err != nil { + log.Fatal(err) + } +} |