diff --git a/junk.go b/junk.go index aa00bfe..b6e825c 100644 --- a/junk.go +++ b/junk.go @@ -129,10 +129,14 @@ func cmdJunkCheck(c *cmd) { } }() - prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), args[0]) + result, err := f.ClassifyMessagePath(context.Background(), args[0]) xcheckf(err, "testing mail") - fmt.Printf("%.6f\n", prob) + sig := "significant" + if !result.Significant { + sig = "not significant" + } + fmt.Printf("%.6f, %s\n", result.Probability, sig) } func cmdJunkTest(c *cmd) { @@ -159,21 +163,21 @@ func cmdJunkTest(c *cmd) { xcheckf(err, "readdir %q", dir) for _, fi := range files { path := filepath.Join(dir, fi.Name()) - prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path) + result, err := f.ClassifyMessagePath(context.Background(), path) if err != nil { log.Printf("classify message %q: %s", path, err) continue } - if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold { + if ham && result.Probability < a.spamThreshold || !ham && result.Probability > a.spamThreshold { ok++ } else { bad++ } - if ham && prob > a.spamThreshold { - fmt.Printf("ham %q: %.4f\n", path, prob) + if ham && result.Probability > a.spamThreshold { + fmt.Printf("ham %q: %.4f\n", path, result.Probability) } - if !ham && prob < a.spamThreshold { - fmt.Printf("spam %q: %.4f\n", path, prob) + if !ham && result.Probability < a.spamThreshold { + fmt.Printf("spam %q: %.4f\n", path, result.Probability) } } return ok, bad @@ -251,22 +255,22 @@ messages are shuffled, with optional random seed.` testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) { for _, name := range files { path := filepath.Join(dir, name) - prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path) + result, err := f.ClassifyMessagePath(context.Background(), path) if err != nil { // log.Infof("%s: %s", path, err) malformed++ continue } - if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold { + if ham && result.Probability < a.spamThreshold || !ham && result.Probability > a.spamThreshold { ok++ } else { bad++ } - if ham && prob > a.spamThreshold { - fmt.Printf("ham %q: %.4f\n", path, prob) + if ham && result.Probability > a.spamThreshold { + fmt.Printf("ham %q: %.4f\n", path, result.Probability) } - if !ham && prob < a.spamThreshold { - fmt.Printf("spam %q: %.4f\n", path, prob) + if !ham && result.Probability < a.spamThreshold { + fmt.Printf("spam %q: %.4f\n", path, result.Probability) } } return @@ -367,21 +371,19 @@ func cmdJunkPlay(c *cmd) { var words map[string]struct{} path := filepath.Join(msg.dir, msg.filename) if !msg.sent { - var prob float64 - var err error - prob, words, _, _, err = f.ClassifyMessagePath(context.Background(), path) + result, err := f.ClassifyMessagePath(context.Background(), path) if err != nil { nbad++ return } if msg.ham { - if prob < a.spamThreshold { + if result.Probability < a.spamThreshold { nhamok++ } else { nhambad++ } } else { - if prob > a.spamThreshold { + if result.Probability > a.spamThreshold { nspamok++ } else { nspambad++ diff --git a/junk/filter.go b/junk/filter.go index d37e53c..af2d9bf 100644 --- a/junk/filter.go +++ b/junk/filter.go @@ -351,9 +351,9 @@ type WordScore struct { } // ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words. -func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, hams, spams []WordScore, rerr error) { +func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (Result, error) { if f.closed { - return 0, nil, nil, errClosed + return Result{}, errClosed } var hamHigh float64 = 0 @@ -391,7 +391,7 @@ func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) ( fetched := map[string]word{} if len(lookupWords) > 0 { if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil { - return 0, nil, nil, err + return Result{}, err } for w, c := range fetched { delete(expect, w) @@ -486,18 +486,34 @@ func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) ( f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam)) prob := 1 / (1 + math.Pow(math.E, eta)) - return prob, topHam, topSpam, nil + + // We want at least some positive signals, otherwise a few negative signals can + // mark incoming messages as spam too easily. If we have no negative signals, more + // messages will be classified as ham and accepted. This is fine, the user will + // classify it such, and retrain the filter. We mostly want to avoid rejecting too + // much when there isn't enough signal. + significant := f.hams >= 50 + + return Result{prob, significant, words, topHam, topSpam}, nil +} + +// Result is a successful classification, whether positive or negative. +type Result struct { + Probability float64 // Between 0 (ham) and 1 (spam). + Significant bool // If true, enough classified words are available to base decisions on. + Words map[string]struct{} + Hams, Spams []WordScore } // ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file. -func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) { +func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (Result, error) { if f.closed { - return 0, nil, nil, nil, errClosed + return Result{}, errClosed } mf, err := os.Open(path) if err != nil { - return 0, nil, nil, nil, err + return Result{}, err } defer func() { err := mf.Close() @@ -505,17 +521,17 @@ func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probabil }() fi, err := mf.Stat() if err != nil { - return 0, nil, nil, nil, err + return Result{}, err } return f.ClassifyMessageReader(ctx, mf, fi.Size()) } -func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) { +func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (Result, error) { m, err := message.EnsurePart(f.log.Logger, false, mf, size) if err != nil && errors.Is(err, message.ErrBadContentType) { // Invalid content-type header is a sure sign of spam. //f.log.Infox("parsing content", err) - return 1, nil, nil, nil, nil + return Result{Probability: 1, Significant: true}, nil } return f.ClassifyMessage(ctx, m) } @@ -523,15 +539,12 @@ func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size // ClassifyMessage parses the mail message in r and returns the spam probability // (between 0 and 1), along with the tokenized words found in the message, and the // ham and spam words and their scores used. -func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) { - var err error - words, err = f.ParseMessage(m) +func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (Result, error) { + words, err := f.ParseMessage(m) if err != nil { - return 0, nil, nil, nil, err + return Result{}, err } - - probability, hams, spams, err = f.ClassifyWords(ctx, words) - return probability, words, hams, spams, err + return f.ClassifyWords(ctx, words) } // Train adds the words of a single message to the filter. diff --git a/junk/filter_test.go b/junk/filter_test.go index 6aea0ef..404f249 100644 --- a/junk/filter_test.go +++ b/junk/filter_test.go @@ -78,16 +78,16 @@ func TestFilter(t *testing.T) { return } - prob, _, _, _, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0])) + result, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0])) tcheck(t, err, "classify ham message") - if prob > 0.1 { - t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob) + if result.Probability > 0.1 { + t.Fatalf("trained ham file has prob %v, expected <= 0.1", result.Probability) } - prob, _, _, _, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0])) + result, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0])) tcheck(t, err, "classify spam message") - if prob < 0.9 { - t.Fatalf("trained spam file has prob %v, expected > 0.9", prob) + if result.Probability < 0.9 { + t.Fatalf("trained spam file has prob %v, expected > 0.9", result.Probability) } err = f.Close() @@ -145,18 +145,18 @@ func TestFilter(t *testing.T) { // Classify and verify. _, err = hamf.Seek(0, 0) tcheck(t, err, "seek ham message") - prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize) + result, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize) tcheck(t, err, "classify ham") - if prob > 0.1 { - t.Fatalf("got prob %v, expected <= 0.1", prob) + if result.Probability > 0.1 { + t.Fatalf("got prob %v, expected <= 0.1", result.Probability) } _, err = spamf.Seek(0, 0) tcheck(t, err, "seek spam message") - prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize) + result, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize) tcheck(t, err, "classify spam") - if prob < 0.9 { - t.Fatalf("got prob %v, expected >= 0.9", prob) + if result.Probability < 0.9 { + t.Fatalf("got prob %v, expected >= 0.9", result.Probability) } // Untrain ham & spam. @@ -185,18 +185,18 @@ func TestFilter(t *testing.T) { // Classify again, should be unknown. _, err = hamf.Seek(0, 0) tcheck(t, err, "seek ham message") - prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize) + result, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize) tcheck(t, err, "classify ham") - if math.Abs(prob-0.5) > 0.1 { - t.Fatalf("got prob %v, expected 0.5 +-0.1", prob) + if math.Abs(result.Probability-0.5) > 0.1 { + t.Fatalf("got prob %v, expected 0.5 +-0.1", result.Probability) } _, err = spamf.Seek(0, 0) tcheck(t, err, "seek spam message") - prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize) + result, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize) tcheck(t, err, "classify spam") - if math.Abs(prob-0.5) > 0.1 { - t.Fatalf("got prob %v, expected 0.5 +-0.1", prob) + if math.Abs(result.Probability-0.5) > 0.1 { + t.Fatalf("got prob %v, expected 0.5 +-0.1", result.Probability) } err = f.Close() diff --git a/smtpserver/alias_test.go b/smtpserver/alias_test.go index 166394b..dad2a84 100644 --- a/smtpserver/alias_test.go +++ b/smtpserver/alias_test.go @@ -230,7 +230,6 @@ test email `, "\n", "\r\n") ts.run(func(err error, client *smtpclient.Client) { - t.Helper() mailFrom := "mjl@mox.example" rcptTo := []string{"private@mox.example", "móx@mox.example"} if err == nil { @@ -239,11 +238,10 @@ test email } ts.smtpErr(err, nil) - ts.checkCount("Inbox", 0) // Not receiving for mjl@ due to msgfrom, and not móx@ due to rcpt to. + ts.checkCount("Inbox", 1) // Receiving once. For explicit móx@ recipient, not for mjl@ due to msgfrom, and another again for móx@ due to rcpt to. }) ts.run(func(err error, client *smtpclient.Client) { - t.Helper() mailFrom := "mjl@mox.example" rcptTo := "private@mox.example" if err == nil { @@ -251,7 +249,7 @@ test email } ts.smtpErr(err, nil) - ts.checkCount("Inbox", 1) // Only receiving for móx@mox.example, not mjl@. + ts.checkCount("Inbox", 2) // Only receiving 1 new message compared to previous, for móx@mox.example, not mjl@. }) msg = strings.ReplaceAll(`From: diff --git a/smtpserver/analyze.go b/smtpserver/analyze.go index 01030cb..1a68f93 100644 --- a/smtpserver/analyze.go +++ b/smtpserver/analyze.go @@ -528,7 +528,7 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver err := f.Close() log.Check(err, "closing junkfilter") }() - contentProb, _, hams, spams, err := f.ClassifyMessageReader(ctx, store.FileMsgReader(d.m.MsgPrefix, d.dataFile), d.m.Size) + result, err := f.ClassifyMessageReader(ctx, store.FileMsgReader(d.m.MsgPrefix, d.dataFile), d.m.Size) if err != nil { log.Errorx("testing for spam", err) addReasonText("classify message error: %v", err) @@ -587,11 +587,12 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver reason = reasonJunkContentStrict thresholdRemark = " (stricter due to recipient address not in to/cc header)" } - accept = contentProb <= threshold - junkSubjectpass = contentProb < threshold-0.2 + accept = result.Probability <= threshold || (!result.Significant && !suspiciousIPrevFail) + junkSubjectpass = result.Probability < threshold-0.2 log.Info("content analyzed", slog.Bool("accept", accept), - slog.Float64("contentprob", contentProb), + slog.Float64("contentprob", result.Probability), + slog.Bool("contentsignificant", result.Significant), slog.Bool("subjectpass", junkSubjectpass)) s := "content: " @@ -600,9 +601,12 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver } else { s += "junk" } - s += fmt.Sprintf(", spamscore %.2f, threshold %.2f%s", contentProb, threshold, thresholdRemark) + if !result.Significant { + s += " (not significant)" + } + s += fmt.Sprintf(", spamscore %.2f, threshold %.2f%s", result.Probability, threshold, thresholdRemark) s += " (ham words: " - for i, w := range hams { + for i, w := range result.Hams { if i > 0 { s += ", " } @@ -613,7 +617,7 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver s += fmt.Sprintf("%s %.3f", word, w.Score) } s += "), (spam words: " - for i, w := range spams { + for i, w := range result.Spams { if i > 0 { s += ", " } diff --git a/smtpserver/server_test.go b/smtpserver/server_test.go index 3d22cb5..103475a 100644 --- a/smtpserver/server_test.go +++ b/smtpserver/server_test.go @@ -670,6 +670,8 @@ func TestSpam(t *testing.T) { for i := 0; i < 3; i++ { nm := m tinsertmsg(t, ts.acc, "Inbox", &nm, deliverMessage) + nm = m + tinsertmsg(t, ts.acc, "mjl2", &nm, deliverMessage) } // Delivery from sender with bad reputation should fail. @@ -922,16 +924,22 @@ func TestDMARCSent(t *testing.T) { // Update DNS for an SPF pass, and DMARC pass. resolver.TXT["example.org."] = []string{"v=spf1 ip4:127.0.0.10 -all"} - // Insert spammy messages not related to the test message. + // Insert hammy & spammy messages not related to the test message. m := store.Message{ MailFrom: "remote@test.example", RcptToLocalpart: smtp.Localpart("mjl"), RcptToDomain: "mox.example", - Flags: store.Flags{Seen: true, Junk: true}, + Flags: store.Flags{Seen: true}, Size: int64(len(deliverMessage)), } - for i := 0; i < 3; i++ { + // We need at least 50 ham messages for the junk filter to become significant. We + // offset it with negative messages for mediocre score. + for i := 0; i < 50; i++ { nm := m + nm.Junk = true + tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) + nm = m + nm.Notjunk = true tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) } tretrain(t, ts.acc)