diff options
author | Ian Gudger <igudger@google.com> | 2017-11-22 17:12:30 -0800 |
---|---|---|
committer | Brad Fitzpatrick <bradfitz@golang.org> | 2018-03-15 04:18:11 +0000 |
commit | 672729ebbd15e0b0dfac1ba22e35e92557215e1a (patch) | |
tree | afeb2a230d9962496edcf86de3a37f431f0933b5 /src/net/dnsclient_unix_test.go | |
parent | c830e05a20cbc193a3b0cb4d4bc9b3acab3643b6 (diff) | |
download | go-672729ebbd15e0b0dfac1ba22e35e92557215e1a.tar.gz go-672729ebbd15e0b0dfac1ba22e35e92557215e1a.zip |
net: use golang.org/x/net/dns/dnsmessage for DNS resolution
Vendors golang.org/x/net/dns/dnsmessage from x/net git rev
892bf7b0c6e2f93b51166bf3882e50277fa5afc6
Updates #16218
Updates #21160
Change-Id: Ic4e8f3c3d83c2936354ec14c5be93b0d2b42dd91
Reviewed-on: https://go-review.googlesource.com/37879
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Diffstat (limited to 'src/net/dnsclient_unix_test.go')
-rw-r--r-- | src/net/dnsclient_unix_test.go | 617 |
1 files changed, 376 insertions, 241 deletions
diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 295ed9770c..1d3b78284c 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -19,42 +19,59 @@ import ( "sync" "testing" "time" + + "golang_org/x/net/dns/dnsmessage" ) var goResolver = Resolver{PreferGo: true} // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation. -const TestAddr uint32 = 0xc0000201 +var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01} // Test address from 2001:db8::/32 block, reserved by RFC 3849 for documentation. var TestAddr6 = [16]byte{0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} +func mustNewName(name string) dnsmessage.Name { + nn, err := dnsmessage.NewName(name) + if err != nil { + panic(fmt.Sprint("creating name: ", err)) + } + return nn +} + +func mustQuestion(name string, qtype dnsmessage.Type, class dnsmessage.Class) dnsmessage.Question { + return dnsmessage.Question{ + Name: mustNewName(name), + Type: qtype, + Class: class, + } +} + var dnsTransportFallbackTests = []struct { - server string - name string - qtype uint16 - timeout int - rcode int + server string + question dnsmessage.Question + timeout int + rcode dnsmessage.RCode }{ // Querying "com." with qtype=255 usually makes an answer // which requires more than 512 bytes. - {"8.8.8.8:53", "com.", dnsTypeALL, 2, dnsRcodeSuccess}, - {"8.8.4.4:53", "com.", dnsTypeALL, 4, dnsRcodeSuccess}, + {"8.8.8.8:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 2, dnsmessage.RCodeSuccess}, + {"8.8.4.4:53", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), 4, dnsmessage.RCodeSuccess}, } func TestDNSTransportFallback(t *testing.T) { fake := fakeDNSServer{ - rh: func(n, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, - rcode: dnsRcodeSuccess, + rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.Header.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, }, - question: q.question, + Questions: q.Questions, } if n == "udp" { - r.truncated = true + r.Header.Truncated = true } return r, nil }, @@ -63,15 +80,13 @@ func TestDNSTransportFallback(t *testing.T) { for _, tt := range dnsTransportFallbackTests { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - msg, err := r.exchange(ctx, tt.server, tt.name, tt.qtype, time.Second) + _, h, err := r.exchange(ctx, tt.server, tt.question, time.Second) if err != nil { t.Error(err) continue } - switch msg.rcode { - case tt.rcode: - default: - t.Errorf("got %v from %v; want %v", msg.rcode, tt.server, tt.rcode) + if h.RCode != tt.rcode { + t.Errorf("got %v from %v; want %v", h.RCode, tt.server, tt.rcode) continue } } @@ -80,39 +95,38 @@ func TestDNSTransportFallback(t *testing.T) { // See RFC 6761 for further information about the reserved, pseudo // domain names. var specialDomainNameTests = []struct { - name string - qtype uint16 - rcode int + question dnsmessage.Question + rcode dnsmessage.RCode }{ // Name resolution APIs and libraries should not recognize the // followings as special. - {"1.0.168.192.in-addr.arpa.", dnsTypePTR, dnsRcodeNameError}, - {"test.", dnsTypeALL, dnsRcodeNameError}, - {"example.com.", dnsTypeALL, dnsRcodeSuccess}, + {mustQuestion("1.0.168.192.in-addr.arpa.", dnsmessage.TypePTR, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, + {mustQuestion("test.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, + {mustQuestion("example.com.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeSuccess}, // Name resolution APIs and libraries should recognize the // followings as special and should not send any queries. // Though, we test those names here for verifying negative // answers at DNS query-response interaction level. - {"localhost.", dnsTypeALL, dnsRcodeNameError}, - {"invalid.", dnsTypeALL, dnsRcodeNameError}, + {mustQuestion("localhost.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, + {mustQuestion("invalid.", dnsmessage.TypeALL, dnsmessage.ClassINET), dnsmessage.RCodeNameError}, } func TestSpecialDomainName(t *testing.T) { - fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, + fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, + Questions: q.Questions, } - switch q.question[0].Name { + switch q.Questions[0].Name.String() { case "example.com.": - r.rcode = dnsRcodeSuccess + r.Header.RCode = dnsmessage.RCodeSuccess default: - r.rcode = dnsRcodeNameError + r.Header.RCode = dnsmessage.RCodeNameError } return r, nil @@ -122,15 +136,13 @@ func TestSpecialDomainName(t *testing.T) { for _, tt := range specialDomainNameTests { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - msg, err := r.exchange(ctx, server, tt.name, tt.qtype, 3*time.Second) + _, h, err := r.exchange(ctx, server, tt.question, 3*time.Second) if err != nil { t.Error(err) continue } - switch msg.rcode { - case tt.rcode, dnsRcodeServerFailure: - default: - t.Errorf("got %v from %v; want %v", msg.rcode, server, tt.rcode) + if h.RCode != tt.rcode { + t.Errorf("got %v from %v; want %v", h.RCode, server, tt.rcode) continue } } @@ -177,24 +189,26 @@ func TestAvoidDNSName(t *testing.T) { } } -var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, +var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, - } - if len(q.question) == 1 && q.question[0].Qtype == dnsTypeA { - r.answer = []dnsRR{ - &dnsRR_A{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeA, - Class: dnsClassINET, - Rdlength: 4, + Questions: q.Questions, + } + if len(q.Questions) == 1 && q.Questions[0].Type == dnsmessage.TypeA { + r.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + Length: 4, + }, + Body: &dnsmessage.AResource{ + A: TestAddr, }, - A: TestAddr, }, } } @@ -459,54 +473,57 @@ var goLookupIPWithResolverConfigTests = []struct { func TestGoLookupIPWithResolverConfig(t *testing.T) { defer dnsWaitGroup.Wait() - - fake := fakeDNSServer{func(n, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { switch s { case "[2001:4860:4860::8888]:53", "8.8.8.8:53": break default: time.Sleep(10 * time.Millisecond) - return nil, poll.ErrTimeout + return dnsmessage.Message{}, poll.ErrTimeout } - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, + Questions: q.Questions, } - for _, question := range q.question { - switch question.Qtype { - case dnsTypeA: - switch question.Name { + for _, question := range q.Questions { + switch question.Type { + case dnsmessage.TypeA: + switch question.Name.String() { case "hostname.as112.net.": break case "ipv4.google.com.": - r.answer = append(r.answer, &dnsRR_A{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeA, - Class: dnsClassINET, - Rdlength: 4, + r.Answers = append(r.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + Length: 4, + }, + Body: &dnsmessage.AResource{ + A: TestAddr, }, - A: TestAddr, }) default: } - case dnsTypeAAAA: - switch question.Name { + case dnsmessage.TypeAAAA: + switch question.Name.String() { case "hostname.as112.net.": break case "ipv6.google.com.": - r.answer = append(r.answer, &dnsRR_AAAA{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeAAAA, - Class: dnsClassINET, - Rdlength: 16, + r.Answers = append(r.Answers, dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + Length: 16, + }, + Body: &dnsmessage.AAAAResource{ + AAAA: TestAddr6, }, - AAAA: TestAddr6, }) } } @@ -554,13 +571,13 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) { func TestGoLookupIPOrderFallbackToFile(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q *dnsMsg, tm time.Time) (*dnsMsg, error) { - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, + fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, + Questions: q.Questions, } return r, nil }} @@ -624,20 +641,20 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) { t.Fatal(err) } - fake := fakeDNSServer{func(_, _ string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, + fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, + Questions: q.Questions, } - switch q.question[0].Name { + switch q.Questions[0].Name.String() { case fqdn + ".servfail.": - r.rcode = dnsRcodeServerFailure + r.Header.RCode = dnsmessage.RCodeServerFailure default: - r.rcode = dnsRcodeNameError + r.Header.RCode = dnsmessage.RCodeNameError } return r, nil @@ -679,28 +696,30 @@ func TestIgnoreLameReferrals(t *testing.T) { t.Fatal(err) } - fake := fakeDNSServer{func(_, s string, q *dnsMsg, _ time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) { t.Log(s, q) - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, + Questions: q.Questions, } if s == "192.0.2.2:53" { - r.recursion_available = true - if q.question[0].Qtype == dnsTypeA { - r.answer = []dnsRR{ - &dnsRR_A{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeA, - Class: dnsClassINET, - Rdlength: 4, + r.Header.RecursionAvailable = true + if q.Questions[0].Type == dnsmessage.TypeA { + r.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + Length: 4, + }, + Body: &dnsmessage.AResource{ + A: TestAddr, }, - A: TestAddr, }, } } @@ -766,20 +785,23 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) { } type fakeDNSServer struct { - rh func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) + rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) } func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) { - return &fakeDNSConn{nil, server, n, s, nil, time.Time{}}, nil + tcp := n == "tcp" || n == "tcp4" || n == "tcp6" + return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil } type fakeDNSConn struct { Conn + tcp bool server *fakeDNSServer n string s string - q *dnsMsg + q dnsmessage.Message t time.Time + buf []byte } func (f *fakeDNSConn) Close() error { @@ -787,15 +809,32 @@ func (f *fakeDNSConn) Close() error { } func (f *fakeDNSConn) Read(b []byte) (int, error) { + if len(f.buf) > 0 { + n := copy(b, f.buf) + f.buf = f.buf[n:] + return n, nil + } + resp, err := f.server.rh(f.n, f.s, f.q, f.t) if err != nil { return 0, err } - bb, ok := resp.Pack() - if !ok { - return 0, errors.New("cannot marshal DNS message") + bb := make([]byte, 2, 514) + bb, err = resp.AppendPack(bb) + if err != nil { + return 0, fmt.Errorf("cannot marshal DNS message: %v", err) } + + if f.tcp { + l := len(bb) - 2 + bb[0] = byte(l >> 8) + bb[1] = byte(l) + f.buf = bb + return f.Read(b) + } + + bb = bb[2:] if len(b) < len(bb) { return 0, errors.New("read would fragment DNS message") } @@ -809,9 +848,11 @@ func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) { } func (f *fakeDNSConn) Write(b []byte) (int, error) { - f.q = new(dnsMsg) - if !f.q.Unpack(b) { - return 0, errors.New("cannot unmarshal DNS message") + if f.tcp && len(b) >= 2 { + b = b[2:] + } + if f.q.Unpack(b) != nil { + return 0, fmt.Errorf("cannot unmarshal DNS message fake %s (%d)", f.n, len(b)) } return len(b), nil } @@ -836,64 +877,75 @@ func TestIgnoreDNSForgeries(t *testing.T) { return } - msg := &dnsMsg{} - if !msg.Unpack(b[:n]) { - t.Error("invalid DNS query") + var msg dnsmessage.Message + if msg.Unpack(b[:n]) != nil { + t.Error("invalid DNS query:", err) return } s.Write([]byte("garbage DNS response packet")) - msg.response = true - msg.id++ // make invalid ID - b, ok := msg.Pack() - if !ok { - t.Error("failed to pack DNS response") + msg.Header.Response = true + msg.Header.ID++ // make invalid ID + + if b, err = msg.Pack(); err != nil { + t.Error("failed to pack DNS response:", err) return } s.Write(b) - msg.id-- // restore original ID - msg.answer = []dnsRR{ - &dnsRR_A{ - Hdr: dnsRR_Header{ - Name: "www.example.com.", - Rrtype: dnsTypeA, - Class: dnsClassINET, - Rdlength: 4, + msg.Header.ID-- // restore original ID + msg.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: mustNewName("www.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + Length: 4, + }, + Body: &dnsmessage.AResource{ + A: TestAddr, }, - A: TestAddr, }, } - b, ok = msg.Pack() - if !ok { - t.Error("failed to pack DNS response") + b, err = msg.Pack() + if err != nil { + t.Error("failed to pack DNS response:", err) return } s.Write(b) }() - msg := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: 42, + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: 42, }, - question: []dnsQuestion{ + Questions: []dnsmessage.Question{ { - Name: "www.example.com.", - Qtype: dnsTypeA, - Qclass: dnsClassINET, + Name: mustNewName("www.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, }, }, } - dc := &dnsPacketConn{c} - resp, err := dc.dnsRoundTrip(msg) + b, err := msg.Pack() if err != nil { - t.Fatalf("dnsRoundTripUDP failed: %v", err) + t.Fatal("Pack failed:", err) } - if got := resp.answer[0].(*dnsRR_A).A; got != TestAddr { + p, _, err := dnsPacketRoundTrip(c, 42, msg.Questions[0], b) + if err != nil { + t.Fatalf("dnsPacketRoundTrip failed: %v", err) + } + + p.SkipAllQuestions() + as, err := p.AllAnswers() + if err != nil { + t.Fatal("AllAnswers failed:", err) + } + if got := as[0].Body.(*dnsmessage.AResource).A; got != TestAddr { t.Errorf("got address %v, want %v", got, TestAddr) } } @@ -918,7 +970,7 @@ func TestRetryTimeout(t *testing.T) { var deadline0 time.Time - fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q, deadline) if deadline.IsZero() { @@ -928,7 +980,7 @@ func TestRetryTimeout(t *testing.T) { if s == "192.0.2.1:53" { deadline0 = deadline time.Sleep(10 * time.Millisecond) - return nil, poll.ErrTimeout + return dnsmessage.Message{}, poll.ErrTimeout } if deadline.Equal(deadline0) { @@ -979,7 +1031,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { } var usedServers []string - fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { usedServers = append(usedServers, s) return mockTXTResponse(q), nil }} @@ -997,22 +1049,24 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) { } } -func mockTXTResponse(q *dnsMsg) *dnsMsg { - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, - recursion_available: true, +func mockTXTResponse(q dnsmessage.Message) dnsmessage.Message { + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, + RecursionAvailable: true, }, - question: q.question, - answer: []dnsRR{ - &dnsRR_TXT{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeTXT, - Class: dnsClassINET, + Questions: q.Questions, + Answers: []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeTXT, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.TXTResource{ + TXT: []string{"ok"}, }, - Txt: "ok", }, }, } @@ -1080,22 +1134,22 @@ func TestStrictErrorsLookupIP(t *testing.T) { cases := []struct { desc string - resolveWhich func(quest *dnsQuestion) resolveWhichEnum + resolveWhich func(quest dnsmessage.Question) resolveWhichEnum wantStrictErr error wantLaxErr error wantIPs []string }{ { desc: "No errors", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { return resolveOK }, wantIPs: []string{ip4, ip6}, }, { desc: "searchX error fails in strict mode", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { - if quest.Name == searchX { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { + if quest.Name.String() == searchX { return resolveTimeout } return resolveOK @@ -1105,8 +1159,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { }, { desc: "searchX IPv4-only timeout fails in strict mode", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { - if quest.Name == searchX && quest.Qtype == dnsTypeA { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { + if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeA { return resolveTimeout } return resolveOK @@ -1116,8 +1170,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { }, { desc: "searchX IPv6-only servfail fails in strict mode", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { - if quest.Name == searchX && quest.Qtype == dnsTypeAAAA { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { + if quest.Name.String() == searchX && quest.Type == dnsmessage.TypeAAAA { return resolveServfail } return resolveOK @@ -1127,8 +1181,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { }, { desc: "searchY error always fails", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { - if quest.Name == searchY { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { + if quest.Name.String() == searchY { return resolveTimeout } return resolveOK @@ -1138,8 +1192,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { }, { desc: "searchY IPv4-only socket error fails in strict mode", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { - if quest.Name == searchY && quest.Qtype == dnsTypeA { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { + if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeA { return resolveOpError } return resolveOK @@ -1149,8 +1203,8 @@ func TestStrictErrorsLookupIP(t *testing.T) { }, { desc: "searchY IPv6-only timeout fails in strict mode", - resolveWhich: func(quest *dnsQuestion) resolveWhichEnum { - if quest.Name == searchY && quest.Qtype == dnsTypeAAAA { + resolveWhich: func(quest dnsmessage.Question) resolveWhichEnum { + if quest.Name.String() == searchY && quest.Type == dnsmessage.TypeAAAA { return resolveTimeout } return resolveOK @@ -1161,80 +1215,84 @@ func TestStrictErrorsLookupIP(t *testing.T) { } for i, tt := range cases { - fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) - switch tt.resolveWhich(&q.question[0]) { + switch tt.resolveWhich(q.Questions[0]) { case resolveOK: // Handle below. case resolveOpError: - return nil, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} + return dnsmessage.Message{}, &OpError{Op: "write", Err: fmt.Errorf("socket on fire")} case resolveServfail: - return &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, - rcode: dnsRcodeServerFailure, + return dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, + RCode: dnsmessage.RCodeServerFailure, }, - question: q.question, + Questions: q.Questions, }, nil case resolveTimeout: - return nil, poll.ErrTimeout + return dnsmessage.Message{}, poll.ErrTimeout default: t.Fatal("Impossible resolveWhich") } - switch q.question[0].Name { + switch q.Questions[0].Name.String() { case searchX, name + ".": // Return NXDOMAIN to utilize the search list. - return &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, - rcode: dnsRcodeNameError, + return dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, + RCode: dnsmessage.RCodeNameError, }, - question: q.question, + Questions: q.Questions, }, nil case searchY: // Return records below. default: - return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) + return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name) } - r := &dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - id: q.id, - response: true, + r := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: q.ID, + Response: true, }, - question: q.question, + Questions: q.Questions, } - switch q.question[0].Qtype { - case dnsTypeA: - r.answer = []dnsRR{ - &dnsRR_A{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeA, - Class: dnsClassINET, - Rdlength: 4, + switch q.Questions[0].Type { + case dnsmessage.TypeA: + r.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + Length: 4, + }, + Body: &dnsmessage.AResource{ + A: TestAddr, }, - A: TestAddr, }, } - case dnsTypeAAAA: - r.answer = []dnsRR{ - &dnsRR_AAAA{ - Hdr: dnsRR_Header{ - Name: q.question[0].Name, - Rrtype: dnsTypeAAAA, - Class: dnsClassINET, - Rdlength: 16, + case dnsmessage.TypeAAAA: + r.Answers = []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: q.Questions[0].Name, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + Length: 16, + }, + Body: &dnsmessage.AAAAResource{ + AAAA: TestAddr6, }, - AAAA: TestAddr6, }, } default: - return nil, fmt.Errorf("Unexpected Qtype: %v", q.question[0].Qtype) + return dnsmessage.Message{}, fmt.Errorf("Unexpected Type: %v", q.Questions[0].Type) } return r, nil }} @@ -1295,22 +1353,22 @@ func TestStrictErrorsLookupTXT(t *testing.T) { const searchY = "test.y.golang.org." const txt = "Hello World" - fake := fakeDNSServer{func(_, s string, q *dnsMsg, deadline time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) { t.Log(s, q) - switch q.question[0].Name { + switch q.Questions[0].Name.String() { case searchX: - return nil, poll.ErrTimeout + return dnsmessage.Message{}, poll.ErrTimeout case searchY: return mockTXTResponse(q), nil default: - return nil, fmt.Errorf("Unexpected Name: %v", q.question[0].Name) + return dnsmessage.Message{}, fmt.Errorf("Unexpected Name: %v", q.Questions[0].Name) } }} for _, strict := range []bool{true, false} { r := Resolver{StrictErrors: strict, Dial: fake.DialContext} - _, rrs, err := r.lookup(context.Background(), name, dnsTypeTXT) + p, _, err := r.lookup(context.Background(), name, dnsmessage.TypeTXT) var wantErr error var wantRRs int if strict { @@ -1326,8 +1384,12 @@ func TestStrictErrorsLookupTXT(t *testing.T) { if !reflect.DeepEqual(err, wantErr) { t.Errorf("strict=%v: got err %#v; want %#v", strict, err, wantErr) } - if len(rrs) != wantRRs { - t.Errorf("strict=%v: got %v; want %v", strict, len(rrs), wantRRs) + a, err := p.AllAnswers() + if err != nil { + a = nil + } + if len(a) != wantRRs { + t.Errorf("strict=%v: got %v; want %v", strict, len(a), wantRRs) } } } @@ -1337,9 +1399,9 @@ func TestStrictErrorsLookupTXT(t *testing.T) { func TestDNSGoroutineRace(t *testing.T) { defer dnsWaitGroup.Wait() - fake := fakeDNSServer{func(n, s string, q *dnsMsg, t time.Time) (*dnsMsg, error) { + fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) { time.Sleep(10 * time.Microsecond) - return nil, poll.ErrTimeout + return dnsmessage.Message{}, poll.ErrTimeout }} r := Resolver{PreferGo: true, Dial: fake.DialContext} @@ -1353,3 +1415,76 @@ func TestDNSGoroutineRace(t *testing.T) { t.Fatal("fake DNS lookup unexpectedly succeeded") } } + +// Issue 8434: verify that Temporary returns true on an error when rcode +// is SERVFAIL +func TestIssue8434(t *testing.T) { + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ + RCode: dnsmessage.RCodeServerFailure, + }, + } + b, err := msg.Pack() + if err != nil { + t.Fatal("Pack failed:", err) + } + var p dnsmessage.Parser + h, err := p.Start(b) + if err != nil { + t.Fatal("Start failed:", err) + } + if err := p.SkipAllQuestions(); err != nil { + t.Fatal("SkipAllQuestions failed:", err) + } + + err = checkHeaders(&p, h, "golang.org", "foo:53") + if err == nil { + t.Fatal("expected an error") + } + if ne, ok := err.(Error); !ok { + t.Fatalf("err = %#v; wanted something supporting net.Error", err) + } else if !ne.Temporary() { + t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err) + } + if de, ok := err.(*DNSError); !ok { + t.Fatalf("err = %#v; wanted a *net.DNSError", err) + } else if !de.IsTemporary { + t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err) + } +} + +// Issue 12778: verify that NXDOMAIN without RA bit errors as +// "no such host" and not "server misbehaving" +func TestIssue12778(t *testing.T) { + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ + RCode: dnsmessage.RCodeNameError, + RecursionAvailable: false, + }, + } + + b, err := msg.Pack() + if err != nil { + t.Fatal("Pack failed:", err) + } + var p dnsmessage.Parser + h, err := p.Start(b) + if err != nil { + t.Fatal("Start failed:", err) + } + if err := p.SkipAllQuestions(); err != nil { + t.Fatal("SkipAllQuestions failed:", err) + } + + err = checkHeaders(&p, h, "golang.org", "foo:53") + if err == nil { + t.Fatal("expected an error") + } + de, ok := err.(*DNSError) + if !ok { + t.Fatalf("err = %#v; wanted a *net.DNSError", err) + } + if de.Err != errNoSuchHost.Error() { + t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error()) + } +} |