// Copyright 2013 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris package net import ( "context" "errors" "fmt" "os" "path" "reflect" "strings" "sync" "sync/atomic" "testing" "time" "golang.org/x/net/dns/dnsmessage" ) var goResolver = Resolver{PreferGo: true} // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation. var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01} // Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation. var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} func mustNewName(name string) dnsmessage.Name { nn, err := dnsmessage.NewName(name) if err != nil { panic(fmt.Sprint("creating name: ", err)) } return nn } func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question { return dnsmessage.Question{ Name: mustNewName(name), Type: qtype, Class: class, } } var dnsTransportFallbackTests = []struct { server string question dnsmessage.Question timeout int rcode dnsmessage.RCode }{ // Querying "com." with qtype=255 usually makes an answer // which requires more than 512 bytes. {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess}, {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess}, } func TestDNSTransportFallback(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, }, Questions: q.Questions, } if n == "udp" { r.Header.Truncated = true } return r, nil }, } r := Resolver{PreferGo: true, Dial: fake.DialContext} for _, tt := range dnsTransportFallbackTests { ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second, useUDPOrTCP) if err != nil { t.Error(err) continue } if h.RCode != tt.rcode { t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode) continue } } } // See RFC 6761 for further information about the reserved, pseudo // domain names. var specialDomainNameTests = []struct { question dnsmessage.Question rcode dnsmessage.RCode }{ // Name resolution APIs and libraries should not recognize the // followings as special. {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess}, // Name resolution APIs and libraries should recognize the // followings as special and should not send any queries. // Though, we test those names here for verifying negative // answers at DNS query-response interaction level. {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, } func TestSpecialDomainName(t *testing.T) { fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } switch q.Questions[0].Name.String() { case "example.com.": r.Header.RCode = dnsmessage.RCodeSuccess default: r.Header.RCode = dnsmessage.RCodeNameError } return r, nil }} r := Resolver{PreferGo: true, Dial: fake.DialContext} server := "8.8.8.8:53" for _, tt := range specialDomainNameTests { ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second, useUDPOrTCP) if err != nil { t.Error(err) continue } if h.RCode != tt.rcode { t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode) continue } } } // Issue 13705: don't try to resolve onion addresses, etc func TestAvoidDNSName(t *testing.T) { tests := []struct { name string avoid bool }{ {"foo.com", false}, {"foo.com.", false}, {"foo.onion.", true}, {"foo.onion", true}, {"foo.ONION", true}, {"foo.ONION.", true}, // But do resolve *.local address; Issue 16739 {"foo.local.", false}, {"foo.local", false}, {"foo.LOCAL", false}, {"foo.LOCAL.", false}, {"", true}, // will be rejected earlier too // Without stuff before onion/local, they're fine to // use DNS. With a search path, // "onion.vegetables.com" can use DNS. Without a // search path (or with a trailing dot), the queries // are just kinda useless, but don't reveal anything // private. {"local", false}, {"onion", false}, {"local.", false}, {"onion.", false}, } for _, tt := range tests { got := avoidDNS(tt.name) if got != tt.avoid { t.Errorf("avoidDNS(%q) = %v; want %v", tt.name, got, tt.avoid) } } } var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA { r.Answers = []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }, } } return r, nil }} // Issue 13705: don't try to resolve onion addresses, etc func TestLookupTorOnion(t *testing.T) { defer dnsWaitGroup.Wait() r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} addrs, err := r.LookupIPAddr(context.Background(), "foo.onion") if err != nil { t.Fatalf("lookup = %v; want nil", err) } if len(addrs) > 0 { t.Errorf("unexpected addresses: %v", addrs) } } type resolvConfTest struct { dir string path string *resolverConfig } func newResolvConfTest() (*resolvConfTest, error) { dir, err := os.MkdirTemp("", "go-resolvconftest") if err != nil { return nil, err } conf := &resolvConfTest{ dir: dir, path: path.Join(dir, "resolv.conf"), resolverConfig: &resolvConf, } conf.initOnce.Do(conf.init) return conf, nil } func (conf *resolvConfTest) writeAndUpdate(lines []string) error { f, err := os.OpenFile(conf.path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { return err } if _, err := f.WriteString(strings.Join(lines, "\n")); err != nil { f.Close() return err } f.Close() if err := conf.forceUpdate(conf.path, time.Now().Add(time.Hour)); err != nil { return err } return nil } func (conf *resolvConfTest) forceUpdate(name string, lastChecked time.Time) error { dnsConf := dnsReadConfig(name) conf.mu.Lock() conf.dnsConfig = dnsConf conf.mu.Unlock() for i := 0; i < 5; i++ { if conf.tryAcquireSema() { conf.lastChecked = lastChecked conf.releaseSema() return nil } } return fmt.Errorf("tryAcquireSema for %s failed", name) } func (conf *resolvConfTest) servers() []string { conf.mu.RLock() servers := conf.dnsConfig.servers conf.mu.RUnlock() return servers } func (conf *resolvConfTest) teardown() error { err := conf.forceUpdate("/etc/resolv.conf", time.Time{}) os.RemoveAll(conf.dir) return err } var updateResolvConfTests = []struct { name string // query name lines []string // resolver configuration lines servers []string // expected name servers }{ { name: "golang.org", lines: []string{"nameserver 8.8.8.8"}, servers: []string{"8.8.8.8:53"}, }, { name: "", lines: nil, // an empty resolv.conf should use defaultNS as name servers servers: defaultNS, }, { name: "www.example.com", lines: []string{"nameserver 8.8.4.4"}, servers: []string{"8.8.4.4:53"}, }, } func TestUpdateResolvConf(t *testing.T) { defer dnsWaitGroup.Wait() r := Resolver{PreferGo: true, Dial: fakeDNSServerSuccessful.DialContext} conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() for i, tt := range updateResolvConfTests { if err := conf.writeAndUpdate(tt.lines); err != nil { t.Error(err) continue } if tt.name != "" { var wg sync.WaitGroup const N = 10 wg.Add(N) for j := 0; j < N; j++ { go func(name string) { defer wg.Done() ips, err := r.LookupIPAddr(context.Background(), name) if err != nil { t.Error(err) return } if len(ips) == 0 { t.Errorf("no records for %s", name) return } }(tt.name) } wg.Wait() } servers := conf.servers() if !reflect.DeepEqual(servers, tt.servers) { t.Errorf("#%d: got %v; want %v", i, servers, tt.servers) continue } } } var goLookupIPWithResolverConfigTests = []struct { name string lines []string // resolver configuration lines error a, aaaa bool // whether response contains A, AAAA-record }{ // no records, transport timeout { "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", []string{ "options timeout:1 attempts:1", "nameserver 255.255.255.255", // please forgive us for abuse of limited broadcast address }, &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "255.255.255.255:53", IsTimeout: true}, false, false, }, // no records, non-existent domain { "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", []string{ "options timeout:3 attempts:1", "nameserver 8.8.8.8", }, &DNSError{Name: "jgahvsekduiv9bw4b3qhn4ykdfgj0493iohkrjfhdvhjiu4j", Server: "8.8.8.8:53", IsTimeout: false}, false, false, }, // a few A records, no AAAA records { "ipv4.google.com.", []string{ "nameserver 8.8.8.8", "nameserver 2001:4860:4860::8888", }, nil, true, false, }, { "ipv4.google.com", []string{ "domain golang.org", "nameserver 2001:4860:4860::8888", "nameserver 8.8.8.8", }, nil, true, false, }, { "ipv4.google.com", []string{ "search x.golang.org y.golang.org", "nameserver 2001:4860:4860::8888", "nameserver 8.8.8.8", }, nil, true, false, }, // no A records, a few AAAA records { "ipv6.google.com.", []string{ "nameserver 2001:4860:4860::8888", "nameserver 8.8.8.8", }, nil, false, true, }, { "ipv6.google.com", []string{ "domain golang.org", "nameserver 8.8.8.8", "nameserver 2001:4860:4860::8888", }, nil, false, true, }, { "ipv6.google.com", []string{ "search x.golang.org y.golang.org", "nameserver 8.8.8.8", "nameserver 2001:4860:4860::8888", }, nil, false, true, }, // both A and AAAA records { "hostname.as112.net", // see RFC 7534 []string{ "domain golang.org", "nameserver 2001:4860:4860::8888", "nameserver 8.8.8.8", }, nil, true, true, }, { "hostname.as112.net", // see RFC 7534 []string{ "search x.golang.org y.golang.org", "nameserver 2001:4860:4860::8888", "nameserver 8.8.8.8", }, nil, true, true, }, } func TestGoLookupIPWithResolverConfig(t *testing.T) { defer dnsWaitGroup.Wait() fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { switch s { case "[2001:4860:4860::8888]:53", "8.8.8.8:53": break default: time.Sleep(10 * time.Millisecond) return dnsmessage.Message{}, os.ErrDeadlineExceeded } r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } for _, question := range q.Questions { switch question.Type { case dnsmessage.TypeA: switch question.Name.String() { case "hostname.as112.net.": break case "ipv4.google.com.": r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }) default: } case dnsmessage.TypeAAAA: switch question.Name.String() { case "hostname.as112.net.": break case "ipv6.google.com.": r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeAAAA, Class: dnsmessage.ClassINET, Length: 16, }, Body: &dnsmessage.AAAAResource{ AAAA: TestAddr6, }, }) } } } return r, nil }} r := Resolver{PreferGo: true, Dial: fake.DialContext} conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() for _, tt := range goLookupIPWithResolverConfigTests { if err := conf.writeAndUpdate(tt.lines); err != nil { t.Error(err) continue } addrs, err := r.LookupIPAddr(context.Background(), tt.name) if err != nil { if err, ok := err.(*DNSError); !ok || tt.error != nil && (err.Name != tt.error.(*DNSError).Name || err.Server != tt.error.(*DNSError).Server || err.IsTimeout != tt.error.(*DNSError).IsTimeout) { t.Errorf("got %v; want %v", err, tt.error) } continue } if len(addrs) == 0 { t.Errorf("no records for %s", tt.name) } if !tt.a && !tt.aaaa && len(addrs) > 0 { t.Errorf("unexpected %v for %s", addrs, tt.name) } for _, addr := range addrs { if !tt.a && addr.IP.To4() != nil { t.Errorf("got %v; must not be IPv4 address", addr) } if !tt.aaaa && addr.IP.To16() != nil && addr.IP.To4() == nil { t.Errorf("got %v; must not be IPv6 address", addr) } } } } // Test that goLookupIPOrder falls back to the host file when no DNS servers are available. func TestGoLookupIPOrderFallbackToFile(t *testing.T) { defer dnsWaitGroup.Wait() fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } return r, nil }} r := Resolver{PreferGo: true, Dial: fake.DialContext} // Add a config that simulates no dns servers being available. conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() if err := conf.writeAndUpdate([]string{}); err != nil { t.Fatal(err) } // Redirect host file lookups. defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) testHookHostsPath = "testdata/hosts" for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} { name := fmt.Sprintf("order %v", order) // First ensure that we get an error when contacting a non-existent host. _, _, err := r.goLookupIPCNAMEOrder(context.Background(), "notarealhost", order) if err == nil { t.Errorf("%s: expected error while looking up name not in hosts file", name) continue } // Now check that we get an address when the name appears in the hosts file. addrs, _, err := r.goLookupIPCNAMEOrder(context.Background(), "thor", order) // entry is in "testdata/hosts" if err != nil { t.Errorf("%s: expected to successfully lookup host entry", name) continue } if len(addrs) != 1 { t.Errorf("%s: expected exactly one result, but got %v", name, addrs) continue } if got, want := addrs[0].String(), "127.1.1.1"; got != want { t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want) } } } // Issue 12712. // When using search domains, return the error encountered // querying the original name instead of an error encountered // querying a generated name. func TestErrorForOriginalNameWhenSearching(t *testing.T) { defer dnsWaitGroup.Wait() const fqdn = "doesnotexist.domain" conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil { t.Fatal(err) } fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } switch q.Questions[0].Name.String() { case fqdn + ".servfail.": r.Header.RCode = dnsmessage.RCodeServerFailure default: r.Header.RCode = dnsmessage.RCodeNameError } return r, nil }} cases := []struct { strictErrors bool wantErr *DNSError }{ {true, &DNSError{Name: fqdn, Err: "server misbehaving", IsTemporary: true}}, {false, &DNSError{Name: fqdn, Err: errNoSuchHost.Error(), IsNotFound: true}}, } for _, tt := range cases { r := Resolver{PreferGo: true, StrictErrors: tt.strictErrors, Dial: fake.DialContext} _, err = r.LookupIPAddr(context.Background(), fqdn) if err == nil { t.Fatal("expected an error") } want := tt.wantErr if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err || err.IsTemporary != want.IsTemporary { t.Errorf("got %v; want %v", err, want) } } } // Issue 15434. If a name server gives a lame referral, continue to the next. func TestIgnoreLameReferrals(t *testing.T) { defer dnsWaitGroup.Wait() conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() if err := conf.writeAndUpdate([]string{"nameserver 192.0.2.1", // the one that will give a lame referral "nameserver 192.0.2.2"}); err != nil { t.Fatal(err) } fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { t.Log(s, q) r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } if s == "192.0.2.2:53" { r.Header.RecursionAvailable = true if q.Questions[0].Type == dnsmessage.TypeA { r.Answers = []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }, } } } return r, nil }} r := Resolver{PreferGo: true, Dial: fake.DialContext} addrs, err := r.LookupIPAddr(context.Background(), "www.golang.org") if err != nil { t.Fatal(err) } if got := len(addrs); got != 1 { t.Fatalf("got %d addresses, want 1", got) } if got, want := addrs[0].String(), "192.0.2.1"; got != want { t.Fatalf("got address %v, want %v", got, want) } } func BenchmarkGoLookupIP(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) ctx := context.Background() b.ReportAllocs() for i := 0; i < b.N; i++ { goResolver.LookupIPAddr(ctx, "www.example.com") } } func BenchmarkGoLookupIPNoSuchHost(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) ctx := context.Background() b.ReportAllocs() for i := 0; i < b.N; i++ { goResolver.LookupIPAddr(ctx, "some.nonexistent") } } func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { testHookUninstaller.Do(uninstallTestHooks) conf, err := newResolvConfTest() if err != nil { b.Fatal(err) } defer conf.teardown() lines := []string{ "nameserver 203.0.113.254", // use TEST-NET-3 block, see RFC 5737 "nameserver 8.8.8.8", } if err := conf.writeAndUpdate(lines); err != nil { b.Fatal(err) } ctx := context.Background() b.ReportAllocs() for i := 0; i < b.N; i++ { goResolver.LookupIPAddr(ctx, "www.example.com") } } type fakeDNSServer struct { rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) alwaysTCP bool } func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" { return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil } return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil } type fakeDNSConn struct { Conn tcp bool server *fakeDNSServer n string s string q dnsmessage.Message t time.Time buf []byte } func (f *fakeDNSConn) Close() error { return nil } func (f *fakeDNSConn) Read(b []byte) (int, error) { if len(f.buf) > 0 { n := copy(b, f.buf) f.buf = f.buf[n:] return n, nil } resp, err := f.server.rh(f.n, f.s, f.q, f.t) if err != nil { return 0, err } bb := make([]byte, 2, 514) bb, err = resp.AppendPack(bb) if err != nil { return 0, fmt.Errorf("cannot marshal DNS message: %v", err) } if f.tcp { l := len(bb) - 2 bb[0] = byte(l >> 8) bb[1] = byte(l) f.buf = bb return f.Read(b) } bb = bb[2:] if len(b) < len(bb) { return 0, errors.New("read would fragment DNS message") } copy(b, bb) return len(bb), nil } func (f *fakeDNSConn) Write(b []byte) (int, error) { if f.tcp && len(b) >= 2 { b = b[2:] } if f.q.Unpack(b) != nil { return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b)) } return len(b), nil } func (f *fakeDNSConn) SetDeadline(t time.Time) error { f.t = t return nil } type fakeDNSPacketConn struct { PacketConn fakeDNSConn } func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error { return f.fakeDNSConn.SetDeadline(t) } func (f *fakeDNSPacketConn) Close() error { return f.fakeDNSConn.Close() } // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281). func TestIgnoreDNSForgeries(t *testing.T) { c, s := Pipe() go func() { b := make([]byte, 512) n, err := s.Read(b) if err != nil { t.Error(err) return } var msg dnsmessage.Message if msg.Unpack(b[:n]) != nil { t.Error("invalid DNS query:", err) return } s.Write([]byte("garbage DNS response packet")) msg.Header.Response = true msg.Header.ID++ // make invalid ID if b, err = msg.Pack(); err != nil { t.Error("failed to pack DNS response:", err) return } s.Write(b) msg.Header.ID-- // restore original ID msg.Answers = []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: mustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }, } b, err = msg.Pack() if err != nil { t.Error("failed to pack DNS response:", err) return } s.Write(b) }() msg := dnsmessage.Message{ Header: dnsmessage.Header{ ID: 42, }, Questions: []dnsmessage.Question{ { Name: mustNewName("www.example.com."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, }, }, } b, err := msg.Pack() if err != nil { t.Fatal("Pack failed:", err) } p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b) if err != nil { t.Fatalf("dnsPacketRoundTrip failed: %v", err) } p.SkipAllQuestions() as, err := p.AllAnswers() if err != nil { t.Fatal("AllAnswers failed:", err) } if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr { t.Errorf("got address %v, want %v", got, TestAddr) } } // Issue 16865. If a name server times out, continue to the next. func TestRetryTimeout(t *testing.T) { defer dnsWaitGroup.Wait() conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() testConf := []string{ "nameserver 192.0.2.1", // the one that will timeout "nameserver 192.0.2.2", } if err := conf.writeAndUpdate(testConf); err != nil { t.Fatal(err) } var deadline0 time.Time fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q, deadline) if deadline.IsZero() { t.Error("zero deadline") } if s == "192.0.2.1:53" { deadline0 = deadline time.Sleep(10 * time.Millisecond) return dnsmessage.Message{}, os.ErrDeadlineExceeded } if deadline.Equal(deadline0) { t.Error("deadline didn't change") } return mockTXTResponse(q), nil }} r := &Resolver{PreferGo: true, Dial: fake.DialContext} _, err = r.LookupTXT(context.Background(), "www.golang.org") if err != nil { t.Fatal(err) } if deadline0.IsZero() { t.Error("deadline0 still zero", deadline0) } } func TestRotate(t *testing.T) { // without rotation, always uses the first server testRotate(t, false, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.1:53", "192.0.2.1:53"}) // with rotation, rotates through back to first testRotate(t, true, []string{"192.0.2.1", "192.0.2.2"}, []string{"192.0.2.1:53", "192.0.2.2:53", "192.0.2.1:53"}) } func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { defer dnsWaitGroup.Wait() conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() var confLines []string for _, ns := range nameservers { confLines = append(confLines, "nameserver "+ns) } if rotate { confLines = append(confLines, "options rotate") } if err := conf.writeAndUpdate(confLines); err != nil { t.Fatal(err) } var usedServers []string fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { usedServers = append(usedServers, s) return mockTXTResponse(q), nil }} r := Resolver{PreferGo: true, Dial: fake.DialContext} // len(nameservers) + 1 to allow rotation to get back to start for i := 0; i < len(nameservers)+1; i++ { if _, err := r.LookupTXT(context.Background(), "www.golang.org"); err != nil { t.Fatal(err) } } if !reflect.DeepEqual(usedServers, wantServers) { t.Errorf("rotate=%t got used servers:\n%v\nwant:\n%v", rotate, usedServers, wantServers) } } func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RecursionAvailable: true, }, Questions: q.Questions, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeTXT, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.TXTResource{ TXT: []string{"ok"}, }, }, }, } return r } // Issue 17448. With StrictErrors enabled, temporary errors should make // LookupIP fail rather than return a partial result. func TestStrictErrorsLookupIP(t *testing.T) { defer dnsWaitGroup.Wait() conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() confData := []string{ "nameserver 192.0.2.53", "search x.golang.org y.golang.org", } if err := conf.writeAndUpdate(confData); err != nil { t.Fatal(err) } const name = "test-issue19592" const server = "192.0.2.53:53" const searchX = "test-issue19592.x.golang.org." const searchY = "test-issue19592.y.golang.org." const ip4 = "192.0.2.1" const ip6 = "2001:db8::1" type resolveWhichEnum int const ( resolveOK resolveWhichEnum = iota resolveOpError resolveServfail resolveTimeout ) makeTempError := func(err string) error { return &DNSError{ Err: err, Name: name, Server: server, IsTemporary: true, } } makeTimeout := func() error { return &DNSError{ Err: os.ErrDeadlineExceeded.Error(), Name: name, Server: server, IsTimeout: true, } } makeNxDomain := func() error { return &DNSError{ Err: errNoSuchHost.Error(), Name: name, Server: server, IsNotFound: true, } } cases := []struct { desc string resolveWhich func(quest dnsmessage.Question) resolveWhichEnum wantStrictErr error wantLaxErr error wantIPs []string }{ { desc: "No errors", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { return resolveOK }, wantIPs: []string{ip4, ip6}, }, { desc: "searchX error fails in strict mode", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { if quest.Name.String() == searchX { return resolveTimeout } return resolveOK }, wantStrictErr: makeTimeout(), wantIPs: []string{ip4, ip6}, }, { desc: "searchX IPv4-only timeout fails in strict mode", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA { return resolveTimeout } return resolveOK }, wantStrictErr: makeTimeout(), wantIPs: []string{ip4, ip6}, }, { desc: "searchX IPv6-only servfail fails in strict mode", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA { return resolveServfail } return resolveOK }, wantStrictErr: makeTempError("server misbehaving"), wantIPs: []string{ip4, ip6}, }, { desc: "searchY error always fails", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { if quest.Name.String() == searchY { return resolveTimeout } return resolveOK }, wantStrictErr: makeTimeout(), wantLaxErr: makeNxDomain(), // This one reaches the "test." FQDN. }, { desc: "searchY IPv4-only socket error fails in strict mode", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA { return resolveOpError } return resolveOK }, wantStrictErr: makeTempError("write: socket on fire"), wantIPs: []string{ip6}, }, { desc: "searchY IPv6-only timeout fails in strict mode", resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA { return resolveTimeout } return resolveOK }, wantStrictErr: makeTimeout(), wantIPs: []string{ip4}, }, } for i, tt := range cases { fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) switch tt.resolveWhich(q.Questions[0]) { case resolveOK: // Handle below. case resolveOpError: return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} case resolveServfail: return dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RCode: dnsmessage.RCodeServerFailure, }, Questions: q.Questions, }, nil case resolveTimeout: return dnsmessage.Message{}, os.ErrDeadlineExceeded default: t.Fatal("Impossible resolveWhich") } switch q.Questions[0].Name.String() { case searchX, name + ".": // Return NXDOMAIN to utilize the search list. return dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RCode: dnsmessage.RCodeNameError, }, Questions: q.Questions, }, nil case searchY: // Return records below. default: return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name) } r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } switch q.Questions[0].Type { case dnsmessage.TypeA: r.Answers = []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }, } case dnsmessage.TypeAAAA: r.Answers = []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeAAAA, Class: dnsmessage.ClassINET, Length: 16, }, Body: &dnsmessage.AAAAResource{ AAAA: TestAddr6, }, }, } default: return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type) } return r, nil }} for _, strict := range []bool{true, false} { r := Resolver{PreferGo: true, StrictErrors: strict, Dial: fake.DialContext} ips, err := r.LookupIPAddr(context.Background(), name) var wantErr error if strict { wantErr = tt.wantStrictErr } else { wantErr = tt.wantLaxErr } if !reflect.DeepEqual(err, wantErr) { t.Errorf("#%d (%s) strict=%v: got err %#v; want %#v", i, tt.desc, strict, err, wantErr) } gotIPs := map[string]struct{}{} for _, ip := range ips { gotIPs[ip.String()] = struct{}{} } wantIPs := map[string]struct{}{} if wantErr == nil { for _, ip := range tt.wantIPs { wantIPs[ip] = struct{}{} } } if !reflect.DeepEqual(gotIPs, wantIPs) { t.Errorf("#%d (%s) strict=%v: got ips %v; want %v", i, tt.desc, strict, gotIPs, wantIPs) } } } } // Issue 17448. With StrictErrors enabled, temporary errors should make // LookupTXT stop walking the search list. func TestStrictErrorsLookupTXT(t *testing.T) { defer dnsWaitGroup.Wait() conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() confData := []string{ "nameserver 192.0.2.53", "search x.golang.org y.golang.org", } if err := conf.writeAndUpdate(confData); err != nil { t.Fatal(err) } const name = "test" const server = "192.0.2.53:53" const searchX = "test.x.golang.org." const searchY = "test.y.golang.org." const txt = "Hello World" fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) switch q.Questions[0].Name.String() { case searchX: return dnsmessage.Message{}, os.ErrDeadlineExceeded case searchY: return mockTXTResponse(q), nil default: return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name) } }} for _, strict := range []bool{true, false} { r := Resolver{StrictErrors: strict, Dial: fake.DialContext} p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT) var wantErr error var wantRRs int if strict { wantErr = &DNSError{ Err: os.ErrDeadlineExceeded.Error(), Name: name, Server: server, IsTimeout: true, } } else { wantRRs = 1 } if !reflect.DeepEqual(err, wantErr) { t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr) } a, err := p.AllAnswers() if err != nil { a = nil } if len(a) != wantRRs { t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs) } } } // Test for a race between uninstalling the test hooks and closing a // socket connection. This used to fail when testing with -race. func TestDNSGoroutineRace(t *testing.T) { defer dnsWaitGroup.Wait() fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { time.Sleep(10 * time.Microsecond) return dnsmessage.Message{}, os.ErrDeadlineExceeded }} r := Resolver{PreferGo: true, Dial: fake.DialContext} // The timeout here is less than the timeout used by the server, // so the goroutine started to query the (fake) server will hang // around after this test is done if we don't call dnsWaitGroup.Wait. ctx, cancel := context.WithTimeout(context.Background(), 2*time.Microsecond) defer cancel() _, err := r.LookupIPAddr(ctx, "where.are.they.now") if err == nil { t.Fatal("fake DNS lookup unexpectedly succeeded") } } func lookupWithFake(fake fakeDNSServer, name string, typ dnsmessage.Type) error { r := Resolver{PreferGo: true, Dial: fake.DialContext} resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, _, err := r.tryOneName(ctx, conf, name, typ) return err } // Issue 8434: verify that Temporary returns true on an error when rcode // is SERVFAIL func TestIssue8434(t *testing.T) { err := lookupWithFake(fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { return dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RCode: dnsmessage.RCodeServerFailure, }, Questions: q.Questions, }, nil }, }, "golang.org.", dnsmessage.TypeALL) if err == nil { t.Fatal("expected an error") } if ne, ok := err.(Error); !ok { t.Fatalf("err = %#v; wanted something supporting net.Error", err) } else if !ne.Temporary() { t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err) } if de, ok := err.(*DNSError); !ok { t.Fatalf("err = %#v; wanted a *net.DNSError", err) } else if !de.IsTemporary { t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err) } } func TestIssueNoSuchHostExists(t *testing.T) { err := lookupWithFake(fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { return dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RCode: dnsmessage.RCodeNameError, }, Questions: q.Questions, }, nil }, }, "golang.org.", dnsmessage.TypeALL) if err == nil { t.Fatal("expected an error") } if _, ok := err.(Error); !ok { t.Fatalf("err = %#v; wanted something supporting net.Error", err) } if de, ok := err.(*DNSError); !ok { t.Fatalf("err = %#v; wanted a *net.DNSError", err) } else if !de.IsNotFound { t.Fatalf("IsNotFound = false for err = %#v; want IsNotFound == true", err) } } // TestNoSuchHost verifies that tryOneName works correctly when the domain does // not exist. // // Issue 12778: verify that NXDOMAIN without RA bit errors as "no such host" // and not "server misbehaving" // // Issue 25336: verify that NXDOMAIN errors fail fast. // // Issue 27525: verify that empty answers fail fast. func TestNoSuchHost(t *testing.T) { tests := []struct { name string f func(string, string, dnsmessage.Message, time.Time) (dnsmessage.Message, error) }{ { "NXDOMAIN", func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { return dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RCode: dnsmessage.RCodeNameError, RecursionAvailable: false, }, Questions: q.Questions, }, nil }, }, { "no answers", func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { return dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, RCode: dnsmessage.RCodeSuccess, RecursionAvailable: false, Authoritative: true, }, Questions: q.Questions, }, nil }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { lookups := 0 err := lookupWithFake(fakeDNSServer{ rh: func(n, s string, q dnsmessage.Message, d time.Time) (dnsmessage.Message, error) { lookups++ return test.f(n, s, q, d) }, }, ".", dnsmessage.TypeALL) if lookups != 1 { t.Errorf("got %d lookups, wanted 1", lookups) } if err == nil { t.Fatal("expected an error") } de, ok := err.(*DNSError) if !ok { t.Fatalf("err = %#v; wanted a *net.DNSError", err) } if de.Err != errNoSuchHost.Error() { t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error()) } if !de.IsNotFound { t.Fatalf("IsNotFound = %v wanted true", de.IsNotFound) } }) } } // Issue 26573: verify that Conns that don't implement PacketConn are treated // as streams even when udp was requested. func TestDNSDialTCP(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, }, Questions: q.Questions, } return r, nil }, alwaysTCP: true, } r := Resolver{PreferGo: true, Dial: fake.DialContext} ctx := context.Background() _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useUDPOrTCP) if err != nil { t.Fatal("exhange failed:", err) } } // Issue 27763: verify that two strings in one TXT record are concatenated. func TestTXTRecordTwoStrings(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, }, Questions: q.Questions, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.TXTResource{ TXT: []string{"string1 ", "string2"}, }, }, { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.TXTResource{ TXT: []string{"onestring"}, }, }, }, } return r, nil }, } r := Resolver{PreferGo: true, Dial: fake.DialContext} txt, err := r.lookupTXT(context.Background(), "golang.org") if err != nil { t.Fatal("LookupTXT failed:", err) } if want := 2; len(txt) != want { t.Fatalf("len(txt), got %d, want %d", len(txt), want) } if want := "string1 string2"; txt[0] != want { t.Errorf("txt[0], got %q, want %q", txt[0], want) } if want := "onestring"; txt[1] != want { t.Errorf("txt[1], got %q, want %q", txt[1], want) } } // Issue 29644: support single-request resolv.conf option in pure Go resolver. // The A and AAAA queries will be sent sequentially, not in parallel. func TestSingleRequestLookup(t *testing.T) { defer dnsWaitGroup.Wait() var ( firstcalled int32 ipv4 int32 = 1 ipv6 int32 = 2 ) fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.ID, Response: true, }, Questions: q.Questions, } for _, question := range q.Questions { switch question.Type { case dnsmessage.TypeA: if question.Name.String() == "slowipv4.example.net." { time.Sleep(10 * time.Millisecond) } if !atomic.CompareAndSwapInt32(&firstcalled, 0, ipv4) { t.Errorf("the A query was received after the AAAA query !") } r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }) case dnsmessage.TypeAAAA: atomic.CompareAndSwapInt32(&firstcalled, 0, ipv6) r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeAAAA, Class: dnsmessage.ClassINET, Length: 16, }, Body: &dnsmessage.AAAAResource{ AAAA: TestAddr6, }, }) } } return r, nil }} r := Resolver{PreferGo: true, Dial: fake.DialContext} conf, err := newResolvConfTest() if err != nil { t.Fatal(err) } defer conf.teardown() if err := conf.writeAndUpdate([]string{"options single-request"}); err != nil { t.Fatal(err) } for _, name := range []string{"hostname.example.net", "slowipv4.example.net"} { firstcalled = 0 _, err := r.LookupIPAddr(context.Background(), name) if err != nil { t.Error(err) } } } // Issue 29358. Add configuration knob to force TCP-only DNS requests in the pure Go resolver. func TestDNSUseTCP(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, }, Questions: q.Questions, } if n == "udp" { t.Fatal("udp protocol was used instead of tcp") } return r, nil }, } r := Resolver{PreferGo: true, Dial: fake.DialContext} ctx, cancel := context.WithCancel(context.Background()) defer cancel() _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second, useTCPOnly) if err != nil { t.Fatal("exchange failed:", err) } } // Issue 34660: PTR response with non-PTR answers should ignore non-PTR func TestPTRandNonPTR(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, }, Questions: q.Questions, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypePTR, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.PTRResource{ PTR: dnsmessage.MustNewName("golang.org."), }, }, { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeTXT, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.TXTResource{ TXT: []string{"PTR 8 6 60 ..."}, // fake RRSIG }, }, }, } return r, nil }, } r := Resolver{PreferGo: true, Dial: fake.DialContext} names, err := r.lookupAddr(context.Background(), "192.0.2.123") if err != nil { t.Fatalf("LookupAddr: %v", err) } if want := []string{"golang.org."}; !reflect.DeepEqual(names, want) { t.Errorf("names = %q; want %q", names, want) } } func TestCVE202133195(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, RecursionAvailable: true, }, Questions: q.Questions, } switch q.Questions[0].Type { case dnsmessage.TypeCNAME: r.Answers = []dnsmessage.Resource{} case dnsmessage.TypeA: // CNAME lookup uses a A/AAAA as a proxy r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(".golang.org."), Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.AResource{ A: TestAddr, }, }, ) case dnsmessage.TypeSRV: n := q.Questions[0].Name if n.String() == "_hdr._tcp.golang.org." { n = dnsmessage.MustNewName(".golang.org.") } r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: n, Type: dnsmessage.TypeSRV, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.SRVResource{ Target: dnsmessage.MustNewName(".golang.org."), }, }, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: n, Type: dnsmessage.TypeSRV, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.SRVResource{ Target: dnsmessage.MustNewName("good.golang.org."), }, }, ) case dnsmessage.TypeMX: r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(".golang.org."), Type: dnsmessage.TypeMX, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.MXResource{ MX: dnsmessage.MustNewName(".golang.org."), }, }, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName("good.golang.org."), Type: dnsmessage.TypeMX, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.MXResource{ MX: dnsmessage.MustNewName("good.golang.org."), }, }, ) case dnsmessage.TypeNS: r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(".golang.org."), Type: dnsmessage.TypeNS, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.NSResource{ NS: dnsmessage.MustNewName(".golang.org."), }, }, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName("good.golang.org."), Type: dnsmessage.TypeNS, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.NSResource{ NS: dnsmessage.MustNewName("good.golang.org."), }, }, ) case dnsmessage.TypePTR: r.Answers = append(r.Answers, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName(".golang.org."), Type: dnsmessage.TypePTR, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.PTRResource{ PTR: dnsmessage.MustNewName(".golang.org."), }, }, dnsmessage.Resource{ Header: dnsmessage.ResourceHeader{ Name: dnsmessage.MustNewName("good.golang.org."), Type: dnsmessage.TypePTR, Class: dnsmessage.ClassINET, Length: 4, }, Body: &dnsmessage.PTRResource{ PTR: dnsmessage.MustNewName("good.golang.org."), }, }, ) } return r, nil }, } r := Resolver{PreferGo: true, Dial: fake.DialContext} // Change the default resolver to match our manipulated resolver originalDefault := DefaultResolver DefaultResolver = &r defer func() { DefaultResolver = originalDefault }() // Redirect host file lookups. defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath) testHookHostsPath = "testdata/hosts" tests := []struct { name string f func(*testing.T) }{ { name: "CNAME", f: func(t *testing.T) { expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"} _, err := r.LookupCNAME(context.Background(), "golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } _, err = LookupCNAME("golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } }, }, { name: "SRV (bad record)", f: func(t *testing.T) { expected := []*SRV{ { Target: "good.golang.org.", }, } expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"} _, records, err := r.LookupSRV(context.Background(), "target", "tcp", "golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } _, records, err = LookupSRV("target", "tcp", "golang.org") if err.Error() != expectedErr.Error() { t.Errorf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } }, }, { name: "SRV (bad header)", f: func(t *testing.T) { _, _, err := r.LookupSRV(context.Background(), "hdr", "tcp", "golang.org.") if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected { t.Errorf("Resolver.LookupSRV returned unexpected error, got %q, want %q", err, expected) } _, _, err = LookupSRV("hdr", "tcp", "golang.org.") if expected := "lookup golang.org.: SRV header name is invalid"; err == nil || err.Error() != expected { t.Errorf("LookupSRV returned unexpected error, got %q, want %q", err, expected) } }, }, { name: "MX", f: func(t *testing.T) { expected := []*MX{ { Host: "good.golang.org.", }, } expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"} records, err := r.LookupMX(context.Background(), "golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } records, err = LookupMX("golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } }, }, { name: "NS", f: func(t *testing.T) { expected := []*NS{ { Host: "good.golang.org.", }, } expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "golang.org"} records, err := r.LookupNS(context.Background(), "golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } records, err = LookupNS("golang.org") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } }, }, { name: "Addr", f: func(t *testing.T) { expected := []string{"good.golang.org."} expectedErr := &DNSError{Err: errMalformedDNSRecordsDetail, Name: "192.0.2.42"} records, err := r.LookupAddr(context.Background(), "192.0.2.42") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } records, err = LookupAddr("192.0.2.42") if err.Error() != expectedErr.Error() { t.Fatalf("unexpected error: %s", err) } if !reflect.DeepEqual(records, expected) { t.Error("Unexpected record set") } }, }, } for _, tc := range tests { t.Run(tc.name, tc.f) } } func TestNullMX(t *testing.T) { fake := fakeDNSServer{ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { r := dnsmessage.Message{ Header: dnsmessage.Header{ ID: q.Header.ID, Response: true, RCode: dnsmessage.RCodeSuccess, }, Questions: q.Questions, Answers: []dnsmessage.Resource{ { Header: dnsmessage.ResourceHeader{ Name: q.Questions[0].Name, Type: dnsmessage.TypeMX, Class: dnsmessage.ClassINET, }, Body: &dnsmessage.MXResource{ MX: dnsmessage.MustNewName("."), }, }, }, } return r, nil }, } r := Resolver{PreferGo: true, Dial: fake.DialContext} rrset, err := r.LookupMX(context.Background(), "golang.org") if err != nil { t.Fatalf("LookupMX: %v", err) } if want := []*MX{&MX{Host: "."}}; !reflect.DeepEqual(rrset, want) { records := []string{} for _, rr := range rrset { records = append(records, fmt.Sprintf("%v", rr)) } t.Errorf("records = [%v]; want [%v]", strings.Join(records, " "), want[0]) } }