simplify dns.MockResolver, changing MockReq to just a string representing the request

similar to Authentic/Inauthentic
This commit is contained in:
Mechiel Lukkien
2023-10-12 16:07:25 +02:00
parent c095f3f39c
commit 7dce883097
10 changed files with 57 additions and 54 deletions

View File

@ -20,20 +20,24 @@ type MockResolver struct {
MX map[string][]*net.MX
TLSA map[string][]adns.TLSA // Keys are e.g. _25._tcp.<host>.
CNAME map[string]string
Fail map[Mockreq]struct{}
Fail []string // Records of the form "type name", e.g. "cname localhost." that will return a servfail.
AllAuthentic bool // Default value for authentic in responses. Overridden with Authentic and Inauthentic
Authentic []string // Records of the form "type name", e.g. "cname localhost."
Inauthentic []string
Authentic []string // Like Fail, but records that cause the response to be authentic.
Inauthentic []string // Like Authentic, but making response inauthentic.
}
type Mockreq struct {
type mockReq struct {
Type string // E.g. "cname", "txt", "mx", "ptr", etc.
Name string // Name of request. For TLSA, the full requested DNS name, e.g. _25._tcp.<host>.
}
func (mr mockReq) String() string {
return mr.Type + " " + mr.Name
}
var _ Resolver = MockResolver{}
func (r MockResolver) result(ctx context.Context, mr Mockreq) (string, adns.Result, error) {
func (r MockResolver) result(ctx context.Context, mr mockReq) (string, adns.Result, error) {
result := adns.Result{Authentic: r.AllAuthentic}
if err := ctx.Err(); err != nil {
@ -50,14 +54,14 @@ func (r MockResolver) result(ctx context.Context, mr Mockreq) (string, adns.Resu
}
for {
if _, ok := r.Fail[mr]; ok {
updateAuthentic(mr.Type + " " + mr.Name)
if slices.Contains(r.Fail, mr.String()) {
updateAuthentic(mr.String())
return mr.Name, adns.Result{}, r.servfail(mr.Name)
}
cname, ok := r.CNAME[mr.Name]
if !ok {
updateAuthentic(mr.Type + " " + mr.Name)
updateAuthentic(mr.String())
break
}
updateAuthentic("cname " + mr.Name)
@ -95,7 +99,7 @@ func (r MockResolver) LookupPort(ctx context.Context, network, service string) (
}
func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, adns.Result, error) {
mr := Mockreq{"cname", name}
mr := mockReq{"cname", name}
name, result, err := r.result(ctx, mr)
if err != nil {
return name, result, err
@ -108,7 +112,7 @@ func (r MockResolver) LookupCNAME(ctx context.Context, name string) (string, adn
}
func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, adns.Result, error) {
mr := Mockreq{"ptr", ip}
mr := mockReq{"ptr", ip}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -121,7 +125,7 @@ func (r MockResolver) LookupAddr(ctx context.Context, ip string) ([]string, adns
}
func (r MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, adns.Result, error) {
mr := Mockreq{"ns", name}
mr := mockReq{"ns", name}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -131,7 +135,7 @@ func (r MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, adn
func (r MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, adns.Result, error) {
xname := fmt.Sprintf("_%s._%s.%s", service, proto, name)
mr := Mockreq{"srv", xname}
mr := mockReq{"srv", xname}
name, result, err := r.result(ctx, mr)
if err != nil {
return name, nil, result, err
@ -141,7 +145,7 @@ func (r MockResolver) LookupSRV(ctx context.Context, service, proto, name string
func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, adns.Result, error) {
// todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
mr := Mockreq{"ipaddr", host}
mr := mockReq{"ipaddr", host}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -164,7 +168,7 @@ func (r MockResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAd
func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, adns.Result, error) {
// todo: make closer to resolver, doing a & aaaa lookups, including their error/(in)secure status.
mr := Mockreq{"host", host}
mr := mockReq{"host", host}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -179,7 +183,7 @@ func (r MockResolver) LookupHost(ctx context.Context, host string) ([]string, ad
}
func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net.IP, adns.Result, error) {
mr := Mockreq{"ip", host}
mr := mockReq{"ip", host}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -204,7 +208,7 @@ func (r MockResolver) LookupIP(ctx context.Context, network, host string) ([]net
}
func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adns.Result, error) {
mr := Mockreq{"mx", name}
mr := mockReq{"mx", name}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -217,7 +221,7 @@ func (r MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, adn
}
func (r MockResolver) LookupTXT(ctx context.Context, name string) ([]string, adns.Result, error) {
mr := Mockreq{"txt", name}
mr := mockReq{"txt", name}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err
@ -236,7 +240,7 @@ func (r MockResolver) LookupTLSA(ctx context.Context, port int, protocol string,
} else {
name = fmt.Sprintf("_%d._%s.%s", port, protocol, host)
}
mr := Mockreq{"tlsa", name}
mr := mockReq{"tlsa", name}
_, result, err := r.result(ctx, mr)
if err != nil {
return nil, result, err