From 394bdef39d91ab551a75ae4df263b79b415a82bd Mon Sep 17 00:00:00 2001 From: Mechiel Lukkien Date: Wed, 26 Feb 2025 14:40:47 +0100 Subject: [PATCH] In storage consistency checks, verify the junk filter has the expected word counts Fix up a test or two. Simplify the XOR logic when we train the junk filter: Only if junk or nonjunk is set, but not when both (or none are set). i.e. when the values aren't the same. Locking the account when we do consistency checks prevents spurious test failures that may have been introduced in the previous commit. --- junk/filter.go | 22 +++++----- smtpserver/alias_test.go | 7 +--- smtpserver/server_test.go | 38 +++++++++++++---- store/account.go | 88 ++++++++++++++++++++++++++++++++++++--- store/train.go | 4 +- 5 files changed, 128 insertions(+), 31 deletions(-) diff --git a/junk/filter.go b/junk/filter.go index 9093367..494d494 100644 --- a/junk/filter.go +++ b/junk/filter.go @@ -39,8 +39,8 @@ type word struct { Spam uint32 } -type wordscore struct { - Word string +type Wordscore struct { + Word string `bstore:"typename wordscore"` Ham uint32 Spam uint32 } @@ -57,7 +57,7 @@ type Params struct { RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."` } -var DBTypes = []any{wordscore{}} // Stored in DB. +var DBTypes = []any{Wordscore{}} // Stored in DB. type Filter struct { Params @@ -142,7 +142,7 @@ func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomP bloom: bloom, } err = f.db.Read(ctx, func(tx *bstore.Tx) error { - wc := wordscore{Word: "-"} + wc := Wordscore{Word: "-"} err := tx.Get(&wc) f.hams = wc.Ham f.spams = wc.Spam @@ -277,11 +277,11 @@ func (f *Filter) Save() error { f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed))) // start := time.Now() if f.isNew { - if err := f.db.HintAppend(true, wordscore{}); err != nil { + if err := f.db.HintAppend(true, Wordscore{}); err != nil { f.log.Errorx("hint appendonly", err) } else { defer func() { - err := f.db.HintAppend(false, wordscore{}) + err := f.db.HintAppend(false, Wordscore{}) f.log.Check(err, "restoring append hint") }() } @@ -289,17 +289,17 @@ func (f *Filter) Save() error { err := f.db.Write(context.Background(), func(tx *bstore.Tx) error { update := func(w string, ham, spam uint32) error { if f.isNew { - return tx.Insert(&wordscore{w, ham, spam}) + return tx.Insert(&Wordscore{w, ham, spam}) } - wc := wordscore{w, 0, 0} + wc := Wordscore{w, 0, 0} err := tx.Get(&wc) if err == bstore.ErrAbsent { - return tx.Insert(&wordscore{w, ham, spam}) + return tx.Insert(&Wordscore{w, ham, spam}) } else if err != nil { return err } - return tx.Update(&wordscore{w, ham, spam}) + return tx.Update(&Wordscore{w, ham, spam}) } if err := update("-", f.hams, f.spams); err != nil { return fmt.Errorf("storing total ham/spam message count: %s", err) @@ -331,7 +331,7 @@ func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]wo err := db.Read(ctx, func(tx *bstore.Tx) error { for _, w := range l { - wc := wordscore{Word: w} + wc := Wordscore{Word: w} if err := tx.Get(&wc); err == nil { dst[w] = word{wc.Ham, wc.Spam} } diff --git a/smtpserver/alias_test.go b/smtpserver/alias_test.go index a47005f..5efcbcc 100644 --- a/smtpserver/alias_test.go +++ b/smtpserver/alias_test.go @@ -6,8 +6,6 @@ import ( "strings" "testing" - "github.com/mjl-/bstore" - "github.com/mjl-/mox/dns" "github.com/mjl-/mox/smtp" "github.com/mjl-/mox/smtpclient" @@ -279,10 +277,7 @@ test email }) // Mark message as junk. - q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) - n, err := q.UpdateFields(map[string]any{"Junk": true}) - tcheck(t, err, "mark as junk") - tcompare(t, n, 1) + ts.xops.MessageFlagsAdd(ctxbg, pkglog, ts.acc, []int64{1}, []string{"$Junk"}) ts.run(func(client *smtpclient.Client) { mailFrom := "mjl@mox.example" diff --git a/smtpserver/server_test.go b/smtpserver/server_test.go index d32dd90..35dc15a 100644 --- a/smtpserver/server_test.go +++ b/smtpserver/server_test.go @@ -39,6 +39,7 @@ import ( "github.com/mjl-/mox/store" "github.com/mjl-/mox/subjectpass" "github.com/mjl-/mox/tlsrptdb" + "github.com/mjl-/mox/webops" ) var ctxbg = context.Background() @@ -99,6 +100,7 @@ type testserver struct { dnsbls []dns.Domain tlsmode smtpclient.TLSMode tlspkix bool + xops webops.XOps } const password0 = "te\u0301st \u00a0\u2002\u200a" // NFD and various unicode spaces. @@ -109,6 +111,21 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test log := mlog.New("smtpserver", nil) + checkf := func(ctx context.Context, err error, format string, args ...any) { + tcheck(t, err, fmt.Sprintf(format, args...)) + } + xops := webops.XOps{ + DBWrite: func(ctx context.Context, acc *store.Account, fn func(tx *bstore.Tx)) { + err := acc.DB.Write(ctx, func(tx *bstore.Tx) error { + fn(tx) + return nil + }) + tcheck(t, err, "db write") + }, + Checkf: checkf, + Checkuserf: checkf, + } + ts := testserver{ t: t, cid: 1, @@ -117,6 +134,7 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test serverConfig: &tls.Config{ Certificates: []tls.Certificate{fakeCert(t, false)}, }, + xops: xops, } // Ensure session keys, for tests that check resume and authentication. @@ -622,7 +640,7 @@ func TestDelivery(t *testing.T) { } func tinsertmsg(t *testing.T, acc *store.Account, mailbox string, m *store.Message, msg string) { - mf, err := store.CreateMessageTemp(pkglog, "queue-dsn") + mf, err := store.CreateMessageTemp(pkglog, "insertmsg") tcheck(t, err, "temp message") defer os.Remove(mf.Name()) defer mf.Close() @@ -651,7 +669,7 @@ func tretrain(t *testing.T, acc *store.Account) { q := bstore.QueryDB[store.Message](ctxbg, acc.DB) q.FilterEqual("Expunged", false) q.FilterFn(func(m store.Message) bool { - return m.Flags.Junk || m.Flags.Notjunk + return m.Flags.Junk != m.Flags.Notjunk }) msgs, err := q.List() tcheck(t, err, "fetch messages") @@ -739,10 +757,14 @@ func TestSpam(t *testing.T) { }) // Mark the messages as having good reputation. - q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) - q.FilterEqual("Expunged", false) - _, err := q.UpdateFields(map[string]any{"Junk": false, "Notjunk": true}) - tcheck(t, err, "update junkiness") + var ids []int64 + err := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB).FilterEqual("Expunged", false).ForEach(func(m store.Message) error { + ids = append(ids, m.ID) + return nil + }) + tcheck(t, err, "get message ids") + ts.xops.MessageFlagsClear(ctxbg, pkglog, ts.acc, ids, []string{"$Junk"}) + ts.xops.MessageFlagsAdd(ctxbg, pkglog, ts.acc, ids, []string{"$NotJunk"}) // Message should now be accepted. ts.run(func(client *smtpclient.Client) { @@ -760,7 +782,7 @@ func TestSpam(t *testing.T) { // Undo dmarc pass, mark messages as junk, and train the filter. resolver.TXT = nil - q = bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) + q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) q.FilterEqual("Expunged", false) _, err = q.UpdateFields(map[string]any{"Junk": true, "Notjunk": false}) tcheck(t, err, "update junkiness") @@ -853,6 +875,7 @@ happens to come from forwarding mail server. n, err := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB).UpdateFields(map[string]any{"Junk": true, "MsgFromValidated": true}) tcheck(t, err, "marking messages as junk") tcompare(t, n, 10) + tretrain(t, ts.acc) // Next delivery will fail, with negative "message From" signal. err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(msgBad)), strings.NewReader(msgBad), false, false, false) @@ -946,6 +969,7 @@ func TestDMARCSent(t *testing.T) { nm := m nm.Junk = true tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) + nm = m nm.Notjunk = true tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) diff --git a/store/account.go b/store/account.go index 8f76f71..1b88269 100644 --- a/store/account.go +++ b/store/account.go @@ -37,6 +37,7 @@ import ( "log/slog" "os" "path/filepath" + "reflect" "runtime/debug" "slices" "sort" @@ -53,6 +54,7 @@ import ( "github.com/mjl-/mox/config" "github.com/mjl-/mox/dns" + "github.com/mjl-/mox/junk" "github.com/mjl-/mox/message" "github.com/mjl-/mox/metrics" "github.com/mjl-/mox/mlog" @@ -654,7 +656,7 @@ func (m Message) LoadPart(r io.ReaderAt) (message.Part, error) { func (m Message) NeedsTraining() bool { untrain := m.TrainedJunk != nil untrainJunk := untrain && *m.TrainedJunk - train := m.Junk || m.Notjunk && !(m.Junk && m.Notjunk) + train := m.Junk != m.Notjunk trainJunk := m.Junk return untrain != train || untrain && train && untrainJunk != trainJunk } @@ -1306,6 +1308,7 @@ func (a *Account) Close() error { // - Message ModSeq > 0, CreateSeq > 0, CreateSeq <= ModSeq. // - All messages have a nonzero ThreadID, and no cycles in ThreadParentID, and parent messages the same ThreadParentIDs tail. // - Annotations must have ModSeq > 0, CreateSeq > 0, ModSeq >= CreateSeq. +// - Recalculate junk filter (words and counts) and check they are the same. func (a *Account) CheckConsistency() error { var uidErrors []string // With a limit, could be many. var modseqErrors []string // With limit. @@ -1315,16 +1318,18 @@ func (a *Account) CheckConsistency() error { var threadAncestorErrors []string // With limit. var errmsgs []string - err := a.DB.Read(context.Background(), func(tx *bstore.Tx) error { + ctx := context.Background() + log := mlog.New("store", nil) + + a.Lock() + defer a.Unlock() + err := a.DB.Read(ctx, func(tx *bstore.Tx) error { nuv := NextUIDValidity{ID: 1} err := tx.Get(&nuv) if err != nil { return fmt.Errorf("fetching next uid validity: %v", err) } - // All message id's from database. For checking for unexpected files afterwards. - messageIDs := map[int64]struct{}{} - mailboxes := map[int64]Mailbox{} err = bstore.QueryTx[Mailbox](tx).ForEach(func(mb Mailbox) error { mailboxes[mb.ID] = mb @@ -1368,6 +1373,29 @@ func (a *Account) CheckConsistency() error { return fmt.Errorf("checking mailbox annotations: %v", err) } + // All message id's from database. For checking for unexpected files afterwards. + messageIDs := map[int64]struct{}{} + + // If configured, we'll be building up the junk filter for the messages, to compare + // against the on-disk junk filter. + var jf *junk.Filter + conf, _ := a.Conf() + if conf.JunkFilter != nil { + random := make([]byte, 16) + cryptorand.Read(random) + dbpath := filepath.Join(mox.DataDirPath("tmp"), fmt.Sprintf("junkfilter-check-%x.db", random)) + bloompath := filepath.Join(mox.DataDirPath("tmp"), fmt.Sprintf("junkfilter-check-%x.bloom", random)) + os.MkdirAll(filepath.Dir(dbpath), 0700) + defer os.Remove(dbpath) + defer os.Remove(bloompath) + jf, err = junk.NewFilter(ctx, log, conf.JunkFilter.Params, dbpath, bloompath) + if err != nil { + return fmt.Errorf("new junk filter: %v", err) + } + defer jf.Close() + } + var ntrained int + counts := map[int64]MailboxCounts{} err = bstore.QueryTx[Message](tx).ForEach(func(m Message) error { mc := counts[m.MailboxID] @@ -1387,6 +1415,7 @@ func (a *Account) CheckConsistency() error { if m.Expunged { return nil } + messageIDs[m.ID] = struct{}{} p := a.MessagePath(m.ID) st, err := os.Stat(p) @@ -1419,6 +1448,16 @@ func (a *Account) CheckConsistency() error { break } } + + if jf != nil { + if m.Junk != m.Notjunk { + ntrained++ + if _, err := a.TrainMessage(ctx, log, jf, m); err != nil { + return fmt.Errorf("train message: %v", err) + } + } + } + return nil }) if err != nil { @@ -1470,6 +1509,45 @@ func (a *Account) CheckConsistency() error { errmsgs = append(errmsgs, errmsg) } + // Compare on-disk junk filter with our recalculated filter. + if jf != nil { + load := func(f *junk.Filter) (map[junk.Wordscore]struct{}, error) { + words := map[junk.Wordscore]struct{}{} + err := bstore.QueryDB[junk.Wordscore](ctx, f.DB()).ForEach(func(w junk.Wordscore) error { + if w.Ham != 0 || w.Spam != 0 { + words[w] = struct{}{} + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("read junk filter wordscores: %v", err) + } + return words, nil + } + if err := jf.Save(); err != nil { + return fmt.Errorf("save recalculated junk filter: %v", err) + } + wordsExp, err := load(jf) + if err != nil { + return fmt.Errorf("read recalculated junk filter: %v", err) + } + + ajf, _, err := a.OpenJunkFilter(ctx, log) + if err != nil { + return fmt.Errorf("open account junk filter: %v", err) + } + defer ajf.Close() + wordsGot, err := load(ajf) + if err != nil { + return fmt.Errorf("read account junk filter: %v", err) + } + + if !reflect.DeepEqual(wordsGot, wordsExp) { + errmsg := fmt.Sprintf("unexpected values in junk filter, trained %d of %d\ngot:\n%v\nexpected:\n%v", ntrained, len(messageIDs), wordsGot, wordsExp) + errmsgs = append(errmsgs, errmsg) + } + } + return nil }) if err != nil { diff --git a/store/train.go b/store/train.go index 8064e92..d40440f 100644 --- a/store/train.go +++ b/store/train.go @@ -90,7 +90,7 @@ func (a *Account) RetrainMessages(ctx context.Context, log mlog.Log, tx *bstore. func (a *Account) RetrainMessage(ctx context.Context, log mlog.Log, tx *bstore.Tx, jf *junk.Filter, m *Message, absentOK bool) error { untrain := m.TrainedJunk != nil untrainJunk := untrain && *m.TrainedJunk - train := m.Junk || m.Notjunk && !(m.Junk && m.Notjunk) + train := m.Junk != m.Notjunk trainJunk := m.Junk if !untrain && !train || (untrain && train && untrainJunk == trainJunk) { @@ -144,7 +144,7 @@ func (a *Account) RetrainMessage(ctx context.Context, log mlog.Log, tx *bstore.T // TrainMessage trains the junk filter based on the current m.Junk/m.Notjunk flags, // disregarding m.TrainedJunk and not updating that field. func (a *Account) TrainMessage(ctx context.Context, log mlog.Log, jf *junk.Filter, m Message) (bool, error) { - if !m.Junk && !m.Notjunk || (m.Junk && m.Notjunk) { + if m.Junk == m.Notjunk { return false, nil }