This commit is contained in:
Mechiel Lukkien
2023-01-30 14:27:06 +01:00
commit cb229cb6cf
1256 changed files with 491723 additions and 0 deletions

285
mtastsdb/db.go Normal file
View File

@ -0,0 +1,285 @@
// Package mtastsdb stores MTA-STS policies for later use.
//
// An MTA-STS policy can specify how long it may be cached. By storing a
// policy, it does not have to be fetched again during email delivery, which
// makes it harder for attackers to intervene.
package mtastsdb
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/mjl-/bstore"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtasts"
)
var xlog = mlog.New("mtastsdb")
var (
metricGet = promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "mox_mtastsdb_get_total",
Help: "Number of Get by result.",
},
[]string{"result"},
)
)
var timeNow = time.Now // Tests override this.
// PolicyRecord is a cached policy or absence of a policy.
type PolicyRecord struct {
Domain string // Domain name, with unicode characters.
Inserted time.Time `bstore:"default now"`
ValidEnd time.Time
LastUpdate time.Time // Policies are refreshed on use and periodically.
LastUse time.Time `bstore:"index"`
Backoff bool
RecordID string // As retrieved from DNS.
mtasts.Policy // As retrieved from the well-known HTTPS url.
}
var (
// No valid non-expired policy in database.
ErrNotFound = errors.New("mtastsdb: policy not found")
// Indicates an MTA-STS TXT record was fetched recently, but fetching the policy
// failed and should not yet be retried.
ErrBackoff = errors.New("mtastsdb: policy fetch failed recently")
)
var mtastsDB *bstore.DB
var mutex sync.Mutex
func database() (rdb *bstore.DB, rerr error) {
mutex.Lock()
defer mutex.Unlock()
if mtastsDB == nil {
p := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(p), 0770)
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, PolicyRecord{})
if err != nil {
return nil, err
}
mtastsDB = db
}
return mtastsDB, nil
}
// Init opens the database and starts a goroutine that refreshes policies in
// the database, and keeps doing so periodically.
func Init(refresher bool) error {
_, err := database()
if err != nil {
return err
}
if refresher {
// todo: allow us to shut down cleanly?
go refresh()
}
return nil
}
// Close closes the database.
func Close() {
mutex.Lock()
defer mutex.Unlock()
if mtastsDB != nil {
mtastsDB.Close()
mtastsDB = nil
}
}
// Lookup looks up a policy for the domain in the database.
//
// Only non-expired records are returned.
func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
log := xlog.WithContext(ctx)
db, err := database()
if err != nil {
return nil, err
}
if domain.IsZero() {
return nil, fmt.Errorf("empty domain")
}
now := timeNow()
q := bstore.QueryDB[PolicyRecord](db)
q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
q.FilterGreater("ValidEnd", now)
pr, err := q.Get()
if err == bstore.ErrAbsent {
return nil, ErrNotFound
} else if err != nil {
return nil, err
}
pr.LastUse = now
if err := db.Update(&pr); err != nil {
log.Errorx("marking cached mta-sts policy as used in database", err)
}
if pr.Backoff {
return nil, ErrBackoff
}
return &pr, nil
}
// Upsert adds the policy to the database, overwriting an existing policy for the domain.
// Policy can be nil, indicating a failure to fetch the policy.
func Upsert(domain dns.Domain, recordID string, policy *mtasts.Policy) error {
db, err := database()
if err != nil {
return err
}
return db.Write(func(tx *bstore.Tx) error {
pr := PolicyRecord{Domain: domain.Name()}
err := tx.Get(&pr)
if err != nil && err != bstore.ErrAbsent {
return err
}
now := timeNow()
var p mtasts.Policy
if policy != nil {
p = *policy
} else {
// ../rfc/8461:552
p.Mode = mtasts.ModeNone
p.MaxAgeSeconds = 5 * 60
}
backoff := policy == nil
validEnd := now.Add(time.Duration(p.MaxAgeSeconds) * time.Second)
if err == bstore.ErrAbsent {
pr = PolicyRecord{domain.Name(), now, validEnd, now, now, backoff, recordID, p}
return tx.Insert(&pr)
}
pr.ValidEnd = validEnd
pr.LastUpdate = now
pr.LastUse = now
pr.Backoff = backoff
pr.RecordID = recordID
pr.Policy = p
return tx.Update(&pr)
})
}
// PolicyRecords returns all policies in the database, sorted descending by last
// use, domain.
func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
db, err := database()
if err != nil {
return nil, err
}
return bstore.QueryDB[PolicyRecord](db).SortDesc("LastUse", "Domain").List()
}
// Get retrieves an MTA-STS policy for domain and whether it is fresh.
//
// If an error is returned, it should be considered a transient error, e.g. a
// temporary DNS lookup failure.
//
// The returned policy can be nil also when there is no error. In this case, the
// domain does not implement MTA-STS.
//
// If a policy is present in the local database, it is refreshed if needed. If no
// policy is present for the domain, an attempt is made to fetch the policy and
// store it in the local database.
//
// Some errors are logged but not otherwise returned, e.g. if a new policy is
// supposedly published but could not be retrieved.
func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy *mtasts.Policy, fresh bool, err error) {
log := xlog.WithContext(ctx)
defer func() {
result := "ok"
if err != nil && errors.Is(err, ErrBackoff) {
result = "backoff"
} else if err != nil && errors.Is(err, ErrNotFound) {
result = "notfound"
} else if err != nil {
result = "error"
}
metricGet.WithLabelValues(result).Inc()
log.Debugx("mtastsdb get result", err, mlog.Field("domain", domain), mlog.Field("fresh", fresh))
}()
cachedPolicy, err := lookup(ctx, domain)
if err != nil && errors.Is(err, ErrNotFound) {
// We don't have a policy for this domain, not even a record that we tried recently
// and should backoff. So attempt to fetch policy.
nctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
record, p, err := mtasts.Get(nctx, resolver, domain)
if err != nil {
switch {
case errors.Is(err, mtasts.ErrNoRecord) || errors.Is(err, mtasts.ErrMultipleRecords) || errors.Is(err, mtasts.ErrRecordSyntax) || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax):
// Remote is not doing MTA-STS, continue below. ../rfc/8461:333 ../rfc/8461:574
default:
// Interpret as temporary error, e.g. mtasts.ErrDNS, try again later.
return nil, false, fmt.Errorf("lookup up mta-sts policy: %w", err)
}
}
// Insert policy into database. If we could not fetch the policy itself, we back
// off for 5 minutes. ../rfc/8461:555
if err == nil || errors.Is(err, mtasts.ErrNoPolicy) || errors.Is(err, mtasts.ErrPolicyFetch) || errors.Is(err, mtasts.ErrPolicySyntax) {
var recordID string
if record != nil {
recordID = record.ID
}
if err := Upsert(domain, recordID, p); err != nil {
log.Errorx("inserting policy into cache, continuing", err)
}
}
return p, true, nil
} else if err != nil && errors.Is(err, ErrBackoff) {
// ../rfc/8461:552
// We recently failed to fetch a policy, act as if MTA-STS is not implemented.
return nil, false, nil
} else if err != nil {
return nil, false, fmt.Errorf("looking up mta-sts policy in cache: %w", err)
}
// Policy was found in database. Check in DNS it is still fresh.
policy = &cachedPolicy.Policy
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
record, _, _, err := mtasts.LookupRecord(ctx, resolver, domain)
if err != nil {
if !errors.Is(err, mtasts.ErrNoRecord) {
// Could be a temporary DNS or configuration error.
log.Errorx("checking for freshness of cached mta-sts dns txt record for domain, continuing with previously cached policy", err)
}
return policy, false, nil
} else if record.ID == cachedPolicy.RecordID {
return policy, true, nil
}
// New policy should be available.
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
defer cancel()
p, _, err := mtasts.FetchPolicy(ctx, domain)
if err != nil {
log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
return policy, false, nil
}
if err := Upsert(domain, record.ID, p); err != nil {
log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
}
return p, true, nil
}

158
mtastsdb/db_test.go Normal file
View File

@ -0,0 +1,158 @@
package mtastsdb
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtasts"
)
func tcheckf(t *testing.T, err error, format string, args ...any) {
if err != nil {
t.Fatalf("%s: %s", fmt.Sprintf(format, args...), err)
}
}
func TestDB(t *testing.T) {
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
mox.Conf.Static.DataDir = "."
dbpath := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(dbpath), 0770)
os.Remove(dbpath)
defer os.Remove(dbpath)
if err := Init(false); err != nil {
t.Fatalf("init database: %s", err)
}
defer Close()
ctx := context.Background()
// Mock time.
now := time.Now().Round(0)
timeNow = func() time.Time { return now }
defer func() { timeNow = time.Now }()
if p, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
t.Fatalf("expected not found, got %v, %#v", err, p)
}
policy1 := mtasts.Policy{
Version: "STSv1",
Mode: mtasts.ModeTesting,
MX: []mtasts.STSMX{
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
{Domain: dns.Domain{ASCII: "mx2.example.com"}},
{Domain: dns.Domain{ASCII: "mx.backup-example.com"}},
},
MaxAgeSeconds: 1296000,
}
if err := Upsert(dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
t.Fatalf("upsert record: %s", err)
}
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
t.Fatalf("lookup after insert: %s", err)
} else if !reflect.DeepEqual(got.Policy, policy1) {
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
}
policy2 := mtasts.Policy{
Version: "STSv1",
Mode: mtasts.ModeEnforce,
MX: []mtasts.STSMX{
{Domain: dns.Domain{ASCII: "mx1.example.com"}},
},
MaxAgeSeconds: 360000,
}
if err := Upsert(dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
t.Fatalf("upsert record: %s", err)
}
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
t.Fatalf("lookup after insert: %s", err)
} else if !reflect.DeepEqual(got.Policy, policy2) {
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
}
// Check if database holds expected record.
records, err := PolicyRecords(context.Background())
tcheckf(t, err, "policyrecords")
expRecords := []PolicyRecord{
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
}
records[0].Policy = mtasts.Policy{}
expRecords[0].Policy = mtasts.Policy{}
if !reflect.DeepEqual(records, expRecords) {
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
}
if err := Upsert(dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
t.Fatalf("upsert record: %s", err)
}
records, err = PolicyRecords(context.Background())
tcheckf(t, err, "policyrecords")
expRecords = []PolicyRecord{
{"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}},
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
}
if !reflect.DeepEqual(records, expRecords) {
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
}
if _, err := lookup(context.Background(), dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
t.Fatalf("got %#v, expected ErrBackoff", err)
}
resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.example.com.": {"v=STSv1; id=124"},
"_mta-sts.other.example.com.": {"v=STSv1; id=1"},
"_mta-sts.temperror.example.com.": {""},
},
Fail: map[dns.Mockreq]struct{}{
{Type: "txt", Name: "_mta-sts.temperror.example.com."}: {},
},
}
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
t.Helper()
p, fresh, err := Get(context.Background(), resolver, dns.Domain{ASCII: domain})
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
t.Fatalf("got err %v, expected %v", err, expErr)
}
if !reflect.DeepEqual(p, expPolicy) || fresh != expFresh {
t.Fatalf("got policy %#v, fresh %v, expected %#v, %v", p, fresh, expPolicy, expFresh)
}
}
testGet("example.com", &policy2, true, nil)
testGet("other.example.com", nil, false, nil) // Back off, already in database.
testGet("absent.example.com", nil, true, nil) // No MTA-STS.
testGet("temperror.example.com", nil, false, mtasts.ErrDNS)
// Force refetch of policy, that will fail.
mtasts.HTTPClient.Transport = &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return nil, fmt.Errorf("bad")
},
}
defer func() {
mtasts.HTTPClient.Transport = nil
}()
resolver.TXT["_mta-sts.example.com."] = []string{"v=STSv1; id=125"}
testGet("example.com", &policy2, false, nil)
// Cached policy but no longer a DNS record.
delete(resolver.TXT, "_mta-sts.example.com.")
testGet("example.com", &policy2, false, nil)
}

176
mtastsdb/refresh.go Normal file
View File

@ -0,0 +1,176 @@
package mtastsdb
import (
"context"
"errors"
"fmt"
mathrand "math/rand"
"runtime/debug"
"time"
"github.com/mjl-/bstore"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/metrics"
"github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtasts"
)
func refresh() int {
interval := 24 * time.Hour
ticker := time.NewTicker(interval)
defer ticker.Stop()
var refreshed int
// Pro-actively refresh policies every 24 hours. ../rfc/8461:583
for {
ticker.Reset(interval)
ctx := context.WithValue(mox.Context, mlog.CidKey, mox.Cid())
n, err := refresh1(ctx, dns.StrictResolver{Pkg: "mtastsdb"}, time.Sleep)
if err != nil {
xlog.WithContext(ctx).Errorx("periodic refresh of cached mtasts policies", err)
}
if n > 0 {
refreshed += n
}
select {
case <-mox.Shutdown:
return refreshed
case <-ticker.C:
}
}
}
// refresh policies that have not been updated in the past 12 hours and remove
// policies not used for 180 days. We start with the first domain immediately, so
// an admin can see any (configuration) issues that are logged. We spread the
// refreshes evenly over the next 3 hours, randomizing the domains, and we add some
// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
// refresh doesn't mess up the timing.
func refresh1(ctx context.Context, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
db, err := database()
if err != nil {
return 0, err
}
now := timeNow()
qdel := bstore.QueryDB[PolicyRecord](db)
qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
if _, err := qdel.Delete(); err != nil {
return 0, fmt.Errorf("deleting old unused policies: %s", err)
}
qup := bstore.QueryDB[PolicyRecord](db)
qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
prs, err := qup.List()
if err != nil {
return 0, fmt.Errorf("querying policies to refresh: %s", err)
}
if len(prs) == 0 {
// Nothing to do.
return 0, nil
}
// Randomize list.
rand := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
for i := range prs {
if i == 0 {
continue
}
j := rand.Intn(i + 1)
prs[i], prs[j] = prs[j], prs[i]
}
// Launch goroutine with the refresh.
xlog.WithContext(ctx).Debug("will refresh mta-sts policies over next 3 hours", mlog.Field("count", len(prs)))
start := timeNow()
for i, pr := range prs {
go refreshDomain(ctx, db, resolver, pr)
if i < len(prs)-1 {
interval := 3 * int64(time.Hour) / int64(len(prs)-1)
extra := time.Duration(rand.Int63n(interval) - interval/2)
next := start.Add(time.Duration(int64(i+1)*interval) + extra)
d := next.Sub(timeNow())
if d > 0 {
sleep(d)
}
}
}
return len(prs), nil
}
func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr PolicyRecord) {
log := xlog.WithContext(ctx)
defer func() {
x := recover()
if x != nil {
// Should not happen, but make sure errors don't take down the application.
log.Error("refresh1", mlog.Field("panic", x))
debug.PrintStack()
metrics.PanicInc("mtastsdb")
}
}()
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
d, err := dns.ParseDomain(pr.Domain)
if err != nil {
log.Errorx("refreshing mta-sts policy: parsing policy domain", err, mlog.Field("domain", d))
return
}
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d)
if err == nil && record.ID == pr.RecordID {
qup := bstore.QueryDB[PolicyRecord](db)
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
now := timeNow()
update := PolicyRecord{
LastUpdate: now,
ValidEnd: now.Add(time.Duration(pr.MaxAgeSeconds) * time.Second),
}
if n, err := qup.UpdateNonzero(update); err != nil {
log.Errorx("updating refreshed, unmodified policy in database", err)
} else if n != 1 {
log.Info("expected to update 1 policy after refresh", mlog.Field("count", n))
}
return
}
// ../rfc/8461:587
if err != nil && pr.Mode == mtasts.ModeNone {
return
} else if err != nil {
log.Errorx("looking up mta-sts record for domain", err, mlog.Field("domain", d))
// Try to fetch new policy. It could be just DNS that is down. We don't want to let our policy expire.
}
p, _, err := mtasts.FetchPolicy(ctx, d)
if err != nil {
if !errors.Is(err, mtasts.ErrNoPolicy) || pr.Mode != mtasts.ModeNone {
log.Errorx("refreshing mtasts policy for domain", err, mlog.Field("domain", d))
}
return
}
now := timeNow()
update := map[string]any{
"LastUpdate": now,
"ValidEnd": now.Add(time.Duration(p.MaxAgeSeconds) * time.Second),
"Backoff": false,
"Policy": *p,
}
if record != nil {
update["RecordID"] = record.ID
}
qup := bstore.QueryDB[PolicyRecord](db)
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
if n, err := qup.UpdateFields(update); err != nil {
log.Errorx("updating refreshed, modified policy in database", err)
} else if n != 1 {
log.Info("updating refreshed, did not update 1 policy", mlog.Field("count", n))
}
}

231
mtastsdb/refresh_test.go Normal file
View File

@ -0,0 +1,231 @@
package mtastsdb
import (
"context"
"crypto/ed25519"
cryptorand "crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/mjl-/bstore"
"github.com/mjl-/mox/dns"
"github.com/mjl-/mox/mox-"
"github.com/mjl-/mox/mtasts"
)
func TestRefresh(t *testing.T) {
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
mox.Conf.Static.DataDir = "."
dbpath := mox.DataDirPath("mtasts.db")
os.MkdirAll(filepath.Dir(dbpath), 0770)
os.Remove(dbpath)
defer os.Remove(dbpath)
if err := Init(false); err != nil {
t.Fatalf("init database: %s", err)
}
defer Close()
db, err := database()
if err != nil {
t.Fatalf("database: %s", err)
}
cert := fakeCert(t, false)
defer func() {
mtasts.HTTPClient.Transport = nil
}()
insert := func(domain string, validEnd, lastUpdate, lastUse time.Time, backoff bool, recordID string, mode mtasts.Mode, maxAge int, mx string) {
t.Helper()
mxd, err := dns.ParseDomain(mx)
if err != nil {
t.Fatalf("parsing mx domain %q: %s", mx, err)
}
policy := mtasts.Policy{
Version: "STSv1",
Mode: mode,
MX: []mtasts.STSMX{{Wildcard: false, Domain: mxd}},
MaxAgeSeconds: maxAge,
Extensions: nil,
}
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy}
if err := db.Insert(&pr); err != nil {
t.Fatalf("insert policy: %s", err)
}
}
now := time.Now()
// Updated just now.
insert("mox.example", now.Add(24*time.Hour), now, now, false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be removed.
insert("stale.mox.example", now.Add(-time.Hour), now, now.Add(-181*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be refreshed, same id.
insert("refresh.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be refreshed and succeed.
insert("policyok.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
// To be refreshed and fail to fetch.
insert("policybad.mox.example", now.Add(7*24*time.Hour), now.Add(-24*time.Hour), now.Add(-179*24*time.Hour), false, "1", mtasts.ModeEnforce, 3600, "mx.mox.example.com")
resolver := dns.MockResolver{
TXT: map[string][]string{
"_mta-sts.refresh.mox.example.": {"v=STSv1; id=1"},
"_mta-sts.policyok.mox.example.": {"v=STSv1; id=2"},
"_mta-sts.policybad.mox.example.": {"v=STSv1; id=2"},
},
}
pool := x509.NewCertPool()
pool.AddCert(cert.Leaf)
l := newPipeListener()
defer l.Close()
go func() {
mux := &http.ServeMux{}
mux.HandleFunc("/.well-known/mta-sts.txt", func(w http.ResponseWriter, r *http.Request) {
if r.Host == "mta-sts.policybad.mox.example" {
w.WriteHeader(500)
return
}
fmt.Fprintf(w, "version: STSv1\nmode: enforce\nmx: mx.mox.example.com\nmax_age: 3600\n")
})
s := &http.Server{
Handler: mux,
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
},
ErrorLog: log.New(io.Discard, "", 0),
}
s.ServeTLS(l, "", "")
}()
mtasts.HTTPClient.Transport = &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return l.Dial()
},
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
}
slept := 0
sleep := func(d time.Duration) {
slept++
interval := 3 * time.Hour / 2
if d < time.Duration(slept)*interval-interval/2 || d > time.Duration(slept)*interval+interval/2 {
t.Fatalf("bad sleep duration %v", d)
}
}
if n, err := refresh1(context.Background(), resolver, sleep); err != nil || n != 3 {
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
}
if slept != 2 {
t.Fatalf("bad sleeps, %d instead of 2", slept)
}
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
// Should not do any more refreshes and return immediately.
q := bstore.QueryDB[PolicyRecord](db)
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
if _, err := q.Delete(); err != nil {
t.Fatalf("delete record that would be refreshed: %v", err)
}
mox.Context = context.Background()
mox.Shutdown = make(chan struct{})
close(mox.Shutdown)
n := refresh()
if n != 0 {
t.Fatalf("refresh found unexpected work, n %d", n)
}
mox.Shutdown = make(chan struct{})
}
type pipeListener struct {
sync.Mutex
closed bool
C chan net.Conn
}
var _ net.Listener = &pipeListener{}
func newPipeListener() *pipeListener { return &pipeListener{C: make(chan net.Conn)} }
func (l *pipeListener) Dial() (net.Conn, error) {
l.Lock()
defer l.Unlock()
if l.closed {
return nil, errors.New("closed")
}
c, s := net.Pipe()
l.C <- s
return c, nil
}
func (l *pipeListener) Accept() (net.Conn, error) {
conn := <-l.C
if conn == nil {
return nil, io.EOF
}
return conn, nil
}
func (l *pipeListener) Close() error {
l.Lock()
defer l.Unlock()
if !l.closed {
l.closed = true
close(l.C)
}
return nil
}
func (l *pipeListener) Addr() net.Addr { return pipeAddr{} }
type pipeAddr struct{}
func (a pipeAddr) Network() string { return "pipe" }
func (a pipeAddr) String() string { return "pipe" }
func fakeCert(t *testing.T, expired bool) tls.Certificate {
notAfter := time.Now()
if expired {
notAfter = notAfter.Add(-time.Hour)
} else {
notAfter = notAfter.Add(time.Hour)
}
privKey := ed25519.NewKeyFromSeed(make([]byte, ed25519.SeedSize)) // Fake key, don't use this for real!
template := &x509.Certificate{
SerialNumber: big.NewInt(1), // Required field...
DNSNames: []string{"mta-sts.policybad.mox.example", "mta-sts.policyok.mox.example"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: notAfter,
}
localCertBuf, err := x509.CreateCertificate(cryptorand.Reader, template, template, privKey.Public(), privKey)
if err != nil {
t.Fatalf("making certificate: %s", err)
}
cert, err := x509.ParseCertificate(localCertBuf)
if err != nil {
t.Fatalf("parsing generated certificate: %s", err)
}
c := tls.Certificate{
Certificate: [][]byte{localCertBuf},
PrivateKey: privKey,
Leaf: cert,
}
return c
}