aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go223
1 files changed, 223 insertions, 0 deletions
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..091fb1d
--- /dev/null
+++ b/main.go
@@ -0,0 +1,223 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "os/signal"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "golang.org/x/net/dns/dnsmessage"
+)
+
+type DNSForwarder struct {
+ upstreamDNS string
+ listenAddresses []ListenAddress
+ nat64Prefix *net.IPNet
+}
+
+type ListenAddress struct {
+ IP string
+ Port int
+}
+
+func NewDNSForwarder(upstreamDNS string, listenAddresses []ListenAddress) *DNSForwarder {
+ _, nat64Network, _ := net.ParseCIDR("64:ff9b::/96")
+ return &DNSForwarder{
+ upstreamDNS: upstreamDNS,
+ listenAddresses: listenAddresses,
+ nat64Prefix: nat64Network,
+ }
+}
+
+func (df *DNSForwarder) createSocket(listenIP string, listenPort int) (*net.UDPConn, error) {
+ addr := &net.UDPAddr{
+ Port: listenPort,
+ }
+
+ if listenIP == "::" {
+ addr.IP = net.IPv6zero
+ } else {
+ addr.IP = net.ParseIP(listenIP)
+ }
+
+ conn, err := net.ListenUDP("udp", addr)
+ if err != nil {
+ return nil, err
+ }
+
+ log.Printf("DNS forwarder listening on %s:%d", listenIP, listenPort)
+ return conn, nil
+}
+
+func (df *DNSForwarder) handleSocket(conn *net.UDPConn) {
+ buffer := make([]byte, 4096)
+ for {
+ n, remoteAddr, err := conn.ReadFromUDP(buffer)
+ if err != nil {
+ log.Printf("Error reading from socket: %v", err)
+ continue
+ }
+
+ go func(data []byte, addr *net.UDPAddr) {
+ response := df.processQuery(data[:n])
+ if response != nil {
+ _, err := conn.WriteToUDP(response, addr)
+ if err != nil {
+ log.Printf("Error sending response: %v", err)
+ }
+ }
+ }(buffer[:n], remoteAddr)
+ }
+}
+
+func (df *DNSForwarder) processQuery(queryData []byte) []byte {
+ // Check minimum DNS message size
+ if len(queryData) < 12 {
+ log.Printf("Query too short to be valid DNS message")
+ return nil
+ }
+
+ var msg dnsmessage.Message
+ if err := msg.Unpack(queryData); err != nil {
+ log.Printf("Error unpacking DNS message: %v", err)
+ errorMsg := &dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: msg.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeFormatError,
+ },
+ }
+ response, packErr := errorMsg.Pack()
+ if packErr != nil {
+ log.Printf("Error packing error response: %v", packErr)
+ return nil
+ }
+ return response
+ }
+
+ if len(msg.Questions) == 0 {
+ log.Printf("DNS query contains no questions")
+ return nil
+ }
+
+ upstreamConn, err := net.Dial("udp", df.upstreamDNS)
+ if err != nil {
+ log.Printf("Error connecting to upstream DNS: %v", err)
+ return nil
+ }
+ defer upstreamConn.Close()
+
+ // Set a reasonable timeout for upstream queries (5 seconds)
+ upstreamConn.SetDeadline(time.Now().Add(5 * time.Second))
+
+ _, err = upstreamConn.Write(queryData)
+ if err != nil {
+ log.Printf("Error sending query to upstream: %v", err)
+ return nil
+ }
+
+ responseBuffer := make([]byte, 4096)
+ n, err := upstreamConn.Read(responseBuffer)
+ if err != nil {
+ log.Printf("Error reading response from upstream: %v", err)
+ return nil
+ }
+
+ if n < 12 {
+ log.Printf("Response too short to be valid DNS message")
+ return nil
+ }
+
+ return df.filterDNS64Responses(responseBuffer[:n])
+}
+
+func (df *DNSForwarder) filterDNS64Responses(responseData []byte) []byte {
+ var msg dnsmessage.Message
+ if err := msg.Unpack(responseData); err != nil {
+ log.Printf("Error unpacking response: %v", err)
+ return responseData
+ }
+
+ var filteredAnswers []dnsmessage.Resource
+ for _, answer := range msg.Answers {
+ if answer.Header.Type == dnsmessage.TypeAAAA {
+ aaaa, ok := answer.Body.(*dnsmessage.AAAAResource)
+ if ok {
+ ip := net.IP(aaaa.AAAA[:])
+ if !df.nat64Prefix.Contains(ip) {
+ filteredAnswers = append(filteredAnswers, answer)
+ }
+ }
+ } else {
+ filteredAnswers = append(filteredAnswers, answer)
+ }
+ }
+
+ msg.Answers = filteredAnswers
+ response, err := msg.Pack()
+ if err != nil {
+ log.Printf("Error packing filtered response: %v", err)
+ return responseData
+ }
+
+ return response
+}
+
+func (df *DNSForwarder) Start() error {
+ var wg sync.WaitGroup
+ for _, addr := range df.listenAddresses {
+ conn, err := df.createSocket(addr.IP, addr.Port)
+ if err != nil {
+ return fmt.Errorf("failed to create socket for %s:%d: %v", addr.IP, addr.Port, err)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ df.handleSocket(conn)
+ }()
+ }
+
+ sigChan := make(chan os.Signal, 1)
+ signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
+ <-sigChan
+
+ log.Println("Shutting down...")
+ return nil
+}
+
+func parseListenAddress(addr string) ListenAddress {
+ parts := strings.Split(addr, ":")
+ if len(parts) < 2 {
+ return ListenAddress{IP: "::", Port: 53}
+ }
+ port := 53
+ fmt.Sscanf(parts[len(parts)-1], "%d", &port)
+ ip := strings.Join(parts[:len(parts)-1], ":")
+ if ip == "" {
+ ip = "::"
+ }
+ return ListenAddress{IP: ip, Port: port}
+}
+
+func main() {
+ upstreamDNS := flag.String("upstream", "[2606:4700:4700::1111]:53", "Upstream DNS server (format: ip:port or [ipv6]:port)")
+ listenAddrs := flag.String("listen", "[::]:53", "Comma-separated list of IP:PORT to listen on")
+ flag.Parse()
+
+ var addresses []ListenAddress
+ for _, addr := range strings.Split(*listenAddrs, ",") {
+ addresses = append(addresses, parseListenAddress(addr))
+ }
+
+ forwarder := NewDNSForwarder(*upstreamDNS, addresses)
+ if err := forwarder.Start(); err != nil {
+ log.Fatal(err)
+ }
+}