mirror of
https://github.com/mjl-/mox.git
synced 2025-07-12 17:44:35 +03:00
update to latest bstore (with support for an index on a []string: Message.DKIMDomains), and cyclic data types (to be used for Message.Part soon); also adds a context.Context to database operations.
This commit is contained in:
@ -63,13 +63,13 @@ var (
|
||||
var mtastsDB *bstore.DB
|
||||
var mutex sync.Mutex
|
||||
|
||||
func database() (rdb *bstore.DB, rerr error) {
|
||||
func database(ctx context.Context) (rdb *bstore.DB, rerr error) {
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
if mtastsDB == nil {
|
||||
p := mox.DataDirPath("mtasts.db")
|
||||
os.MkdirAll(filepath.Dir(p), 0770)
|
||||
db, err := bstore.Open(p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, PolicyRecord{})
|
||||
db, err := bstore.Open(ctx, p, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, PolicyRecord{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -81,7 +81,7 @@ func database() (rdb *bstore.DB, rerr error) {
|
||||
// Init opens the database and starts a goroutine that refreshes policies in
|
||||
// the database, and keeps doing so periodically.
|
||||
func Init(refresher bool) error {
|
||||
_, err := database()
|
||||
_, err := database(mox.Shutdown)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -110,7 +110,7 @@ func Close() {
|
||||
// Only non-expired records are returned.
|
||||
func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||
log := xlog.WithContext(ctx)
|
||||
db, err := database()
|
||||
db, err := database(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -119,7 +119,7 @@ func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||
return nil, fmt.Errorf("empty domain")
|
||||
}
|
||||
now := timeNow()
|
||||
q := bstore.QueryDB[PolicyRecord](db)
|
||||
q := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||
q.FilterNonzero(PolicyRecord{Domain: domain.Name()})
|
||||
q.FilterGreater("ValidEnd", now)
|
||||
pr, err := q.Get()
|
||||
@ -130,7 +130,7 @@ func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||
}
|
||||
|
||||
pr.LastUse = now
|
||||
if err := db.Update(&pr); err != nil {
|
||||
if err := db.Update(ctx, &pr); err != nil {
|
||||
log.Errorx("marking cached mta-sts policy as used in database", err)
|
||||
}
|
||||
if pr.Backoff {
|
||||
@ -141,13 +141,13 @@ func lookup(ctx context.Context, domain dns.Domain) (*PolicyRecord, error) {
|
||||
|
||||
// Upsert adds the policy to the database, overwriting an existing policy for the domain.
|
||||
// Policy can be nil, indicating a failure to fetch the policy.
|
||||
func Upsert(domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
||||
db, err := database()
|
||||
func Upsert(ctx context.Context, domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
||||
db, err := database(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.Write(func(tx *bstore.Tx) error {
|
||||
return db.Write(ctx, func(tx *bstore.Tx) error {
|
||||
pr := PolicyRecord{Domain: domain.Name()}
|
||||
err := tx.Get(&pr)
|
||||
if err != nil && err != bstore.ErrAbsent {
|
||||
@ -185,11 +185,11 @@ func Upsert(domain dns.Domain, recordID string, policy *mtasts.Policy) error {
|
||||
// PolicyRecords returns all policies in the database, sorted descending by last
|
||||
// use, domain.
|
||||
func PolicyRecords(ctx context.Context) ([]PolicyRecord, error) {
|
||||
db, err := database()
|
||||
db, err := database(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bstore.QueryDB[PolicyRecord](db).SortDesc("LastUse", "Domain").List()
|
||||
return bstore.QueryDB[PolicyRecord](ctx, db).SortDesc("LastUse", "Domain").List()
|
||||
}
|
||||
|
||||
// Get retrieves an MTA-STS policy for domain and whether it is fresh.
|
||||
@ -244,7 +244,7 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
|
||||
if record != nil {
|
||||
recordID = record.ID
|
||||
}
|
||||
if err := Upsert(domain, recordID, p); err != nil {
|
||||
if err := Upsert(ctx, domain, recordID, p); err != nil {
|
||||
log.Errorx("inserting policy into cache, continuing", err)
|
||||
}
|
||||
}
|
||||
@ -259,9 +259,9 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
|
||||
|
||||
// Policy was found in database. Check in DNS it is still fresh.
|
||||
policy = &cachedPolicy.Policy
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
nctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
record, _, _, err := mtasts.LookupRecord(ctx, resolver, domain)
|
||||
record, _, _, err := mtasts.LookupRecord(nctx, resolver, domain)
|
||||
if err != nil {
|
||||
if !errors.Is(err, mtasts.ErrNoRecord) {
|
||||
// Could be a temporary DNS or configuration error.
|
||||
@ -271,15 +271,16 @@ func Get(ctx context.Context, resolver dns.Resolver, domain dns.Domain) (policy
|
||||
} else if record.ID == cachedPolicy.RecordID {
|
||||
return policy, true, nil
|
||||
}
|
||||
|
||||
// New policy should be available.
|
||||
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
nctx, cancel = context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
p, _, err := mtasts.FetchPolicy(ctx, domain)
|
||||
p, _, err := mtasts.FetchPolicy(nctx, domain)
|
||||
if err != nil {
|
||||
log.Errorx("fetching updated policy for domain, continuing with previously cached policy", err)
|
||||
return policy, false, nil
|
||||
}
|
||||
if err := Upsert(domain, record.ID, p); err != nil {
|
||||
if err := Upsert(ctx, domain, record.ID, p); err != nil {
|
||||
log.Errorx("inserting refreshed policy into cache, continuing with fresh policy", err)
|
||||
}
|
||||
return p, true, nil
|
||||
|
@ -1,7 +1,6 @@
|
||||
package mtastsdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@ -24,6 +23,7 @@ func tcheckf(t *testing.T, err error, format string, args ...any) {
|
||||
}
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
mox.Shutdown = ctxbg
|
||||
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
||||
mox.Conf.Static.DataDir = "."
|
||||
|
||||
@ -37,14 +37,12 @@ func TestDB(t *testing.T) {
|
||||
}
|
||||
defer Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Mock time.
|
||||
now := time.Now().Round(0)
|
||||
timeNow = func() time.Time { return now }
|
||||
defer func() { timeNow = time.Now }()
|
||||
|
||||
if p, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
|
||||
if p, err := lookup(ctxbg, dns.Domain{ASCII: "example.com"}); err != ErrNotFound {
|
||||
t.Fatalf("expected not found, got %v, %#v", err, p)
|
||||
}
|
||||
|
||||
@ -58,10 +56,10 @@ func TestDB(t *testing.T) {
|
||||
},
|
||||
MaxAgeSeconds: 1296000,
|
||||
}
|
||||
if err := Upsert(dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
|
||||
if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "123", &policy1); err != nil {
|
||||
t.Fatalf("upsert record: %s", err)
|
||||
}
|
||||
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||
if got, err := lookup(ctxbg, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||
t.Fatalf("lookup after insert: %s", err)
|
||||
} else if !reflect.DeepEqual(got.Policy, policy1) {
|
||||
t.Fatalf("mismatch between inserted and retrieved: got %#v, want %#v", got, policy1)
|
||||
@ -75,17 +73,17 @@ func TestDB(t *testing.T) {
|
||||
},
|
||||
MaxAgeSeconds: 360000,
|
||||
}
|
||||
if err := Upsert(dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
|
||||
if err := Upsert(ctxbg, dns.Domain{ASCII: "example.com"}, "124", &policy2); err != nil {
|
||||
t.Fatalf("upsert record: %s", err)
|
||||
}
|
||||
if got, err := lookup(ctx, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||
if got, err := lookup(ctxbg, dns.Domain{ASCII: "example.com"}); err != nil {
|
||||
t.Fatalf("lookup after insert: %s", err)
|
||||
} else if !reflect.DeepEqual(got.Policy, policy2) {
|
||||
t.Fatalf("mismatch between inserted and retrieved: got %v, want %v", got, policy2)
|
||||
}
|
||||
|
||||
// Check if database holds expected record.
|
||||
records, err := PolicyRecords(context.Background())
|
||||
records, err := PolicyRecords(ctxbg)
|
||||
tcheckf(t, err, "policyrecords")
|
||||
expRecords := []PolicyRecord{
|
||||
{"example.com", now, now.Add(time.Duration(policy2.MaxAgeSeconds) * time.Second), now, now, false, "124", policy2},
|
||||
@ -96,10 +94,10 @@ func TestDB(t *testing.T) {
|
||||
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
||||
}
|
||||
|
||||
if err := Upsert(dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
|
||||
if err := Upsert(ctxbg, dns.Domain{ASCII: "other.example.com"}, "", nil); err != nil {
|
||||
t.Fatalf("upsert record: %s", err)
|
||||
}
|
||||
records, err = PolicyRecords(context.Background())
|
||||
records, err = PolicyRecords(ctxbg)
|
||||
tcheckf(t, err, "policyrecords")
|
||||
expRecords = []PolicyRecord{
|
||||
{"other.example.com", now, now.Add(5 * 60 * time.Second), now, now, true, "", mtasts.Policy{Mode: mtasts.ModeNone, MaxAgeSeconds: 5 * 60}},
|
||||
@ -109,7 +107,7 @@ func TestDB(t *testing.T) {
|
||||
t.Fatalf("records mismatch, got %#v, expected %#v", records, expRecords)
|
||||
}
|
||||
|
||||
if _, err := lookup(context.Background(), dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
|
||||
if _, err := lookup(ctxbg, dns.Domain{ASCII: "other.example.com"}); err != ErrBackoff {
|
||||
t.Fatalf("got %#v, expected ErrBackoff", err)
|
||||
}
|
||||
|
||||
@ -126,7 +124,7 @@ func TestDB(t *testing.T) {
|
||||
|
||||
testGet := func(domain string, expPolicy *mtasts.Policy, expFresh bool, expErr error) {
|
||||
t.Helper()
|
||||
p, fresh, err := Get(context.Background(), resolver, dns.Domain{ASCII: domain})
|
||||
p, fresh, err := Get(ctxbg, resolver, dns.Domain{ASCII: domain})
|
||||
if (err == nil) != (expErr == nil) || err != nil && !errors.Is(err, expErr) {
|
||||
t.Fatalf("got err %v, expected %v", err, expErr)
|
||||
}
|
||||
|
@ -52,19 +52,19 @@ func refresh() int {
|
||||
// jitter to the timing. Each refresh is done in a new goroutine, so a single slow
|
||||
// refresh doesn't mess up the timing.
|
||||
func refresh1(ctx context.Context, resolver dns.Resolver, sleep func(d time.Duration)) (int, error) {
|
||||
db, err := database()
|
||||
db, err := database(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
now := timeNow()
|
||||
qdel := bstore.QueryDB[PolicyRecord](db)
|
||||
qdel := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||
qdel.FilterLess("LastUse", now.Add(-180*24*time.Hour))
|
||||
if _, err := qdel.Delete(); err != nil {
|
||||
return 0, fmt.Errorf("deleting old unused policies: %s", err)
|
||||
}
|
||||
|
||||
qup := bstore.QueryDB[PolicyRecord](db)
|
||||
qup := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||
qup.FilterLess("LastUpdate", now.Add(-12*time.Hour))
|
||||
prs, err := qup.List()
|
||||
if err != nil {
|
||||
@ -127,7 +127,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr
|
||||
log.Debug("refreshing mta-sts policy for domain", mlog.Field("domain", d))
|
||||
record, _, _, err := mtasts.LookupRecord(ctx, resolver, d)
|
||||
if err == nil && record.ID == pr.RecordID {
|
||||
qup := bstore.QueryDB[PolicyRecord](db)
|
||||
qup := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
||||
now := timeNow()
|
||||
update := PolicyRecord{
|
||||
@ -166,7 +166,7 @@ func refreshDomain(ctx context.Context, db *bstore.DB, resolver dns.Resolver, pr
|
||||
if record != nil {
|
||||
update["RecordID"] = record.ID
|
||||
}
|
||||
qup := bstore.QueryDB[PolicyRecord](db)
|
||||
qup := bstore.QueryDB[PolicyRecord](ctx, db)
|
||||
qup.FilterNonzero(PolicyRecord{Domain: pr.Domain, LastUpdate: pr.LastUpdate})
|
||||
if n, err := qup.UpdateFields(update); err != nil {
|
||||
log.Errorx("updating refreshed, modified policy in database", err)
|
||||
|
@ -26,7 +26,10 @@ import (
|
||||
"github.com/mjl-/mox/mtasts"
|
||||
)
|
||||
|
||||
var ctxbg = context.Background()
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
mox.Shutdown = ctxbg
|
||||
mox.ConfigStaticPath = "../testdata/mtasts/fake.conf"
|
||||
mox.Conf.Static.DataDir = "."
|
||||
|
||||
@ -40,7 +43,7 @@ func TestRefresh(t *testing.T) {
|
||||
}
|
||||
defer Close()
|
||||
|
||||
db, err := database()
|
||||
db, err := database(ctxbg)
|
||||
if err != nil {
|
||||
t.Fatalf("database: %s", err)
|
||||
}
|
||||
@ -66,7 +69,7 @@ func TestRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
pr := PolicyRecord{domain, time.Time{}, validEnd, lastUpdate, lastUse, backoff, recordID, policy}
|
||||
if err := db.Insert(&pr); err != nil {
|
||||
if err := db.Insert(ctxbg, &pr); err != nil {
|
||||
t.Fatalf("insert policy: %s", err)
|
||||
}
|
||||
}
|
||||
@ -132,7 +135,7 @@ func TestRefresh(t *testing.T) {
|
||||
t.Fatalf("bad sleep duration %v", d)
|
||||
}
|
||||
}
|
||||
if n, err := refresh1(context.Background(), resolver, sleep); err != nil || n != 3 {
|
||||
if n, err := refresh1(ctxbg, resolver, sleep); err != nil || n != 3 {
|
||||
t.Fatalf("refresh1: err %s, n %d, expected no error, 3", err, n)
|
||||
}
|
||||
if slept != 2 {
|
||||
@ -141,19 +144,19 @@ func TestRefresh(t *testing.T) {
|
||||
time.Sleep(time.Second / 10) // Give goroutine time to write result, before we cleanup the database.
|
||||
|
||||
// Should not do any more refreshes and return immediately.
|
||||
q := bstore.QueryDB[PolicyRecord](db)
|
||||
q := bstore.QueryDB[PolicyRecord](ctxbg, db)
|
||||
q.FilterNonzero(PolicyRecord{Domain: "policybad.mox.example"})
|
||||
if _, err := q.Delete(); err != nil {
|
||||
t.Fatalf("delete record that would be refreshed: %v", err)
|
||||
}
|
||||
mox.Context = context.Background()
|
||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
||||
mox.Context = ctxbg
|
||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||
mox.ShutdownCancel()
|
||||
n := refresh()
|
||||
if n != 0 {
|
||||
t.Fatalf("refresh found unexpected work, n %d", n)
|
||||
}
|
||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(context.Background())
|
||||
mox.Shutdown, mox.ShutdownCancel = context.WithCancel(ctxbg)
|
||||
}
|
||||
|
||||
type pipeListener struct {
|
||||
|
Reference in New Issue
Block a user