update to latest adns, synced with Go's net

This commit is contained in:
Mechiel Lukkien
2024-03-08 15:31:54 +01:00
parent a00b0ba6cd
commit 4fbd7abb57
17 changed files with 187 additions and 85 deletions

View File

@ -29,6 +29,7 @@ import (
"golang.org/x/net/dns/dnsmessage"
"github.com/mjl-/adns/internal/bytealg"
"github.com/mjl-/adns/internal/itoa"
)
@ -205,7 +206,14 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
}
if h.Truncated { // see RFC 5966
// RFC 5966 indicates that when a client receives a UDP response with
// the TC flag set, it should take the TC flag as an indication that it
// should retry over TCP instead.
// The case when the TC flag is set in a TCP response is not well specified,
// so this implements the glibc resolver behavior, returning the existing
// dns response instead of returning a "errNoAnswerFromDNSServer" error.
// See go.dev/issue/64896
if h.Truncated && network == "udp" {
continue
}
return p, h, nil
@ -215,7 +223,9 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
// checkHeader performs basic sanity checks on the header.
func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
if h.RCode == dnsmessage.RCodeNameError {
rcode, hasAdd := extractExtendedRCode(*p, h)
if rcode == dnsmessage.RCodeNameError {
return errNoSuchHost
}
@ -226,17 +236,17 @@ func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
// libresolv continues to the next server when it receives
// an invalid referral response. See golang.org/issue/15434.
if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
if rcode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone && !hasAdd {
return errLameReferral
}
if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
if rcode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
// the server is behaving incorrectly or
// having temporary trouble.
if h.RCode == dnsmessage.RCodeServerFailure {
if rcode == dnsmessage.RCodeServerFailure {
// Look for Extended DNS Error (EDE), RFC 8914.
if p.SkipAllAnswers() != nil || p.SkipAllAuthorities() != nil {
@ -302,6 +312,26 @@ func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
}
}
// extractExtendedRCode extracts the extended RCode from the OPT resource (EDNS(0))
// If an OPT record is not found, the RCode from the hdr is returned.
// Another return value indicates whether an additional resource was found.
func extractExtendedRCode(p dnsmessage.Parser, hdr dnsmessage.Header) (dnsmessage.RCode, bool) {
p.SkipAllAnswers()
p.SkipAllAuthorities()
hasAdd := false
for {
ahdr, err := p.AdditionalHeader()
if err != nil {
return hdr.RCode, hasAdd
}
hasAdd = true
if ahdr.Type == dnsmessage.TypeOPT {
return ahdr.ExtendedRCode(hdr.RCode), hasAdd
}
p.SkipAdditional()
}
}
// Do a lookup for a single name, which must be rooted
// (otherwise answer will not find the answers).
func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, Result, error) {
@ -542,10 +572,6 @@ func avoidDNS(name string) bool {
// nameList returns a list of names for sequential DNS queries.
func (conf *dnsConfig) nameList(name string) []string {
if avoidDNS(name) {
return nil
}
// Check name length (see isDomainName).
l := len(name)
rooted := l > 0 && name[l-1] == '.'
@ -555,27 +581,31 @@ func (conf *dnsConfig) nameList(name string) []string {
// If name is rooted (trailing dot), try only that name.
if rooted {
if avoidDNS(name) {
return nil
}
return []string{name}
}
hasNdots := count(name, '.') >= conf.ndots
hasNdots := bytealg.CountString(name, '.') >= conf.ndots
name += "."
l++
// Build list of search choices.
names := make([]string, 0, 1+len(conf.search))
// If name has enough dots, try unsuffixed first.
if hasNdots {
if hasNdots && !avoidDNS(name) {
names = append(names, name)
}
// Try suffixes that are not too long (see isDomainName).
for _, suffix := range conf.search {
if l+len(suffix) <= 254 {
names = append(names, name+suffix)
fqdn := name + suffix
if !avoidDNS(fqdn) && len(fqdn) <= 254 {
names = append(names, fqdn)
}
}
// Try unsuffixed, if not tried first above.
if !hasNdots {
if !hasNdots && !avoidDNS(name) {
names = append(names, name)
}
return names
@ -767,7 +797,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
h, err := result0.p.AnswerHeader()
if err != nil && err != dnsmessage.ErrSectionDone {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result0.server,
}
@ -780,7 +810,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
a, err := result0.p.AResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result0.server,
}
@ -795,7 +825,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
aaaa, err := result0.p.AAAAResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result0.server,
}
@ -810,7 +840,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
c, err := result0.p.CNAMEResource()
if err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result0.server,
}
@ -823,7 +853,7 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
default:
if err := result0.p.SkipAnswer(); err != nil {
lastErr = &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: name,
Server: result0.server,
}
@ -915,7 +945,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
}
if err != nil {
return nil, result, &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: addr,
Server: server,
}
@ -924,7 +954,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
err := p.SkipAnswer()
if err != nil {
return nil, result, &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: addr,
Server: server,
}
@ -934,7 +964,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
ptr, err := p.PTRResource()
if err != nil {
return nil, result, &DNSError{
Err: "cannot marshal DNS message",
Err: errCannotUnmarshalDNSMessage.Error(),
Name: addr,
Server: server,
}