diff options
author | Ian Gudger <igudger@google.com> | 2017-11-22 17:12:30 -0800 |
---|---|---|
committer | Brad Fitzpatrick <bradfitz@golang.org> | 2018-03-15 04:18:11 +0000 |
commit | 672729ebbd15e0b0dfac1ba22e35e92557215e1a (patch) | |
tree | afeb2a230d9962496edcf86de3a37f431f0933b5 /src/net/dnsclient_unix.go | |
parent | c830e05a20cbc193a3b0cb4d4bc9b3acab3643b6 (diff) | |
download | go-672729ebbd15e0b0dfac1ba22e35e92557215e1a.tar.gz go-672729ebbd15e0b0dfac1ba22e35e92557215e1a.zip |
net: use golang.org/x/net/dns/dnsmessage for DNS resolution
Vendors golang.org/x/net/dns/dnsmessage from x/net git rev
892bf7b0c6e2f93b51166bf3882e50277fa5afc6
Updates #16218
Updates #21160
Change-Id: Ic4e8f3c3d83c2936354ec14c5be93b0d2b42dd91
Reviewed-on: https://go-review.googlesource.com/37879
Run-TryBot: Matthew Dempsky <mdempsky@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Diffstat (limited to 'src/net/dnsclient_unix.go')
-rw-r--r-- | src/net/dnsclient_unix.go | 417 |
1 files changed, 290 insertions, 127 deletions
diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index 9026fd8c74..d2a6dc4a81 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -23,142 +23,231 @@ import ( "os" "sync" "time" -) - -// A dnsConn represents a DNS transport endpoint. -type dnsConn interface { - io.Closer - SetDeadline(time.Time) error + "golang_org/x/net/dns/dnsmessage" +) - // dnsRoundTrip executes a single DNS transaction, returning a - // DNS response message for the provided DNS query message. - dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err } -// dnsPacketConn implements the dnsConn interface for RFC 1035's -// "UDP usage" transport mechanism. Conn is a packet-oriented connection, -// such as a *UDPConn. -type dnsPacketConn struct { - Conn +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true } -func (c *dnsPacketConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { - b, ok := query.Pack() - if !ok { - return nil, errors.New("cannot marshal DNS message") - } +func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { if _, err := c.Write(b); err != nil { - return nil, err + return dnsmessage.Parser{}, dnsmessage.Header{}, err } b = make([]byte, 512) // see RFC 1035 for { n, err := c.Read(b) if err != nil { - return nil, err + return dnsmessage.Parser{}, dnsmessage.Header{}, err } - resp := &dnsMsg{} - if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) { - // Ignore invalid responses as they may be malicious - // forgery attempts. Instead continue waiting until - // timeout. See golang.org/issue/13281. + var p dnsmessage.Parser + // Ignore invalid responses as they may be malicious + // forgery attempts. Instead continue waiting until + // timeout. See golang.org/issue/13281. + h, err := p.Start(b[:n]) + if err != nil { continue } - return resp, nil + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil } } -// dnsStreamConn implements the dnsConn interface for RFC 1035's -// "TCP usage" transport mechanism. Conn is a stream-oriented connection, -// such as a *TCPConn. -type dnsStreamConn struct { - Conn -} - -func (c *dnsStreamConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) { - b, ok := query.Pack() - if !ok { - return nil, errors.New("cannot marshal DNS message") - } - l := len(b) - b = append([]byte{byte(l >> 8), byte(l)}, b...) +func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { if _, err := c.Write(b); err != nil { - return nil, err + return dnsmessage.Parser{}, dnsmessage.Header{}, err } b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035 if _, err := io.ReadFull(c, b[:2]); err != nil { - return nil, err + return dnsmessage.Parser{}, dnsmessage.Header{}, err } - l = int(b[0])<<8 | int(b[1]) + l := int(b[0])<<8 | int(b[1]) if l > len(b) { b = make([]byte, l) } n, err := io.ReadFull(c, b[:l]) if err != nil { - return nil, err + return dnsmessage.Parser{}, dnsmessage.Header{}, err } - resp := &dnsMsg{} - if !resp.Unpack(b[:n]) { - return nil, errors.New("cannot unmarshal DNS message") + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message") } - if !resp.IsResponseTo(query) { - return nil, errors.New("invalid DNS response") + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot unmarshal DNS message") } - return resp, nil + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response") + } + return p, h, nil } // exchange sends a query on the connection and hopes for a response. -func (r *Resolver) exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) { - out := dnsMsg{ - dnsMsgHdr: dnsMsgHdr{ - recursion_desired: true, - }, - question: []dnsQuestion{ - {name, qtype, dnsClassINET}, - }, +func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("cannot marshal DNS message") } for _, network := range []string{"udp", "tcp"} { - // TODO(mdempsky): Refactor so defers from UDP-based - // exchanges happen before TCP-based exchange. - ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) defer cancel() c, err := r.dial(ctx, network, server) if err != nil { - return nil, err + return dnsmessage.Parser{}, dnsmessage.Header{}, err } - defer c.Close() if d, ok := ctx.Deadline(); ok && !d.IsZero() { c.SetDeadline(d) } - out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano()) - in, err := c.dnsRoundTrip(&out) + var p dnsmessage.Parser + var h dnsmessage.Header + if network == "tcp" { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } else { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } + c.Close() if err != nil { - return nil, mapErr(err) + return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err) + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("invalid DNS response") } - if in.truncated { // see RFC 5966 + if h.Truncated { // see RFC 5966 continue } - return in, nil + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errors.New("no answer from DNS server") +} + +func checkHeaders(p *dnsmessage.Parser, h dnsmessage.Header, name, server string) error { + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + + // libresolv continues to the next server when it receives + // an invalid referral response. See golang.org/issue/15434. + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return &DNSError{Err: "lame referral", Name: name, Server: server} + } + + // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError, + // it means the response in msg was not useful and trying another + // server probably won't help. Return now in those cases. + // TODO: indicate this in a more obvious way, such as a field on DNSError? + if h.RCode == dnsmessage.RCodeNameError { + return &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server} + } + if h.RCode != dnsmessage.RCodeSuccess { + // None of the error codes make sense + // for the query we sent. If we didn't get + // a name error and we didn't get success, + // the server is behaving incorrectly or + // having temporary trouble. + err := &DNSError{Err: "server misbehaving", Name: name, Server: server} + if h.RCode == dnsmessage.RCodeServerFailure { + err.IsTemporary = true + } + return err + } + + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type, name, server string) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return &DNSError{ + Err: errNoSuchHost.Error(), + Name: name, + Server: server, + } + } + if err != nil { + return &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return &DNSError{ + Err: "cannot unmarshal DNS message", + Name: name, + Server: server, + } + } } - return nil, errors.New("no answer from DNS server") } // Do a lookup for a single name, which must be rooted // (otherwise answer will not find the answers). -func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) { +func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { var lastErr error serverOffset := cfg.serverOffset() sLen := uint32(len(cfg.servers)) + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errors.New("cannot marshal DNS message") + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + for i := 0; i < cfg.attempts; i++ { for j := uint32(0); j < sLen; j++ { server := cfg.servers[(serverOffset+j)%sLen] - msg, err := r.exchange(ctx, server, name, qtype, cfg.timeout) + p, h, err := r.exchange(ctx, server, q, cfg.timeout) if err != nil { lastErr = &DNSError{ Err: err.Error(), @@ -175,41 +264,19 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, } continue } - // libresolv continues to the next server when it receives - // an invalid referral response. See golang.org/issue/15434. - if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 { - lastErr = &DNSError{Err: "lame referral", Name: name, Server: server} + + lastErr = checkHeaders(&p, h, name, server) + if lastErr != nil { continue } - cname, rrs, err := answer(name, server, msg, qtype) - // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError, - // it means the response in msg was not useful and trying another - // server probably won't help. Return now in those cases. - // TODO: indicate this in a more obvious way, such as a field on DNSError? - if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError { - return cname, rrs, err + + lastErr = skipToAnswer(&p, qtype, name, server) + if lastErr == nil { + return p, server, nil } - lastErr = err } } - return "", nil, lastErr -} - -// addrRecordList converts and returns a list of IP addresses from DNS -// address records (both A and AAAA). Other record types are ignored. -func addrRecordList(rrs []dnsRR) []IPAddr { - addrs := make([]IPAddr, 0, 4) - for _, rr := range rrs { - switch rr := rr.(type) { - case *dnsRR_A: - addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))}) - case *dnsRR_AAAA: - ip := make(IP, IPv6len) - copy(ip, rr.AAAA[:]) - addrs = append(addrs, IPAddr{IP: ip}) - } - } - return addrs + return dnsmessage.Parser{}, "", lastErr } // A resolverConfig represents a DNS stub resolver configuration. @@ -287,21 +354,26 @@ func (conf *resolverConfig) releaseSema() { <-conf.ch } -func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) { +func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { if !isDomainName(name) { // We used to use "invalid domain name" as the error, // but that is a detail of the specific lookup mechanism. // Other lookups might allow broader name syntax // (for example Multicast DNS allows UTF-8; see RFC 6762). // For consistency with libc resolvers, report no such host. - return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name} + return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() + var ( + p dnsmessage.Parser + server string + err error + ) for _, fqdn := range conf.nameList(name) { - cname, rrs, err = r.tryOneName(ctx, conf, fqdn, qtype) + p, server, err = r.tryOneName(ctx, conf, fqdn, qtype) if err == nil { break } @@ -311,13 +383,16 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype uint16) (cname break } } + if err == nil { + return p, server, nil + } if err, ok := err.(*DNSError); ok { // Show original name passed to lookup, not suffixed one. // In general we might have tried many suffixes; showing // just one is misleading. See also golang.org/issue/6324. err.Name = name } - return + return dnsmessage.Parser{}, "", err } // avoidDNS reports whether this is a hostname for which we should not @@ -454,36 +529,36 @@ func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr, return } -func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) { +func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) { if order == hostLookupFilesDNS || order == hostLookupFiles { addrs = goLookupIPFiles(name) if len(addrs) > 0 || order == hostLookupFiles { - return addrs, name, nil + return addrs, dnsmessage.Name{}, nil } } if !isDomainName(name) { // See comment in func lookup above about use of errNoSuchHost. - return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name} + return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name} } resolvConf.tryUpdate("/etc/resolv.conf") resolvConf.mu.RLock() conf := resolvConf.dnsConfig resolvConf.mu.RUnlock() type racer struct { - cname string - rrs []dnsRR + p dnsmessage.Parser + server string error } lane := make(chan racer, 1) - qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA} + qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA} var lastErr error for _, fqdn := range conf.nameList(name) { for _, qtype := range qtypes { dnsWaitGroup.Add(1) - go func(qtype uint16) { - defer dnsWaitGroup.Done() - cname, rrs, err := r.tryOneName(ctx, conf, fqdn, qtype) - lane <- racer{cname, rrs, err} + go func(qtype dnsmessage.Type) { + p, server, err := r.tryOneName(ctx, conf, fqdn, qtype) + lane <- racer{p, server, err} + dnsWaitGroup.Done() }(qtype) } hitStrictError := false @@ -500,9 +575,74 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order } continue } - addrs = append(addrs, addrRecordList(racer.rrs)...) - if cname == "" { - cname = racer.cname + + // Presotto says it's okay to assume that servers listed in + // /etc/resolv.conf are recursive resolvers. + // + // We asked for recursion, so it should have included all the + // answers we need in this one packet. + // + // Further, RFC 1035 section 4.3.1 says that "the recursive + // response to a query will be... The answer to the query, + // possibly preface by one or more CNAME RRs that specify + // aliases encountered on the way to an answer." + // + // Therefore, we should be able to assume that we can ignore + // CNAMEs and that the A and AAAA records we requested are + // for the canonical name. + + loop: + for { + h, err := racer.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: racer.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := racer.p.AResource() + if err != nil { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: racer.server, + } + break loop + } + addrs = append(addrs, IPAddr{IP: IP(a.A[:])}) + + case dnsmessage.TypeAAAA: + aaaa, err := racer.p.AAAAResource() + if err != nil { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: racer.server, + } + break loop + } + addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])}) + + default: + if err := racer.p.SkipAnswer(); err != nil { + lastErr = &DNSError{ + Err: "cannot marshal DNS message", + Name: name, + Server: racer.server, + } + break loop + } + continue + } + if cname.Length == 0 && h.Name.Length != 0 { + cname = h.Name + } } } if hitStrictError { @@ -528,17 +668,17 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order addrs = goLookupIPFiles(name) } if len(addrs) == 0 && lastErr != nil { - return nil, "", lastErr + return nil, dnsmessage.Name{}, lastErr } } return addrs, cname, nil } // goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME. -func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (cname string, err error) { +func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) { order := systemConf().hostLookupOrder(host) - _, cname, err = r.goLookupIPCNAMEOrder(ctx, host, order) - return + _, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order) + return cname.String(), err } // goLookupPTR is the native Go implementation of LookupAddr. @@ -555,13 +695,36 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, erro if err != nil { return nil, err } - _, rrs, err := r.lookup(ctx, arpa, dnsTypePTR) + p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR) if err != nil { return nil, err } - ptrs := make([]string, len(rrs)) - for i, rr := range rrs { - ptrs[i] = rr.(*dnsRR_PTR).Ptr + var ptrs []string + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return nil, &DNSError{ + Err: "cannot marshal DNS message", + Name: addr, + Server: server, + } + } + if h.Type != dnsmessage.TypePTR { + continue + } + ptr, err := p.PTRResource() + if err != nil { + return nil, &DNSError{ + Err: "cannot marshal DNS message", + Name: addr, + Server: server, + } + } + ptrs = append(ptrs, ptr.PTR.String()) + } return ptrs, nil } |