This commit is contained in:
Mechiel Lukkien
2023-01-30 14:27:06 +01:00
commit cb229cb6cf
1256 changed files with 491723 additions and 0 deletions

333
mtasts/mtasts.go Normal file
View File

@ -0,0 +1,333 @@
// Package mtasts implements MTA-STS (SMTP MTA Strict Transport Security, RFC 8461)
// which allows a domain to specify SMTP TLS requirements.
//
// SMTP for message delivery to a remote mail server always starts out unencrypted,
// in plain text. STARTTLS allows upgrading the connection to TLS, but is optional
// and by default mail servers will fall back to plain text communication if
// STARTTLS does not work (which can be sabotaged by DNS manipulation or SMTP
// connection manipulation). MTA-STS can specify a policy for requiring STARTTLS to
// be used for message delivery. A TXT DNS record at "_mta-sts.<domain>" specifies
// the version of the policy, and
// "https://mta-sts.<domain>/.well-known/mta-sts.txt" serves the policy.
package mtasts
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/metrics"
"github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/moxio"
)
var xlog = mlog.New("mtasts")
var (
metricGet = promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "mox_mtasts_get_duration_seconds",
Help: "MTA-STS get of policy, including lookup, duration and result.",
Buckets: []float64{0.01, 0.05, 0.100, 0.5, 1, 5, 10, 20},
},
[]string{
"result", // ok, lookuperror, fetcherror
},
)
)
// Pair is an extension key/value pair in a MTA-STS DNS record or policy.
type Pair struct {
Key string
Value string
}
// Record is an MTA-STS DNS record, served under "_mta-sts.<domain>" as a TXT
// record.
//
// Example:
//
// v=STSv1; id=20160831085700Z
type Record struct {
Version string // "STSv1", for "v=". Required.
ID string // Record version, for "id=". Required.
Extensions []Pair // Optional extensions.
}
// String returns a textual version of the MTA-STS record for use as DNS TXT
// record.
func (r Record) String() string {
b := &strings.Builder{}
fmt.Fprint(b, "v="+r.Version)
fmt.Fprint(b, "; id="+r.ID)
for _, p := range r.Extensions {
fmt.Fprint(b, "; "+p.Key+"="+p.Value)
}
return b.String()
}
// Mode indicates how the policy should be interpreted.
type Mode string
// ../rfc/8461:655
const (
ModeEnforce Mode = "enforce" // Policy must be followed, i.e. deliveries must fail if a TLS connection cannot be made.
ModeTesting Mode = "testing" // In case TLS cannot be negotiated, plain SMTP can be used, but failures must be reported, e.g. with TLS-RPT.
ModeNone Mode = "none" // In case MTA-STS is not or no longer implemented.
)
// STSMX is an allowlisted MX host name/pattern.
// todo: find a way to name this just STSMX without getting duplicate names for "MX" in the sherpa api.
type STSMX struct {
// "*." wildcard, e.g. if a subdomain matches. A wildcard must match exactly one
// label. *.example.com matches mail.example.com, but not example.com, and not
// foor.bar.example.com.
Wildcard bool
Domain dns.Domain
}
// Policy is an MTA-STS policy as served at "https://mta-sts.<domain>/.well-known/mta-sts.txt".
type Policy struct {
Version string // "STSv1"
Mode Mode
MX []STSMX
MaxAgeSeconds int // How long this policy can be cached. Suggested values are in weeks or more.
Extensions []Pair
}
// String returns a textual representation for serving at the well-known URL.
func (p Policy) String() string {
b := &strings.Builder{}
line := func(k, v string) {
fmt.Fprint(b, k+": "+v+"\n")
}
line("version", p.Version)
line("mode", string(p.Mode))
line("max_age", fmt.Sprintf("%d", p.MaxAgeSeconds))
for _, mx := range p.MX {
s := mx.Domain.Name()
if mx.Wildcard {
s = "*." + s
}
line("mx", s)
}
return b.String()
}
// Matches returns whether the hostname matches the mx list in the policy.
func (p *Policy) Matches(host dns.Domain) bool {
// ../rfc/8461:636
for _, mx := range p.MX {
if mx.Wildcard {
v := strings.SplitN(host.ASCII, ".", 2)
if len(v) == 2 && v[1] == mx.Domain.ASCII {
return true
}
} else if host == mx.Domain {
return true
}
}
return false
}
// Lookup errors.
var (
ErrNoRecord = errors.New("mtasts: no mta-sts dns txt record") // Domain does not implement MTA-STS. If a cached non-expired policy is available, it should still be used.
ErrMultipleRecords = errors.New("mtasts: multiple mta-sts records") // Should be treated as if domain does not implement MTA-STS, unless a cached non-expired policy is available.
ErrDNS = errors.New("mtasts: dns lookup") // For temporary DNS errors.
ErrRecordSyntax = errors.New("mtasts: record syntax error")
)
// LookupRecord looks up the MTA-STS TXT DNS record at "_mta-sts.<domain>",
// following CNAME records, and returns the parsed MTA-STS record, the DNS TXT
// record and any CNAMEs that were followed.
func LookupRecord(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (rrecord *Record, rtxt string, rcnames []string, rerr error) {
log := xlog.WithContext(ctx)
start := time.Now()
defer func() {
log.Debugx("mtasts lookup result", rerr, mlog.Field("domain", domain), mlog.Field("record", rrecord), mlog.Field("cnames", rcnames), mlog.Field("duration", time.Since(start)))
}()
// ../rfc/8461:289
// ../rfc/8461:351
// We lookup the txt record, but must follow CNAME records when the TXT does not exist.
var cnames []string
name := "_mta-sts." + domain.ASCII + "."
var txts []string
for {
var err error
txts, err = dns.WithPackage(resolver, "mtasts").LookupTXT(ctx, name)
if dns.IsNotFound(err) {
// DNS has no specified limit on how many CNAMEs to follow. Chains of 10 CNAMEs
// have been seen on the internet.
if len(cnames) > 16 {
return nil, "", cnames, fmt.Errorf("too many cnames")
}
cname, err := dns.WithPackage(resolver, "mtasts").LookupCNAME(ctx, name)
if dns.IsNotFound(err) {
return nil, "", cnames, ErrNoRecord
}
if err != nil {
return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err)
}
cnames = append(cnames, cname)
name = cname
continue
} else if err != nil {
return nil, "", cnames, fmt.Errorf("%w: %s", ErrDNS, err)
} else {
break
}
}
var text string
var record *Record
for _, txt := range txts {
r, ismtasts, err := ParseRecord(txt)
if !ismtasts {
// ../rfc/8461:331 says we should essentially treat a record starting with e.g.
// "v=STSv1 ;" (note the space) as a non-STS record too in case of multiple TXT
// records. We treat it as an STS record that is invalid, which is possibly more
// reasonable.
continue
}
if err != nil {
return nil, "", cnames, err
}
if record != nil {
return nil, "", cnames, ErrMultipleRecords
}
record = r
text = txt
}
if record == nil {
return nil, "", cnames, ErrNoRecord
}
return record, text, cnames, nil
}
// Policy fetch errors.
var (
ErrNoPolicy = errors.New("mtasts: no policy served") // If the name "mta-sts.<domain>" does not exist in DNS or if webserver returns HTTP status 404 "File not found".
ErrPolicyFetch = errors.New("mtasts: cannot fetch policy") // E.g. for HTTP request errors.
ErrPolicySyntax = errors.New("mtasts: policy syntax error")
)
// HTTPClient is used by FetchPolicy for HTTP requests.
var HTTPClient = &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return fmt.Errorf("redirect not allowed for MTA-STS policies") // ../rfc/8461:549
},
}
// FetchPolicy fetches a new policy for the domain, at
// https://mta-sts.<domain>/.well-known/mta-sts.txt.
//
// FetchPolicy returns the parsed policy and the literal policy text as fetched
// from the server. If a policy was fetched but could not be parsed, the policyText
// return value will be set.
//
// Policies longer than 64KB result in a syntax error.
//
// If an error is returned, callers should back off for 5 minutes until the next
// attempt.
func FetchPolicy(ctx context.Context, domain dns.Domain) (policy *Policy, policyText string, rerr error) {
log := xlog.WithContext(ctx)
start := time.Now()
defer func() {
log.Debugx("mtasts fetch policy result", rerr, mlog.Field("domain", domain), mlog.Field("policy", policy), mlog.Field("policytext", policyText), mlog.Field("duration", time.Since(start)))
}()
// Timeout of 1 minute. ../rfc/8461:569
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
// TLS requirements are what the Go standard library checks: trusted, non-expired,
// hostname validated against DNS-ID supporting wildcard. ../rfc/8461:524
url := "https://mta-sts." + domain.Name() + "/.well-known/mta-sts.txt"
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, "", fmt.Errorf("%w: http request: %s", ErrPolicyFetch, err)
}
// We are not likely to reuse a connection: we cache policies and negative DNS
// responses. So don't keep connections open unnecessarily.
req.Close = true
resp, err := HTTPClient.Do(req)
if dns.IsNotFound(err) {
return nil, "", ErrNoPolicy
}
if err != nil {
return nil, "", fmt.Errorf("%w: http get: %s", ErrPolicyFetch, err)
}
metrics.HTTPClientObserve(ctx, "mtasts", req.Method, resp.StatusCode, err, start)
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, "", ErrNoPolicy
}
if resp.StatusCode != http.StatusOK {
// ../rfc/8461:548
return nil, "", fmt.Errorf("%w: http status %s while status 200 is required", ErrPolicyFetch, resp.Status)
}
// We don't look at Content-Type and charset. It should be ASCII or UTF-8, we'll
// just always whatever is sent as UTF-8. ../rfc/8461:367
// ../rfc/8461:570
buf, err := io.ReadAll(&moxio.LimitReader{R: resp.Body, Limit: 64 * 1024})
if err != nil {
return nil, "", fmt.Errorf("%w: reading policy: %s", ErrPolicySyntax, err)
}
policyText = string(buf)
policy, err = ParsePolicy(policyText)
if err != nil {
return nil, policyText, fmt.Errorf("parsing policy: %w", err)
}
return policy, policyText, nil
}
// Get looks up the MTA-STS DNS record and fetches the policy.
//
// Errors can be those returned by LookupRecord and FetchPolicy.
//
// If a valid policy cannot be retrieved, a sender must treat the domain as not
// implementing MTA-STS. If a sender has a non-expired cached policy, that policy
// would still apply.
//
// If a record was retrieved, but a policy could not be retrieved/parsed, the
// record is still returned.
//
// Also see Get in package mtastsdb.
func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (record *Record, policy *Policy, err error) {
log := xlog.WithContext(ctx)
start := time.Now()
result := "lookuperror"
defer func() {
metricGet.WithLabelValues(result).Observe(float64(time.Since(start)) / float64(time.Second))
log.Debugx("mtasts get result", err, mlog.Field("domain", domain), mlog.Field("record", record), mlog.Field("policy", policy), mlog.Field("duration", time.Since(start)))
}()
record, _, _, err = LookupRecord(ctx, resolver, domain)
if err != nil {
return nil, nil, err
}
result = "fetcherror"
policy, _, err = FetchPolicy(ctx, domain)
if err != nil {
return record, nil, err
}
result = "ok"
return record, policy, nil
}

267
mtasts/mtasts_test.go Normal file
View File

@ -0,0 +1,267 @@
package mtasts
import (
"context"
"crypto/ed25519"
cryptorand "crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"log"
"math/big"
"net"
"net/http"
"reflect"
"strings"
"sync"
"testing"
"time"
"github.com/mjl-/mox/dns"
)
func TestLookup(t *testing.T) {
resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.a.example.": {"v=STSv1; id=1"},
"_mta-sts.one.example.": {"v=STSv1; id=1", "bogus"},
"_mta-sts.bad.example.": {"v=STSv1; bogus"},
"_mta-sts.multiple.example.": {"v=STSv1; id=1", "v=STSv1; id=2"},
"_mta-sts.c.cnames.example.": {"v=STSv1; id=1"},
"_mta-sts.temperror.example.": {"v=STSv1; id=1"},
"_mta-sts.other.example.": {"bogus", "more"},
},
CNAME: map[string]string{
"_mta-sts.a.cnames.example.": "_mta-sts.b.cnames.example.",
"_mta-sts.b.cnames.example.": "_mta-sts.c.cnames.example.",
"_mta-sts.followtemperror.example.": "_mta-sts.cnametemperror.example.",
},
Fail: map[dns.Mockreq]struct{}{
{Type: "txt", Name: "_mta-sts.temperror.example."}: {},
{Type: "cname", Name: "_mta-sts.cnametemperror.example."}: {},
},
}
test := func(host string, expRecord *Record, expCNAMEs []string, expErr error) {
t.Helper()
record, _, cnames, err := LookupRecord(context.Background(), resolver, dns.Domain{ASCII: host})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("lookup: got err %#v, expected %#v", err, expErr)
}
if err != nil {
return
}
if !reflect.DeepEqual(record, expRecord) || !reflect.DeepEqual(cnames, expCNAMEs) {
t.Fatalf("lookup: got record %#v, cnames %#v, expected %#v %#v", record, cnames, expRecord, expCNAMEs)
}
}
test("absent.example", nil, nil, ErrNoRecord)
test("other.example", nil, nil, ErrNoRecord)
test("a.example", &Record{Version: "STSv1", ID: "1"}, nil, nil)
test("one.example", &Record{Version: "STSv1", ID: "1"}, nil, nil)
test("bad.example", nil, nil, ErrRecordSyntax)
test("multiple.example", nil, nil, ErrMultipleRecords)
test("a.cnames.example", &Record{Version: "STSv1", ID: "1"}, []string{"_mta-sts.b.cnames.example.", "_mta-sts.c.cnames.example."}, nil)
test("temperror.example", nil, nil, ErrDNS)
test("cnametemperror.example", nil, nil, ErrDNS)
test("followtemperror.example", nil, nil, ErrDNS)
}
func TestMatches(t *testing.T) {
p, err := ParsePolicy("version: STSv1\nmode: enforce\nmax_age: 1\nmx: a.example\nmx: *.b.example\n")
if err != nil {
t.Fatalf("parsing policy: %s", err)
}
mustParseDomain := func(s string) dns.Domain {
t.Helper()
d, err := dns.ParseDomain(s)
if err != nil {
t.Fatalf("parsing domain %q: %s", s, err)
}
return d
}
match := func(s string) {
t.Helper()
if !p.Matches(mustParseDomain(s)) {
t.Fatalf("unexpected mismatch for %q", s)
}
}
not := func(s string) {
t.Helper()
if p.Matches(mustParseDomain(s)) {
t.Fatalf("unexpected match for %q", s)
}
}
match("a.example")
match("sub.b.example")
not("b.example")
not("sub.sub.b.example")
not("other")
}
type pipeListener struct {
sync.Mutex
closed bool
C chan net.Conn
}
var _ net.Listener = &pipeListener{}
func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
func (l *pipeListener) Dial() (net.Conn, error) {
l.Lock()
defer l.Unlock()
if l.closed {
return nil, errors.New("closed")
}
c, s := net.Pipe()
l.C <- s
return c, nil
}
func (l *pipeListener) Accept() (net.Conn, error) {
conn := <-l.C
if conn == nil {
return nil, io.EOF
}
return conn, nil
}
func (l *pipeListener) Close() error {
l.Lock()
defer l.Unlock()
if !l.closed {
l.closed = true
close(l.C)
}
return nil
}
func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
type pipeAddr struct{}
func (a pipeAddr) Network() string { return "pipe" }
func (a pipeAddr) String() string { return "pipe" }
func fakeCert(t *testing.T, expired bool) tls.Certificate {
notAfter := time.Now()
if expired {
notAfter = notAfter.Add(-time.Hour)
} else {
notAfter = notAfter.Add(time.Hour)
}
privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
template := &x509.Certificate{
SerialNumber: big.NewInt(1), // Required field...
DNSNames: []string{"mta-sts.mox.example"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: notAfter,
}
localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
if err != nil {
t.Fatalf("making certificate: %s", err)
}
cert, err := x509.ParseCertificate(localCertBuf)
if err != nil {
t.Fatalf("parsing generated certificate: %s", err)
}
c := tls.Certificate{
Certificate: [][]byte{localCertBuf},
PrivateKey: privKey,
Leaf: cert,
}
return c
}
func TestFetch(t *testing.T) {
certok := fakeCert(t, false)
certbad := fakeCert(t, true)
defer func() {
HTTPClient.Transport = nil
}()
resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.mox.example.": {"v=STSv1; id=1"},
"_mta-sts.other.example.": {"v=STSv1; id=1"},
},
}
test := func(cert tls.Certificate, domain string, status int, policyText string, expPolicy *Policy, expErr error) {
t.Helper()
pool := x509.NewCertPool()
pool.AddCert(cert.Leaf)
l := newPipeListener()
defer l.Close()
go func() {
mux := &http.ServeMux{}
mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Location", "/other") // Ignored except for redirect.
w.WriteHeader(status)
w.Write([]byte(policyText))
})
s := &http.Server{
Handler: mux,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
ErrorLog: log.New(io.Discard, "", 0),
}
s.ServeTLS(l, "", "")
}()
HTTPClient.Transport = &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
if strings.HasPrefix(addr, "mta-sts.doesnotexist.example") {
return nil, &net.DNSError{IsNotFound: true}
}
return l.Dial()
},
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
}
p, _, err := FetchPolicy(context.Background(), dns.Domain{ASCII: domain})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("policy: got err %#v, expected %#v", err, expErr)
}
if err == nil && !reflect.DeepEqual(p, expPolicy) {
t.Fatalf("policy: got %#v, expected %#v", p, expPolicy)
}
if domain == "doesnotexist.example" {
expErr = ErrNoRecord
}
_, p, err = Get(context.Background(), resolver, dns.Domain{ASCII: domain})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("get: got err %#v, expected %#v", err, expErr)
}
if err == nil && !reflect.DeepEqual(p, expPolicy) {
t.Fatalf("get: got %#v, expected %#v", p, expPolicy)
}
}
test(certok, "mox.example", 200, "bogus", nil, ErrPolicySyntax)
test(certok, "other.example", 200, "bogus", nil, ErrPolicyFetch)
test(certbad, "mox.example", 200, "bogus", nil, ErrPolicyFetch)
test(certok, "mox.example", 404, "bogus", nil, ErrNoPolicy)
test(certok, "doesnotexist.example", 200, "bogus", nil, ErrNoPolicy)
test(certok, "mox.example", 301, "bogus", nil, ErrPolicyFetch)
test(certok, "mox.example", 500, "bogus", nil, ErrPolicyFetch)
large := make([]byte, 64*1024+2)
test(certok, "mox.example", 200, string(large), nil, ErrPolicySyntax)
validPolicy := "version:STSv1\nmode:none\nmax_age:1"
test(certok, "mox.example", 200, validPolicy, &Policy{Version: "STSv1", Mode: "none", MaxAgeSeconds: 1}, nil)
}

347
mtasts/parse.go Normal file
View File

@ -0,0 +1,347 @@
package mtasts
import (
"fmt"
"strconv"
"strings"
"github.com/mjl-/mox/dns"
)
type parseErr string
func (e parseErr) Error() string {
return string(e)
}
var _ error = parseErr("")
// ParseRecord parses an MTA-STS record.
func ParseRecord(txt string) (record *Record, ismtasts bool, err error) {
defer func() {
x := recover()
if x == nil {
return
}
if xerr, ok := x.(parseErr); ok {
record = nil
err = fmt.Errorf("%w: %s", ErrRecordSyntax, xerr)
return
}
panic(x)
}()
// Parsing is mostly case-sensitive.
// ../rfc/8461:306
p := newParser(txt)
record = &Record{
Version: "STSv1",
}
seen := map[string]struct{}{}
p.xtake("v=STSv1")
p.xdelim()
ismtasts = true
for {
k := p.xkey()
p.xtake("=")
// Section 3.1 about the TXT record does not say anything about duplicate fields.
// But section 3.2 about (parsing) policies has a paragraph that starts
// requirements on both TXT and policy records. That paragraph ends with a note
// about handling duplicate fields. Let's assume that note also applies to TXT
// records. ../rfc/8461:517
_, dup := seen[k]
seen[k] = struct{}{}
switch k {
case "id":
if !dup {
record.ID = p.xid()
}
default:
v := p.xvalue()
record.Extensions = append(record.Extensions, Pair{k, v})
}
if !p.delim() || p.empty() {
break
}
}
if !p.empty() {
p.xerrorf("leftover characters")
}
if record.ID == "" {
p.xerrorf("missing id")
}
return
}
// ParsePolicy parses an MTA-STS policy.
func ParsePolicy(s string) (policy *Policy, err error) {
defer func() {
x := recover()
if x == nil {
return
}
if xerr, ok := x.(parseErr); ok {
policy = nil
err = fmt.Errorf("%w: %s", ErrPolicySyntax, xerr)
return
}
panic(x)
}()
// ../rfc/8461:426
p := newParser(s)
policy = &Policy{
Version: "STSv1",
}
seen := map[string]struct{}{}
for {
k := p.xkey()
// For fields except "mx", only the first must be used. ../rfc/8461:517
_, dup := seen[k]
seen[k] = struct{}{}
p.xtake(":")
p.wsp()
switch k {
case "version":
policy.Version = p.xtake("STSv1")
case "mode":
mode := Mode(p.xtakelist("testing", "enforce", "none"))
if !dup {
policy.Mode = mode
}
case "max_age":
maxage := p.xmaxage()
if !dup {
policy.MaxAgeSeconds = maxage
}
case "mx":
policy.MX = append(policy.MX, p.xmx())
default:
v := p.xpolicyvalue()
policy.Extensions = append(policy.Extensions, Pair{k, v})
}
p.wsp()
if !p.eol() || p.empty() {
break
}
}
if !p.empty() {
p.xerrorf("leftover characters")
}
required := []string{"version", "mode", "max_age"}
for _, req := range required {
if _, ok := seen[req]; !ok {
p.xerrorf("missing field %q", req)
}
}
if _, ok := seen["mx"]; !ok && policy.Mode != ModeNone {
// ../rfc/8461:437
p.xerrorf("missing mx given mode")
}
return
}
type parser struct {
s string
o int
}
func newParser(s string) *parser {
return &parser{s: s}
}
func (p *parser) xerrorf(format string, args ...any) {
msg := fmt.Sprintf(format, args...)
if p.o < len(p.s) {
msg += fmt.Sprintf(" (remain %q)", p.s[p.o:])
}
panic(parseErr(msg))
}
func (p *parser) xtake(s string) string {
if !p.prefix(s) {
p.xerrorf("expected %q", s)
}
p.o += len(s)
return s
}
func (p *parser) xdelim() {
if !p.delim() {
p.xerrorf("expected semicolon")
}
}
func (p *parser) xtaken(n int) string {
r := p.s[p.o : p.o+n]
p.o += n
return r
}
func (p *parser) xtakefn1(fn func(rune, int) bool) string {
for i, b := range p.s[p.o:] {
if !fn(b, i) {
if i == 0 {
p.xerrorf("expected at least one char")
}
return p.xtaken(i)
}
}
if p.empty() {
p.xerrorf("expected at least 1 char")
}
return p.xtaken(len(p.s) - p.o)
}
func (p *parser) prefix(s string) bool {
return strings.HasPrefix(p.s[p.o:], s)
}
// File name, the known values match this syntax.
// ../rfc/8461:482
func (p *parser) xkey() string {
return p.xtakefn1(func(b rune, i int) bool {
return i < 32 && (b >= 'a' && b <= 'z' || b >= 'A' && b <= 'Z' || b >= '0' && b <= '9' || (i > 0 && b == '_' || b == '-' || b == '.'))
})
}
// ../rfc/8461:319
func (p *parser) xid() string {
return p.xtakefn1(func(b rune, i int) bool {
return i < 32 && (b >= 'a' && b <= 'z' || b >= 'A' && b <= 'Z' || b >= '0' && b <= '9')
})
}
// ../rfc/8461:326
func (p *parser) xvalue() string {
return p.xtakefn1(func(b rune, i int) bool {
return b > ' ' && b < 0x7f && b != '=' && b != ';'
})
}
// ../rfc/8461:315
func (p *parser) delim() bool {
o := p.o
e := len(p.s)
for o < e && (p.s[o] == ' ' || p.s[o] == '\t') {
o++
}
if o >= e || p.s[o] != ';' {
return false
}
o++
for o < e && (p.s[o] == ' ' || p.s[o] == '\t') {
o++
}
p.o = o
return true
}
func (p *parser) empty() bool {
return p.o >= len(p.s)
}
// ../rfc/8461:485
func (p *parser) eol() bool {
return p.take("\n") || p.take("\r\n")
}
func (p *parser) xtakelist(l ...string) string {
for _, s := range l {
if p.prefix(s) {
return p.xtaken(len(s))
}
}
p.xerrorf("expected one of %s", strings.Join(l, ", "))
return "" // not reached
}
// ../rfc/8461:476
func (p *parser) xmaxage() int {
digits := p.xtakefn1(func(b rune, i int) bool {
return b >= '0' && b <= '9' && i < 10
})
v, err := strconv.ParseInt(digits, 10, 32)
if err != nil {
p.xerrorf("parsing int: %s", err)
}
return int(v)
}
func (p *parser) take(s string) bool {
if p.prefix(s) {
p.o += len(s)
return true
}
return false
}
// ../rfc/8461:469
func (p *parser) xmx() (mx STSMX) {
if p.prefix("*.") {
mx.Wildcard = true
p.o += 2
}
mx.Domain = p.xdomain()
return mx
}
// ../rfc/5321:2291
func (p *parser) xdomain() dns.Domain {
s := p.xsubdomain()
for p.take(".") {
s += "." + p.xsubdomain()
}
d, err := dns.ParseDomain(s)
if err != nil {
p.xerrorf("parsing domain %q: %s", s, err)
}
return d
}
// ../rfc/8461:487
func (p *parser) xsubdomain() string {
// note: utf-8 is valid, but U-labels are explicitly not allowed. ../rfc/8461:411 ../rfc/5321:2303
unicode := false
s := p.xtakefn1(func(c rune, i int) bool {
if c > 0x7f {
unicode = true
}
return c >= '0' && c <= '9' || c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || (i > 0 && c == '-') || c > 0x7f
})
if unicode {
p.xerrorf("domain must be specified in A labels, not U labels (unicode)")
}
return s
}
// ../rfc/8461:487
func (p *parser) xpolicyvalue() string {
e := len(p.s)
for i, c := range p.s[p.o:] {
if c > ' ' && c < 0x7f || c >= 0x80 || (c == ' ' && i > 0) {
continue
}
e = p.o + i
break
}
// Walk back on trailing spaces.
for e > p.o && p.s[e-1] == ' ' {
e--
}
n := e - p.o
if n <= 0 {
p.xerrorf("empty extension value")
}
return p.xtaken(n)
}
// "*WSP"
func (p *parser) wsp() {
n := len(p.s)
for p.o < n && (p.s[p.o] == ' ' || p.s[p.o] == '\t') {
p.o++
}
}

237
mtasts/parse_test.go Normal file
View File

@ -0,0 +1,237 @@
package mtasts
import (
"reflect"
"testing"
"github.com/mjl-/mox/dns"
)
func TestRecord(t *testing.T) {
good := func(txt string, want Record) {
t.Helper()
r, _, err := ParseRecord(txt)
if err != nil {
t.Fatalf("parse: %s", err)
}
if !reflect.DeepEqual(r, &want) {
t.Fatalf("want %#v, got %#v", want, *r)
}
}
bad := func(txt string) {
t.Helper()
r, _, err := ParseRecord(txt)
if err == nil {
t.Fatalf("parse, expected error, got record %v", r)
}
}
good("v=STSv1; id=20160831085700Z;", Record{Version: "STSv1", ID: "20160831085700Z"})
good("v=STSv1; \t id=20160831085700Z \t;", Record{Version: "STSv1", ID: "20160831085700Z"})
good("v=STSv1; id=a", Record{Version: "STSv1", ID: "a"})
good("v=STSv1; id=a; more=a; ext=2", Record{Version: "STSv1", ID: "a", Extensions: []Pair{{"more", "a"}, {"ext", "2"}}})
bad("v=STSv0")
bad("v=STSv10")
bad("v=STSv2")
bad("v=STSv1") // missing id
bad("v=STSv1;") // missing id
bad("v=STSv1; ext=1") // missing id
bad("v=STSv1; id=") // empty id
bad("v=STSv1; id=012345678901234567890123456789012") // id too long
bad("v=STSv1; id=test-123") // invalid id
bad("v=STSv1; id=a; more=") // empty value in extension
bad("v=STSv1; id=a; a12345678901234567890123456789012=1") // extension name too long
bad("v=STSv1; id=a; 1%=a") // invalid extension name
bad("v=STSv1; id=a; test==") // invalid extension name
bad("v=STSv1; id=a;;") // additional semicolon
const want = `v=STSv1; id=a; more=a; ext=2`
record := Record{Version: "STSv1", ID: "a", Extensions: []Pair{{"more", "a"}, {"ext", "2"}}}
got := record.String()
if got != want {
t.Fatalf("record string, got %q, want %q", got, want)
}
}
func TestParsePolicy(t *testing.T) {
good := func(s string, want Policy) {
t.Helper()
p, err := ParsePolicy(s)
if err != nil {
t.Fatalf("parse policy: %s", err)
}
if !reflect.DeepEqual(p, &want) {
t.Fatalf("want %v, got %v", want, p)
}
}
good(`version: STSv1
mode: testing
mx: mx1.example.com
mx: mx2.example.com
mx: mx.backup-example.com
max_age: 1296000
`,
Policy{
Version: "STSv1",
Mode: ModeTesting,
MX: []STSMX{
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
{Domain: dns.Domain{ASCII: "mx2.example.com"}},
{Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
},
MaxAgeSeconds: 1296000,
},
)
good("version: STSv1\nmode: enforce \nmx: *.example.com \nmax_age: 0 \n",
Policy{
Version: "STSv1",
Mode: ModeEnforce,
MX: []STSMX{
{Wildcard: true, Domain: dns.Domain{ASCII: "example.com"}},
},
MaxAgeSeconds: 0,
},
)
good("version:STSv1\r\nmode:\tenforce\r\nmx: \t\t *.example.com\nmax_age: 1\nmore:ext e ns ion",
Policy{
Version: "STSv1",
Mode: ModeEnforce,
MX: []STSMX{
{Wildcard: true, Domain: dns.Domain{ASCII: "example.com"}},
},
MaxAgeSeconds: 1,
Extensions: []Pair{{"more", "ext e ns ion"}},
},
)
bad := func(s string) {
t.Helper()
p, err := ParsePolicy(s)
if err == nil {
t.Fatalf("parsing policy did not fail: %v", p)
}
}
bad("") // missing version
bad("version:STSv0\nmode:none\nmax_age:0") // bad version
bad("version:STSv10\nmode:none\nmax_age:0") // bad version
bad("version:STSv2\nmode:none\nmax_age:0") // bad version
bad("version:STSv1\nmax_age:0\nmx:example.com") // missing mode
bad("version:STSv1\nmode:none") // missing max_age
bad("version:STSv1\nmax_age:0\nmode:enforce") // missing mx for mode
bad("version:STSv1\nmax_age:0\nmode:testing") // missing mx for mode
bad("max_age:0\nmode:none") // missing version
bad("version:STSv1\nmode:none\nmax_age:01234567890") // max_age too long
bad("version:STSv1\nmode:bad\nmax_age:1") // bad mode
bad("version:STSv1\nmode:none\nmax_age:a") // bad max_age
bad("version:STSv1\nmode:enforce\nmax_age:0\nmx:") // missing value
bad("version:STSv1\nmode:enforce\nmax_age:0\nmx:*.*.example") // bad mx
bad("version:STSv1\nmode:enforce\nmax_age:0\nmx:**.example") // bad mx
bad("version:STSv1\nmode:enforce\nmax_age:0\nmx:**.example-") // bad mx
bad("version:STSv1\nmode:enforce\nmax_age:0\nmx:test.example-") // bad mx
bad("version:STSv1\nmode:none\nmax_age:0\next:") // empty extension
bad("version:STSv1\nmode:none\nmax_age:0\na12345678901234567890123456789012:123") // long extension name
bad("version:STSv1\nmode:none\nmax_age:0\n_bad:test") // bad ext name
bad("version:STSv1\nmode:none\nmax_age:0\nmx: møx.example") // invalid u-label in mx
policy := Policy{
Version: "STSv1",
Mode: ModeTesting,
MX: []STSMX{
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
{Domain: dns.Domain{ASCII: "mx2.example.com"}},
{Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
},
MaxAgeSeconds: 1296000,
}
want := `version: STSv1
mode: testing
max_age: 1296000
mx: mx1.example.com
mx: mx2.example.com
mx: mx.backup-example.com
`
got := policy.String()
if got != want {
t.Fatalf("policy string, got %q, want %q", got, want)
}
}
func FuzzParseRecord(f *testing.F) {
f.Add("v=STSv1; id=20160831085700Z;")
f.Add("v=STSv1; \t id=20160831085700Z \t;")
f.Add("v=STSv1; id=a")
f.Add("v=STSv1; id=a; more=a; ext=2")
f.Add("v=STSv0")
f.Add("v=STSv10")
f.Add("v=STSv2")
f.Add("v=STSv1") // missing id
f.Add("v=STSv1;") // missing id
f.Add("v=STSv1; ext=1") // missing id
f.Add("v=STSv1; id=") // empty id
f.Add("v=STSv1; id=012345678901234567890123456789012") // id too long
f.Add("v=STSv1; id=test-123") // invalid id
f.Add("v=STSv1; id=a; more=") // empty value in extension
f.Add("v=STSv1; id=a; a12345678901234567890123456789012=1") // extension name too long
f.Add("v=STSv1; id=a; 1%=a") // invalid extension name
f.Add("v=STSv1; id=a; test==") // invalid extension name
f.Add("v=STSv1; id=a;;") // additional semicolon
f.Fuzz(func(t *testing.T, s string) {
r, _, err := ParseRecord(s)
if err == nil {
_ = r.String()
}
})
}
func FuzzParsePolicy(f *testing.F) {
f.Add(`version: STSv1
mode: testing
mx: mx1.example.com
mx: mx2.example.com
mx: mx.backup-example.com
max_age: 1296000
`)
f.Add(`version: STSv1
mode: enforce
mx: *.example.com
max_age: 0
`)
f.Add("version:STSv1\r\nmode:\tenforce\r\nmx: \t\t *.example.com\nmax_age: 1\nmore:ext e ns ion")
f.Add("") // missing version
f.Add("version:STSv0\nmode:none\nmax_age:0") // bad version
f.Add("version:STSv10\nmode:none\nmax_age:0") // bad version
f.Add("version:STSv2\nmode:none\nmax_age:0") // bad version
f.Add("version:STSv1\nmax_age:0\nmx:example.com") // missing mode
f.Add("version:STSv1\nmode:none") // missing max_age
f.Add("version:STSv1\nmax_age:0\nmode:enforce") // missing mx for mode
f.Add("version:STSv1\nmax_age:0\nmode:testing") // missing mx for mode
f.Add("max_age:0\nmode:none") // missing version
f.Add("version:STSv1\nmode:none\nmax_age:0 ") // trailing whitespace
f.Add("version:STSv1\nmode:none\nmax_age:01234567890") // max_age too long
f.Add("version:STSv1\nmode:bad\nmax_age:1") // bad mode
f.Add("version:STSv1\nmode:none\nmax_age:a") // bad max_age
f.Add("version:STSv1\nmode:enforce\nmax_age:0\nmx:") // missing value
f.Add("version:STSv1\nmode:enforce\nmax_age:0\nmx:*.*.example") // bad mx
f.Add("version:STSv1\nmode:enforce\nmax_age:0\nmx:**.example") // bad mx
f.Add("version:STSv1\nmode:enforce\nmax_age:0\nmx:**.example-") // bad mx
f.Add("version:STSv1\nmode:enforce\nmax_age:0\nmx:test.example-") // bad mx
f.Add("version:STSv1\nmode:none\nmax_age:0\next:") // empty extension
f.Add("version:STSv1\nmode:none\nmax_age:0\next:abc ") // trailing space
f.Add("version:STSv1\nmode:none\nmax_age:0\next:a\t") // invalid char
f.Add("version:STSv1\nmode:none\nmax_age:0\na12345678901234567890123456789012:123") // long extension name
f.Add("version:STSv1\nmode:none\nmax_age:0\n_bad:test") // bad ext name
f.Fuzz(func(t *testing.T, s string) {
r, err := ParsePolicy(s)
if err == nil {
_ = r.String()
}
})
}