mirror of
https://github.com/mjl-/mox.git
synced 2025-07-12 17:44:35 +03:00
mox!
This commit is contained in:
285
mtastsdb/db.go
Normal file
285
mtastsdb/db.go
Normal file
@ -0,0 +1,285 @@
|
||||
// Package mtastsdb stores MTA-STS policies for later use.
|
||||
//
|
||||
// An MTA-STS policy can specify how long it may be cached. By storing a
|
||||
// policy, it does not have to be fetched again during email delivery, which
|
||||
// makes it harder for attackers to intervene.
|
||||
package mtastsdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
"github.com/mjl-/bstore"
|
||||
|
||||
"github.com/mjl-/mox/dns"
|
||||
"github.com/mjl-/mox/mlog"
|
||||
"github.com/mjl-/mox/mox-"
|
||||
"github.com/mjl-/mox/mtasts"
|
||||
)
|
||||
|
||||
var xlog = mlog.New("mtastsdb")
|
||||
|
||||
var (
|
||||
metricGet = promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "mox_mtastsdb_get_total",
|
||||
Help: "Number of Get by result.",
|
||||
},
|
||||
[]string{"result"},
|
||||
)
|
||||
)
|
||||
|
||||
var timeNow = time.Now // Tests override this.
|
||||
|
||||
// PolicyRecord is a cached policy or absence of a policy.
|
||||
type PolicyRecord struct {
|
||||
Domain string // Domain name, with unicode characters.
|
||||
Inserted time.Time `bstore:"default now"`
|
||||
ValidEnd time.Time
|
||||
LastUpdate time.Time // Policies are refreshed on use and periodically.
|
||||
LastUse time.Time `bstore:"index"`
|
||||
Backoff bool
|
||||
RecordID string // As retrieved from DNS.
|
||||
mtasts.Policy // As retrieved from the well-known HTTPS url.
|
||||
}
|
||||
|
||||
var (
|
||||
// No valid non-expired policy in database.
|
||||
ErrNotFound = errors.New("mtastsdb: policy not found")
|
||||
|
||||
// Indicates an MTA-STS TXT record was fetched recently, but fetching the policy
|
||||
// failed and should not yet be retried.
|
||||
ErrBackoff = errors.New("mtastsdb: policy fetch failed recently")
|
||||
)
|
||||
|
||||
var mtastsDB *bstore.DB
|
||||
var mutex sync.Mutex
|
||||
|
||||
func database() (rdb *bstore.DB, rerr error) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
if mtastsDB == nil {
|
||||
p := mox.DataDirPath("mtasts.db")
|
||||
os.MkdirAll(filepath.Dir(p), 0770)
|
||||
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, PolicyRecord{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mtastsDB = db
|
||||
}
|
||||
return mtastsDB, nil
|
||||
}
|
||||
|
||||
// Init opens the database and starts a goroutine that refreshes policies in
|
||||
// the database, and keeps doing so periodically.
|
||||
func Init(refresher bool) error {
|
||||
_, err := database()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if refresher {
|
||||
// todo: allow us to shut down cleanly?
|
||||
go refresh()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the database.
|
||||
func Close() {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
if mtastsDB != nil {
|
||||
mtastsDB.Close()
|
||||
mtastsDB = nil
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup looks up a policy for the domain in the database.
|
||||
//
|
||||
// Only non-expired records are returned.
|
||||
func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||
log := xlog.WithContext(ctx)
|
||||
db, err := database()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if domain.IsZero() {
|
||||
return nil, fmt.Errorf("empty domain")
|
||||
}
|
||||
now := timeNow()
|
||||
q := bstore.QueryDB[PolicyRecord](db)
|
||||
q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
|
||||
q.FilterGreater("ValidEnd", now)
|
||||
pr, err := q.Get()
|
||||
if err == bstore.ErrAbsent {
|
||||
return nil, ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pr.LastUse = now
|
||||
if err := db.Update(&pr); err != nil {
|
||||
log.Errorx("marking cached mta-sts policy as used in database", err)
|
||||
}
|
||||
if pr.Backoff {
|
||||
return nil, ErrBackoff
|
||||
}
|
||||
return &pr, nil
|
||||
}
|
||||
|
||||
// Upsert adds the policy to the database, overwriting an existing policy for the domain.
|
||||
// Policy can be nil, indicating a failure to fetch the policy.
|
||||
func Upsert(domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
||||
db, err := database()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Write(func(tx *bstore.Tx) error {
|
||||
pr := PolicyRecord{Domain: domain.Name()}
|
||||
err := tx.Get(&pr)
|
||||
if err != nil && err != bstore.ErrAbsent {
|
||||
return err
|
||||
}
|
||||
|
||||
now := timeNow()
|
||||
|
||||
var p mtasts.Policy
|
||||
if policy != nil {
|
||||
p = *policy
|
||||
} else {
|
||||
// ../rfc/8461:552
|
||||
p.Mode = mtasts.ModeNone
|
||||
p.MaxAgeSeconds = 5 * 60
|
||||
}
|
||||
backoff := policy == nil
|
||||
validEnd := now.Add(time.Duration(p.MaxAgeSeconds) * time.Second)
|
||||
|
||||
if err == bstore.ErrAbsent {
|
||||
pr = PolicyRecord{domain.Name(), now, validEnd, now, now, backoff, recordID, p}
|
||||
return tx.Insert(&pr)
|
||||
}
|
||||
|
||||
pr.ValidEnd = validEnd
|
||||
pr.LastUpdate = now
|
||||
pr.LastUse = now
|
||||
pr.Backoff = backoff
|
||||
pr.RecordID = recordID
|
||||
pr.Policy = p
|
||||
return tx.Update(&pr)
|
||||
})
|
||||
}
|
||||
|
||||
// PolicyRecords returns all policies in the database, sorted descending by last
|
||||
// use, domain.
|
||||
func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
|
||||
db, err := database()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bstore.QueryDB[PolicyRecord](db).SortDesc("LastUse", "Domain").List()
|
||||
}
|
||||
|
||||
// Get retrieves an MTA-STS policy for domain and whether it is fresh.
|
||||
//
|
||||
// If an error is returned, it should be considered a transient error, e.g. a
|
||||
// temporary DNS lookup failure.
|
||||
//
|
||||
// The returned policy can be nil also when there is no error. In this case, the
|
||||
// domain does not implement MTA-STS.
|
||||
//
|
||||
// If a policy is present in the local database, it is refreshed if needed. If no
|
||||
// policy is present for the domain, an attempt is made to fetch the policy and
|
||||
// store it in the local database.
|
||||
//
|
||||
// Some errors are logged but not otherwise returned, e.g. if a new policy is
|
||||
// supposedly published but could not be retrieved.
|
||||
func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy *mtasts.Policy, fresh bool, err error) {
|
||||
log := xlog.WithContext(ctx)
|
||||
defer func() {
|
||||
result := "ok"
|
||||
if err != nil && errors.Is(err, ErrBackoff) {
|
||||
result = "backoff"
|
||||
} else if err != nil && errors.Is(err, ErrNotFound) {
|
||||
result = "notfound"
|
||||
} else if err != nil {
|
||||
result = "error"
|
||||
}
|
||||
metricGet.WithLabelValues(result).Inc()
|
||||
log.Debugx("mtastsdb get result", err, mlog.Field("domain", domain), mlog.Field("fresh", fresh))
|
||||
}()
|
||||
|
||||
cachedPolicy, err := lookup(ctx, domain)
|
||||
if err != nil && errors.Is(err, ErrNotFound) {
|
||||
// We don't have a policy for this domain, not even a record that we tried recently
|
||||
// and should backoff. So attempt to fetch policy.
|
||||
nctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
record, p, err := mtasts.Get(nctx, resolver, domain)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, mtasts.ErrNoRecord) || errors.Is(err, mtasts.ErrMultipleRecords) || errors.Is(err, mtasts.ErrRecordSyntax) || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax):
|
||||
// Remote is not doing MTA-STS, continue below. ../rfc/8461:333 ../rfc/8461:574
|
||||
default:
|
||||
// Interpret as temporary error, e.g. mtasts.ErrDNS, try again later.
|
||||
return nil, false, fmt.Errorf("lookup up mta-sts policy: %w", err)
|
||||
}
|
||||
}
|
||||
// Insert policy into database. If we could not fetch the policy itself, we back
|
||||
// off for 5 minutes. ../rfc/8461:555
|
||||
if err == nil || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax) {
|
||||
var recordID string
|
||||
if record != nil {
|
||||
recordID = record.ID
|
||||
}
|
||||
if err := Upsert(domain, recordID, p); err != nil {
|
||||
log.Errorx("inserting policy into cache, continuing", err)
|
||||
}
|
||||
}
|
||||
return p, true, nil
|
||||
} else if err != nil && errors.Is(err, ErrBackoff) {
|
||||
// ../rfc/8461:552
|
||||
// We recently failed to fetch a policy, act as if MTA-STS is not implemented.
|
||||
return nil, false, nil
|
||||
} else if err != nil {
|
||||
return nil, false, fmt.Errorf("looking up mta-sts policy in cache: %w", err)
|
||||
}
|
||||
|
||||
// Policy was found in database. Check in DNS it is still fresh.
|
||||
policy = &cachedPolicy.Policy
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
record, _, _, err := mtasts.LookupRecord(ctx, resolver, domain)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mtasts.ErrNoRecord) {
|
||||
// Could be a temporary DNS or configuration error.
|
||||
log.Errorx("checking for freshness of cached mta-sts dns txt record for domain, continuing with previously cached policy", err)
|
||||
}
|
||||
return policy, false, nil
|
||||
} else if record.ID == cachedPolicy.RecordID {
|
||||
return policy, true, nil
|
||||
}
|
||||
// New policy should be available.
|
||||
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
p, _, err := mtasts.FetchPolicy(ctx, domain)
|
||||
if err != nil {
|
||||
log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
|
||||
return policy, false, nil
|
||||
}
|
||||
if err := Upsert(domain, record.ID, p); err != nil {
|
||||
log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
|
||||
}
|
||||
return p, true, nil
|
||||
}
|
158
mtastsdb/db_test.go
Normal file
158
mtastsdb/db_test.go
Normal file
@ -0,0 +1,158 @@
|
||||
package mtastsdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mjl-/mox/dns"
|
||||
"github.com/mjl-/mox/mox-"
|
||||
"github.com/mjl-/mox/mtasts"
|
||||
)
|
||||
|
||||
func tcheckf(t *testing.T, err error, format string, args ...any) {
|
||||
if err != nil {
|
||||
t.Fatalf("%s: %s", fmt.Sprintf(format, args...), err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
||||
mox.Conf.Static.DataDir = "."
|
||||
|
||||
dbpath := mox.DataDirPath("mtasts.db")
|
||||
os.MkdirAll(filepath.Dir(dbpath), 0770)
|
||||
os.Remove(dbpath)
|
||||
defer os.Remove(dbpath)
|
||||
|
||||
if err := Init(false); err != nil {
|
||||
t.Fatalf("init database: %s", err)
|
||||
}
|
||||
defer Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock time.
|
||||
now := time.Now().Round(0)
|
||||
timeNow = func() time.Time { return now }
|
||||
defer func() { timeNow = time.Now }()
|
||||
|
||||
if p, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
|
||||
t.Fatalf("expected not found, got %v, %#v", err, p)
|
||||
}
|
||||
|
||||
policy1 := mtasts.Policy{
|
||||
Version: "STSv1",
|
||||
Mode: mtasts.ModeTesting,
|
||||
MX: []mtasts.STSMX{
|
||||
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
|
||||
{Domain: dns.Domain{ASCII: "mx2.example.com"}},
|
||||
{Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
|
||||
},
|
||||
MaxAgeSeconds: 1296000,
|
||||
}
|
||||
if err := Upsert(dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
|
||||
t.Fatalf("upsert record: %s", err)
|
||||
}
|
||||
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||
t.Fatalf("lookup after insert: %s", err)
|
||||
} else if !reflect.DeepEqual(got.Policy, policy1) {
|
||||
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
|
||||
}
|
||||
|
||||
policy2 := mtasts.Policy{
|
||||
Version: "STSv1",
|
||||
Mode: mtasts.ModeEnforce,
|
||||
MX: []mtasts.STSMX{
|
||||
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
|
||||
},
|
||||
MaxAgeSeconds: 360000,
|
||||
}
|
||||
if err := Upsert(dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
|
||||
t.Fatalf("upsert record: %s", err)
|
||||
}
|
||||
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||
t.Fatalf("lookup after insert: %s", err)
|
||||
} else if !reflect.DeepEqual(got.Policy, policy2) {
|
||||
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
|
||||
}
|
||||
|
||||
// Check if database holds expected record.
|
||||
records, err := PolicyRecords(context.Background())
|
||||
tcheckf(t, err, "policyrecords")
|
||||
expRecords := []PolicyRecord{
|
||||
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
|
||||
}
|
||||
records[0].Policy = mtasts.Policy{}
|
||||
expRecords[0].Policy = mtasts.Policy{}
|
||||
if !reflect.DeepEqual(records, expRecords) {
|
||||
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
||||
}
|
||||
|
||||
if err := Upsert(dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
|
||||
t.Fatalf("upsert record: %s", err)
|
||||
}
|
||||
records, err = PolicyRecords(context.Background())
|
||||
tcheckf(t, err, "policyrecords")
|
||||
expRecords = []PolicyRecord{
|
||||
{"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}},
|
||||
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
|
||||
}
|
||||
if !reflect.DeepEqual(records, expRecords) {
|
||||
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
||||
}
|
||||
|
||||
if _, err := lookup(context.Background(), dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
|
||||
t.Fatalf("got %#v, expected ErrBackoff", err)
|
||||
}
|
||||
|
||||
resolver := dns.MockResolver{
|
||||
TXT: map[string][]string{
|
||||
"_mta-sts.example.com.": {"v=STSv1; id=124"},
|
||||
"_mta-sts.other.example.com.": {"v=STSv1; id=1"},
|
||||
"_mta-sts.temperror.example.com.": {""},
|
||||
},
|
||||
Fail: map[dns.Mockreq]struct{}{
|
||||
{Type: "txt", Name: "_mta-sts.temperror.example.com."}: {},
|
||||
},
|
||||
}
|
||||
|
||||
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
|
||||
t.Helper()
|
||||
p, fresh, err := Get(context.Background(), resolver, dns.Domain{ASCII: domain})
|
||||
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
||||
t.Fatalf("got err %v, expected %v", err, expErr)
|
||||
}
|
||||
if !reflect.DeepEqual(p, expPolicy) || fresh != expFresh {
|
||||
t.Fatalf("got policy %#v, fresh %v, expected %#v, %v", p, fresh, expPolicy, expFresh)
|
||||
}
|
||||
}
|
||||
|
||||
testGet("example.com", &policy2, true, nil)
|
||||
testGet("other.example.com", nil, false, nil) // Back off, already in database.
|
||||
testGet("absent.example.com", nil, true, nil) // No MTA-STS.
|
||||
testGet("temperror.example.com", nil, false, mtasts.ErrDNS)
|
||||
|
||||
// Force refetch of policy, that will fail.
|
||||
mtasts.HTTPClient.Transport = &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return nil, fmt.Errorf("bad")
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
mtasts.HTTPClient.Transport = nil
|
||||
}()
|
||||
resolver.TXT["_mta-sts.example.com."] = []string{"v=STSv1; id=125"}
|
||||
testGet("example.com", &policy2, false, nil)
|
||||
|
||||
// Cached policy but no longer a DNS record.
|
||||
delete(resolver.TXT, "_mta-sts.example.com.")
|
||||
testGet("example.com", &policy2, false, nil)
|
||||
}
|
176
mtastsdb/refresh.go
Normal file
176
mtastsdb/refresh.go
Normal file
@ -0,0 +1,176 @@
|
||||
package mtastsdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
mathrand "math/rand"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/mjl-/bstore"
|
||||
|
||||
"github.com/mjl-/mox/dns"
|
||||
"github.com/mjl-/mox/metrics"
|
||||
"github.com/mjl-/mox/mlog"
|
||||
"github.com/mjl-/mox/mox-"
|
||||
"github.com/mjl-/mox/mtasts"
|
||||
)
|
||||
|
||||
func refresh() int {
|
||||
interval := 24 * time.Hour
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
var refreshed int
|
||||
|
||||
// Pro-actively refresh policies every 24 hours. ../rfc/8461:583
|
||||
for {
|
||||
ticker.Reset(interval)
|
||||
|
||||
ctx := context.WithValue(mox.Context, mlog.CidKey, mox.Cid())
|
||||
n, err := refresh1(ctx, dns.StrictResolver{Pkg: "mtastsdb"}, time.Sleep)
|
||||
if err != nil {
|
||||
xlog.WithContext(ctx).Errorx("periodic refresh of cached mtasts policies", err)
|
||||
}
|
||||
if n > 0 {
|
||||
refreshed += n
|
||||
}
|
||||
|
||||
select {
|
||||
case <-mox.Shutdown:
|
||||
return refreshed
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refresh policies that have not been updated in the past 12 hours and remove
|
||||
// policies not used for 180 days. We start with the first domain immediately, so
|
||||
// an admin can see any (configuration) issues that are logged. We spread the
|
||||
// refreshes evenly over the next 3 hours, randomizing the domains, and we add some
|
||||
// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
|
||||
// refresh doesn't mess up the timing.
|
||||
func refresh1(ctx context.Context, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
|
||||
db, err := database()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
now := timeNow()
|
||||
qdel := bstore.QueryDB[PolicyRecord](db)
|
||||
qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
|
||||
if _, err := qdel.Delete(); err != nil {
|
||||
return 0, fmt.Errorf("deleting old unused policies: %s", err)
|
||||
}
|
||||
|
||||
qup := bstore.QueryDB[PolicyRecord](db)
|
||||
qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
|
||||
prs, err := qup.List()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("querying policies to refresh: %s", err)
|
||||
}
|
||||
|
||||
if len(prs) == 0 {
|
||||
// Nothing to do.
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Randomize list.
|
||||
rand := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||
for i := range prs {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
j := rand.Intn(i + 1)
|
||||
prs[i], prs[j] = prs[j], prs[i]
|
||||
}
|
||||
|
||||
// Launch goroutine with the refresh.
|
||||
xlog.WithContext(ctx).Debug("will refresh mta-sts policies over next 3 hours", mlog.Field("count", len(prs)))
|
||||
start := timeNow()
|
||||
for i, pr := range prs {
|
||||
go refreshDomain(ctx, db, resolver, pr)
|
||||
if i < len(prs)-1 {
|
||||
interval := 3 * int64(time.Hour) / int64(len(prs)-1)
|
||||
extra := time.Duration(rand.Int63n(interval) - interval/2)
|
||||
next := start.Add(time.Duration(int64(i+1)*interval) + extra)
|
||||
d := next.Sub(timeNow())
|
||||
if d > 0 {
|
||||
sleep(d)
|
||||
}
|
||||
}
|
||||
}
|
||||
return len(prs), nil
|
||||
}
|
||||
|
||||
func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr PolicyRecord) {
|
||||
log := xlog.WithContext(ctx)
|
||||
defer func() {
|
||||
x := recover()
|
||||
if x != nil {
|
||||
// Should not happen, but make sure errors don't take down the application.
|
||||
log.Error("refresh1", mlog.Field("panic", x))
|
||||
debug.PrintStack()
|
||||
metrics.PanicInc("mtastsdb")
|
||||
}
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
|
||||
d, err := dns.ParseDomain(pr.Domain)
|
||||
if err != nil {
|
||||
log.Errorx("refreshing mta-sts policy: parsing policy domain", err, mlog.Field("domain", d))
|
||||
return
|
||||
}
|
||||
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
|
||||
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d)
|
||||
if err == nil && record.ID == pr.RecordID {
|
||||
qup := bstore.QueryDB[PolicyRecord](db)
|
||||
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
||||
now := timeNow()
|
||||
update := PolicyRecord{
|
||||
LastUpdate: now,
|
||||
ValidEnd: now.Add(time.Duration(pr.MaxAgeSeconds) * time.Second),
|
||||
}
|
||||
if n, err := qup.UpdateNonzero(update); err != nil {
|
||||
log.Errorx("updating refreshed, unmodified policy in database", err)
|
||||
} else if n != 1 {
|
||||
log.Info("expected to update 1 policy after refresh", mlog.Field("count", n))
|
||||
}
|
||||
return
|
||||
}
|
||||
// ../rfc/8461:587
|
||||
if err != nil && pr.Mode == mtasts.ModeNone {
|
||||
return
|
||||
} else if err != nil {
|
||||
log.Errorx("looking up mta-sts record for domain", err, mlog.Field("domain", d))
|
||||
// Try to fetch new policy. It could be just DNS that is down. We don't want to let our policy expire.
|
||||
}
|
||||
|
||||
p, _, err := mtasts.FetchPolicy(ctx, d)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mtasts.ErrNoPolicy) || pr.Mode != mtasts.ModeNone {
|
||||
log.Errorx("refreshing mtasts policy for domain", err, mlog.Field("domain", d))
|
||||
}
|
||||
return
|
||||
}
|
||||
now := timeNow()
|
||||
update := map[string]any{
|
||||
"LastUpdate": now,
|
||||
"ValidEnd": now.Add(time.Duration(p.MaxAgeSeconds) * time.Second),
|
||||
"Backoff": false,
|
||||
"Policy": *p,
|
||||
}
|
||||
if record != nil {
|
||||
update["RecordID"] = record.ID
|
||||
}
|
||||
qup := bstore.QueryDB[PolicyRecord](db)
|
||||
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
||||
if n, err := qup.UpdateFields(update); err != nil {
|
||||
log.Errorx("updating refreshed, modified policy in database", err)
|
||||
} else if n != 1 {
|
||||
log.Info("updating refreshed, did not update 1 policy", mlog.Field("count", n))
|
||||
}
|
||||
}
|
231
mtastsdb/refresh_test.go
Normal file
231
mtastsdb/refresh_test.go
Normal file
@ -0,0 +1,231 @@
|
||||
package mtastsdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mjl-/bstore"
|
||||
|
||||
"github.com/mjl-/mox/dns"
|
||||
"github.com/mjl-/mox/mox-"
|
||||
"github.com/mjl-/mox/mtasts"
|
||||
)
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
||||
mox.Conf.Static.DataDir = "."
|
||||
|
||||
dbpath := mox.DataDirPath("mtasts.db")
|
||||
os.MkdirAll(filepath.Dir(dbpath), 0770)
|
||||
os.Remove(dbpath)
|
||||
defer os.Remove(dbpath)
|
||||
|
||||
if err := Init(false); err != nil {
|
||||
t.Fatalf("init database: %s", err)
|
||||
}
|
||||
defer Close()
|
||||
|
||||
db, err := database()
|
||||
if err != nil {
|
||||
t.Fatalf("database: %s", err)
|
||||
}
|
||||
|
||||
cert := fakeCert(t, false)
|
||||
defer func() {
|
||||
mtasts.HTTPClient.Transport = nil
|
||||
}()
|
||||
|
||||
insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
|
||||
t.Helper()
|
||||
|
||||
mxd, err := dns.ParseDomain(mx)
|
||||
if err != nil {
|
||||
t.Fatalf("parsing mx domain %q: %s", mx, err)
|
||||
}
|
||||
policy := mtasts.Policy{
|
||||
Version: "STSv1",
|
||||
Mode: mode,
|
||||
MX: []mtasts.STSMX{{Wildcard: false, Domain: mxd}},
|
||||
MaxAgeSeconds: maxAge,
|
||||
Extensions: nil,
|
||||
}
|
||||
|
||||
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy}
|
||||
if err := db.Insert(&pr); err != nil {
|
||||
t.Fatalf("insert policy: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
// Updated just now.
|
||||
insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
|
||||
// To be removed.
|
||||
insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
|
||||
// To be refreshed, same id.
|
||||
insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
|
||||
// To be refreshed and succeed.
|
||||
insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
|
||||
// To be refreshed and fail to fetch.
|
||||
insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
|
||||
|
||||
resolver := dns.MockResolver{
|
||||
TXT: map[string][]string{
|
||||
"_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
|
||||
"_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
|
||||
"_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
|
||||
},
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(cert.Leaf)
|
||||
|
||||
l := newPipeListener()
|
||||
defer l.Close()
|
||||
go func() {
|
||||
mux := &http.ServeMux{}
|
||||
mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "mta-sts.policybad.mox.example" {
|
||||
w.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
|
||||
})
|
||||
s := &http.Server{
|
||||
Handler: mux,
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
},
|
||||
ErrorLog: log.New(io.Discard, "", 0),
|
||||
}
|
||||
s.ServeTLS(l, "", "")
|
||||
}()
|
||||
|
||||
mtasts.HTTPClient.Transport = &http.Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
return l.Dial()
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: pool,
|
||||
},
|
||||
}
|
||||
|
||||
slept := 0
|
||||
sleep := func(d time.Duration) {
|
||||
slept++
|
||||
interval := 3 * time.Hour / 2
|
||||
if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
|
||||
t.Fatalf("bad sleep duration %v", d)
|
||||
}
|
||||
}
|
||||
if n, err := refresh1(context.Background(), resolver, sleep); err != nil || n != 3 {
|
||||
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
|
||||
}
|
||||
if slept != 2 {
|
||||
t.Fatalf("bad sleeps, %d instead of 2", slept)
|
||||
}
|
||||
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
|
||||
|
||||
// Should not do any more refreshes and return immediately.
|
||||
q := bstore.QueryDB[PolicyRecord](db)
|
||||
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
|
||||
if _, err := q.Delete(); err != nil {
|
||||
t.Fatalf("delete record that would be refreshed: %v", err)
|
||||
}
|
||||
mox.Context = context.Background()
|
||||
mox.Shutdown = make(chan struct{})
|
||||
close(mox.Shutdown)
|
||||
n := refresh()
|
||||
if n != 0 {
|
||||
t.Fatalf("refresh found unexpected work, n %d", n)
|
||||
}
|
||||
mox.Shutdown = make(chan struct{})
|
||||
}
|
||||
|
||||
type pipeListener struct {
|
||||
sync.Mutex
|
||||
closed bool
|
||||
C chan net.Conn
|
||||
}
|
||||
|
||||
var _ net.Listener = &pipeListener{}
|
||||
|
||||
func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
|
||||
func (l *pipeListener) Dial() (net.Conn, error) {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
if l.closed {
|
||||
return nil, errors.New("closed")
|
||||
}
|
||||
c, s := net.Pipe()
|
||||
l.C <- s
|
||||
return c, nil
|
||||
}
|
||||
func (l *pipeListener) Accept() (net.Conn, error) {
|
||||
conn := <-l.C
|
||||
if conn == nil {
|
||||
return nil, io.EOF
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
func (l *pipeListener) Close() error {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
if !l.closed {
|
||||
l.closed = true
|
||||
close(l.C)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
|
||||
|
||||
type pipeAddr struct{}
|
||||
|
||||
func (a pipeAddr) Network() string { return "pipe" }
|
||||
func (a pipeAddr) String() string { return "pipe" }
|
||||
|
||||
func fakeCert(t *testing.T, expired bool) tls.Certificate {
|
||||
notAfter := time.Now()
|
||||
if expired {
|
||||
notAfter = notAfter.Add(-time.Hour)
|
||||
} else {
|
||||
notAfter = notAfter.Add(time.Hour)
|
||||
}
|
||||
|
||||
privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1), // Required field...
|
||||
DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: notAfter,
|
||||
}
|
||||
localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
|
||||
if err != nil {
|
||||
t.Fatalf("making certificate: %s", err)
|
||||
}
|
||||
cert, err := x509.ParseCertificate(localCertBuf)
|
||||
if err != nil {
|
||||
t.Fatalf("parsing generated certificate: %s", err)
|
||||
}
|
||||
c := tls.Certificate{
|
||||
Certificate: [][]byte{localCertBuf},
|
||||
PrivateKey: privKey,
|
||||
Leaf: cert,
|
||||
}
|
||||
return c
|
||||
}
|
Reference in New Issue
Block a user