diff --git a/junk/filter.go b/junk/filter.go index 9093367..494d494 100644 --- a/junk/filter.go +++ b/junk/filter.go @@ -39,8 +39,8 @@ type word struct { Spam uint32 } -type wordscore struct { - Word string +type Wordscore struct { + Word string `bstore:"typename wordscore"` Ham uint32 Spam uint32 } @@ -57,7 +57,7 @@ type Params struct { RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."` } -var DBTypes = []any{wordscore{}} // Stored in DB. +var DBTypes = []any{Wordscore{}} // Stored in DB. type Filter struct { Params @@ -142,7 +142,7 @@ func OpenFilter(ctx context.Context, log mlog.Log, params Params, dbPath, bloomP bloom: bloom, } err = f.db.Read(ctx, func(tx *bstore.Tx) error { - wc := wordscore{Word: "-"} + wc := Wordscore{Word: "-"} err := tx.Get(&wc) f.hams = wc.Ham f.spams = wc.Spam @@ -277,11 +277,11 @@ func (f *Filter) Save() error { f.log.Debug("inserting words in junkfilter db", slog.Any("words", len(f.changed))) // start := time.Now() if f.isNew { - if err := f.db.HintAppend(true, wordscore{}); err != nil { + if err := f.db.HintAppend(true, Wordscore{}); err != nil { f.log.Errorx("hint appendonly", err) } else { defer func() { - err := f.db.HintAppend(false, wordscore{}) + err := f.db.HintAppend(false, Wordscore{}) f.log.Check(err, "restoring append hint") }() } @@ -289,17 +289,17 @@ func (f *Filter) Save() error { err := f.db.Write(context.Background(), func(tx *bstore.Tx) error { update := func(w string, ham, spam uint32) error { if f.isNew { - return tx.Insert(&wordscore{w, ham, spam}) + return tx.Insert(&Wordscore{w, ham, spam}) } - wc := wordscore{w, 0, 0} + wc := Wordscore{w, 0, 0} err := tx.Get(&wc) if err == bstore.ErrAbsent { - return tx.Insert(&wordscore{w, ham, spam}) + return tx.Insert(&Wordscore{w, ham, spam}) } else if err != nil { return err } - return tx.Update(&wordscore{w, ham, spam}) + return tx.Update(&Wordscore{w, ham, spam}) } if err := update("-", f.hams, f.spams); err != nil { return fmt.Errorf("storing total ham/spam message count: %s", err) @@ -331,7 +331,7 @@ func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]wo err := db.Read(ctx, func(tx *bstore.Tx) error { for _, w := range l { - wc := wordscore{Word: w} + wc := Wordscore{Word: w} if err := tx.Get(&wc); err == nil { dst[w] = word{wc.Ham, wc.Spam} } diff --git a/smtpserver/alias_test.go b/smtpserver/alias_test.go index a47005f..5efcbcc 100644 --- a/smtpserver/alias_test.go +++ b/smtpserver/alias_test.go @@ -6,8 +6,6 @@ import ( "strings" "testing" - "github.com/mjl-/bstore" - "github.com/mjl-/mox/dns" "github.com/mjl-/mox/smtp" "github.com/mjl-/mox/smtpclient" @@ -279,10 +277,7 @@ test email }) // Mark message as junk. - q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) - n, err := q.UpdateFields(map[string]any{"Junk": true}) - tcheck(t, err, "mark as junk") - tcompare(t, n, 1) + ts.xops.MessageFlagsAdd(ctxbg, pkglog, ts.acc, []int64{1}, []string{"$Junk"}) ts.run(func(client *smtpclient.Client) { mailFrom := "mjl@mox.example" diff --git a/smtpserver/server_test.go b/smtpserver/server_test.go index d32dd90..35dc15a 100644 --- a/smtpserver/server_test.go +++ b/smtpserver/server_test.go @@ -39,6 +39,7 @@ import ( "github.com/mjl-/mox/store" "github.com/mjl-/mox/subjectpass" "github.com/mjl-/mox/tlsrptdb" + "github.com/mjl-/mox/webops" ) var ctxbg = context.Background() @@ -99,6 +100,7 @@ type testserver struct { dnsbls []dns.Domain tlsmode smtpclient.TLSMode tlspkix bool + xops webops.XOps } const password0 = "te\u0301st \u00a0\u2002\u200a" // NFD and various unicode spaces. @@ -109,6 +111,21 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test log := mlog.New("smtpserver", nil) + checkf := func(ctx context.Context, err error, format string, args ...any) { + tcheck(t, err, fmt.Sprintf(format, args...)) + } + xops := webops.XOps{ + DBWrite: func(ctx context.Context, acc *store.Account, fn func(tx *bstore.Tx)) { + err := acc.DB.Write(ctx, func(tx *bstore.Tx) error { + fn(tx) + return nil + }) + tcheck(t, err, "db write") + }, + Checkf: checkf, + Checkuserf: checkf, + } + ts := testserver{ t: t, cid: 1, @@ -117,6 +134,7 @@ func newTestServer(t *testing.T, configPath string, resolver dns.Resolver) *test serverConfig: &tls.Config{ Certificates: []tls.Certificate{fakeCert(t, false)}, }, + xops: xops, } // Ensure session keys, for tests that check resume and authentication. @@ -622,7 +640,7 @@ func TestDelivery(t *testing.T) { } func tinsertmsg(t *testing.T, acc *store.Account, mailbox string, m *store.Message, msg string) { - mf, err := store.CreateMessageTemp(pkglog, "queue-dsn") + mf, err := store.CreateMessageTemp(pkglog, "insertmsg") tcheck(t, err, "temp message") defer os.Remove(mf.Name()) defer mf.Close() @@ -651,7 +669,7 @@ func tretrain(t *testing.T, acc *store.Account) { q := bstore.QueryDB[store.Message](ctxbg, acc.DB) q.FilterEqual("Expunged", false) q.FilterFn(func(m store.Message) bool { - return m.Flags.Junk || m.Flags.Notjunk + return m.Flags.Junk != m.Flags.Notjunk }) msgs, err := q.List() tcheck(t, err, "fetch messages") @@ -739,10 +757,14 @@ func TestSpam(t *testing.T) { }) // Mark the messages as having good reputation. - q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) - q.FilterEqual("Expunged", false) - _, err := q.UpdateFields(map[string]any{"Junk": false, "Notjunk": true}) - tcheck(t, err, "update junkiness") + var ids []int64 + err := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB).FilterEqual("Expunged", false).ForEach(func(m store.Message) error { + ids = append(ids, m.ID) + return nil + }) + tcheck(t, err, "get message ids") + ts.xops.MessageFlagsClear(ctxbg, pkglog, ts.acc, ids, []string{"$Junk"}) + ts.xops.MessageFlagsAdd(ctxbg, pkglog, ts.acc, ids, []string{"$NotJunk"}) // Message should now be accepted. ts.run(func(client *smtpclient.Client) { @@ -760,7 +782,7 @@ func TestSpam(t *testing.T) { // Undo dmarc pass, mark messages as junk, and train the filter. resolver.TXT = nil - q = bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) + q := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB) q.FilterEqual("Expunged", false) _, err = q.UpdateFields(map[string]any{"Junk": true, "Notjunk": false}) tcheck(t, err, "update junkiness") @@ -853,6 +875,7 @@ happens to come from forwarding mail server. n, err := bstore.QueryDB[store.Message](ctxbg, ts.acc.DB).UpdateFields(map[string]any{"Junk": true, "MsgFromValidated": true}) tcheck(t, err, "marking messages as junk") tcompare(t, n, 10) + tretrain(t, ts.acc) // Next delivery will fail, with negative "message From" signal. err = client.Deliver(ctxbg, mailFrom, rcptTo, int64(len(msgBad)), strings.NewReader(msgBad), false, false, false) @@ -946,6 +969,7 @@ func TestDMARCSent(t *testing.T) { nm := m nm.Junk = true tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) + nm = m nm.Notjunk = true tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) diff --git a/store/account.go b/store/account.go index 8f76f71..1b88269 100644 --- a/store/account.go +++ b/store/account.go @@ -37,6 +37,7 @@ import ( "log/slog" "os" "path/filepath" + "reflect" "runtime/debug" "slices" "sort" @@ -53,6 +54,7 @@ import ( "github.com/mjl-/mox/config" "github.com/mjl-/mox/dns" + "github.com/mjl-/mox/junk" "github.com/mjl-/mox/message" "github.com/mjl-/mox/metrics" "github.com/mjl-/mox/mlog" @@ -654,7 +656,7 @@ func (m Message) LoadPart(r io.ReaderAt) (message.Part, error) { func (m Message) NeedsTraining() bool { untrain := m.TrainedJunk != nil untrainJunk := untrain && *m.TrainedJunk - train := m.Junk || m.Notjunk && !(m.Junk && m.Notjunk) + train := m.Junk != m.Notjunk trainJunk := m.Junk return untrain != train || untrain && train && untrainJunk != trainJunk } @@ -1306,6 +1308,7 @@ func (a *Account) Close() error { // - Message ModSeq > 0, CreateSeq > 0, CreateSeq <= ModSeq. // - All messages have a nonzero ThreadID, and no cycles in ThreadParentID, and parent messages the same ThreadParentIDs tail. // - Annotations must have ModSeq > 0, CreateSeq > 0, ModSeq >= CreateSeq. +// - Recalculate junk filter (words and counts) and check they are the same. func (a *Account) CheckConsistency() error { var uidErrors []string // With a limit, could be many. var modseqErrors []string // With limit. @@ -1315,16 +1318,18 @@ func (a *Account) CheckConsistency() error { var threadAncestorErrors []string // With limit. var errmsgs []string - err := a.DB.Read(context.Background(), func(tx *bstore.Tx) error { + ctx := context.Background() + log := mlog.New("store", nil) + + a.Lock() + defer a.Unlock() + err := a.DB.Read(ctx, func(tx *bstore.Tx) error { nuv := NextUIDValidity{ID: 1} err := tx.Get(&nuv) if err != nil { return fmt.Errorf("fetching next uid validity: %v", err) } - // All message id's from database. For checking for unexpected files afterwards. - messageIDs := map[int64]struct{}{} - mailboxes := map[int64]Mailbox{} err = bstore.QueryTx[Mailbox](tx).ForEach(func(mb Mailbox) error { mailboxes[mb.ID] = mb @@ -1368,6 +1373,29 @@ func (a *Account) CheckConsistency() error { return fmt.Errorf("checking mailbox annotations: %v", err) } + // All message id's from database. For checking for unexpected files afterwards. + messageIDs := map[int64]struct{}{} + + // If configured, we'll be building up the junk filter for the messages, to compare + // against the on-disk junk filter. + var jf *junk.Filter + conf, _ := a.Conf() + if conf.JunkFilter != nil { + random := make([]byte, 16) + cryptorand.Read(random) + dbpath := filepath.Join(mox.DataDirPath("tmp"), fmt.Sprintf("junkfilter-check-%x.db", random)) + bloompath := filepath.Join(mox.DataDirPath("tmp"), fmt.Sprintf("junkfilter-check-%x.bloom", random)) + os.MkdirAll(filepath.Dir(dbpath), 0700) + defer os.Remove(dbpath) + defer os.Remove(bloompath) + jf, err = junk.NewFilter(ctx, log, conf.JunkFilter.Params, dbpath, bloompath) + if err != nil { + return fmt.Errorf("new junk filter: %v", err) + } + defer jf.Close() + } + var ntrained int + counts := map[int64]MailboxCounts{} err = bstore.QueryTx[Message](tx).ForEach(func(m Message) error { mc := counts[m.MailboxID] @@ -1387,6 +1415,7 @@ func (a *Account) CheckConsistency() error { if m.Expunged { return nil } + messageIDs[m.ID] = struct{}{} p := a.MessagePath(m.ID) st, err := os.Stat(p) @@ -1419,6 +1448,16 @@ func (a *Account) CheckConsistency() error { break } } + + if jf != nil { + if m.Junk != m.Notjunk { + ntrained++ + if _, err := a.TrainMessage(ctx, log, jf, m); err != nil { + return fmt.Errorf("train message: %v", err) + } + } + } + return nil }) if err != nil { @@ -1470,6 +1509,45 @@ func (a *Account) CheckConsistency() error { errmsgs = append(errmsgs, errmsg) } + // Compare on-disk junk filter with our recalculated filter. + if jf != nil { + load := func(f *junk.Filter) (map[junk.Wordscore]struct{}, error) { + words := map[junk.Wordscore]struct{}{} + err := bstore.QueryDB[junk.Wordscore](ctx, f.DB()).ForEach(func(w junk.Wordscore) error { + if w.Ham != 0 || w.Spam != 0 { + words[w] = struct{}{} + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("read junk filter wordscores: %v", err) + } + return words, nil + } + if err := jf.Save(); err != nil { + return fmt.Errorf("save recalculated junk filter: %v", err) + } + wordsExp, err := load(jf) + if err != nil { + return fmt.Errorf("read recalculated junk filter: %v", err) + } + + ajf, _, err := a.OpenJunkFilter(ctx, log) + if err != nil { + return fmt.Errorf("open account junk filter: %v", err) + } + defer ajf.Close() + wordsGot, err := load(ajf) + if err != nil { + return fmt.Errorf("read account junk filter: %v", err) + } + + if !reflect.DeepEqual(wordsGot, wordsExp) { + errmsg := fmt.Sprintf("unexpected values in junk filter, trained %d of %d\ngot:\n%v\nexpected:\n%v", ntrained, len(messageIDs), wordsGot, wordsExp) + errmsgs = append(errmsgs, errmsg) + } + } + return nil }) if err != nil { diff --git a/store/train.go b/store/train.go index 8064e92..d40440f 100644 --- a/store/train.go +++ b/store/train.go @@ -90,7 +90,7 @@ func (a *Account) RetrainMessages(ctx context.Context, log mlog.Log, tx *bstore. func (a *Account) RetrainMessage(ctx context.Context, log mlog.Log, tx *bstore.Tx, jf *junk.Filter, m *Message, absentOK bool) error { untrain := m.TrainedJunk != nil untrainJunk := untrain && *m.TrainedJunk - train := m.Junk || m.Notjunk && !(m.Junk && m.Notjunk) + train := m.Junk != m.Notjunk trainJunk := m.Junk if !untrain && !train || (untrain && train && untrainJunk == trainJunk) { @@ -144,7 +144,7 @@ func (a *Account) RetrainMessage(ctx context.Context, log mlog.Log, tx *bstore.T // TrainMessage trains the junk filter based on the current m.Junk/m.Notjunk flags, // disregarding m.TrainedJunk and not updating that field. func (a *Account) TrainMessage(ctx context.Context, log mlog.Log, jf *junk.Filter, m Message) (bool, error) { - if !m.Junk && !m.Notjunk || (m.Junk && m.Notjunk) { + if m.Junk == m.Notjunk { return false, nil }