do not use results from junk filter if we have less than 50 positive classifications to base the decision on

useful for new accounts. we don't want to start rejecting incoming messages for
having a score near 0.5 because of too little training material. we err on the
side of allowing messages in. the user will mark them as junk, training the
filter. once enough non-junk has come in, we'll start the actual filtering.

for issue #64 by x8x, and i've also seen this concern on matrix
This commit is contained in:
Mechiel Lukkien 2025-01-23 22:55:50 +01:00
parent 8fac9f862b
commit 6aa2139a54
No known key found for this signature in database
6 changed files with 93 additions and 68 deletions

40
junk.go
View File

@ -129,10 +129,14 @@ func cmdJunkCheck(c *cmd) {
} }
}() }()
prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), args[0]) result, err := f.ClassifyMessagePath(context.Background(), args[0])
xcheckf(err, "testing mail") xcheckf(err, "testing mail")
fmt.Printf("%.6f\n", prob) sig := "significant"
if !result.Significant {
sig = "not significant"
}
fmt.Printf("%.6f, %s\n", result.Probability, sig)
} }
func cmdJunkTest(c *cmd) { func cmdJunkTest(c *cmd) {
@ -159,21 +163,21 @@ func cmdJunkTest(c *cmd) {
xcheckf(err, "readdir %q", dir) xcheckf(err, "readdir %q", dir)
for _, fi := range files { for _, fi := range files {
path := filepath.Join(dir, fi.Name()) path := filepath.Join(dir, fi.Name())
prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path) result, err := f.ClassifyMessagePath(context.Background(), path)
if err != nil { if err != nil {
log.Printf("classify message %q: %s", path, err) log.Printf("classify message %q: %s", path, err)
continue continue
} }
if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold { if ham && result.Probability < a.spamThreshold || !ham && result.Probability > a.spamThreshold {
ok++ ok++
} else { } else {
bad++ bad++
} }
if ham && prob > a.spamThreshold { if ham && result.Probability > a.spamThreshold {
fmt.Printf("ham %q: %.4f\n", path, prob) fmt.Printf("ham %q: %.4f\n", path, result.Probability)
} }
if !ham && prob < a.spamThreshold { if !ham && result.Probability < a.spamThreshold {
fmt.Printf("spam %q: %.4f\n", path, prob) fmt.Printf("spam %q: %.4f\n", path, result.Probability)
} }
} }
return ok, bad return ok, bad
@ -251,22 +255,22 @@ messages are shuffled, with optional random seed.`
testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) { testDir := func(dir string, files []string, ham bool) (ok, bad, malformed int) {
for _, name := range files { for _, name := range files {
path := filepath.Join(dir, name) path := filepath.Join(dir, name)
prob, _, _, _, err := f.ClassifyMessagePath(context.Background(), path) result, err := f.ClassifyMessagePath(context.Background(), path)
if err != nil { if err != nil {
// log.Infof("%s: %s", path, err) // log.Infof("%s: %s", path, err)
malformed++ malformed++
continue continue
} }
if ham && prob < a.spamThreshold || !ham && prob > a.spamThreshold { if ham && result.Probability < a.spamThreshold || !ham && result.Probability > a.spamThreshold {
ok++ ok++
} else { } else {
bad++ bad++
} }
if ham && prob > a.spamThreshold { if ham && result.Probability > a.spamThreshold {
fmt.Printf("ham %q: %.4f\n", path, prob) fmt.Printf("ham %q: %.4f\n", path, result.Probability)
} }
if !ham && prob < a.spamThreshold { if !ham && result.Probability < a.spamThreshold {
fmt.Printf("spam %q: %.4f\n", path, prob) fmt.Printf("spam %q: %.4f\n", path, result.Probability)
} }
} }
return return
@ -367,21 +371,19 @@ func cmdJunkPlay(c *cmd) {
var words map[string]struct{} var words map[string]struct{}
path := filepath.Join(msg.dir, msg.filename) path := filepath.Join(msg.dir, msg.filename)
if !msg.sent { if !msg.sent {
var prob float64 result, err := f.ClassifyMessagePath(context.Background(), path)
var err error
prob, words, _, _, err = f.ClassifyMessagePath(context.Background(), path)
if err != nil { if err != nil {
nbad++ nbad++
return return
} }
if msg.ham { if msg.ham {
if prob < a.spamThreshold { if result.Probability < a.spamThreshold {
nhamok++ nhamok++
} else { } else {
nhambad++ nhambad++
} }
} else { } else {
if prob > a.spamThreshold { if result.Probability > a.spamThreshold {
nspamok++ nspamok++
} else { } else {
nspambad++ nspambad++

View File

@ -351,9 +351,9 @@ type WordScore struct {
} }
// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words. // ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, hams, spams []WordScore, rerr error) { func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (Result, error) {
if f.closed { if f.closed {
return 0, nil, nil, errClosed return Result{}, errClosed
} }
var hamHigh float64 = 0 var hamHigh float64 = 0
@ -391,7 +391,7 @@ func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (
fetched := map[string]word{} fetched := map[string]word{}
if len(lookupWords) > 0 { if len(lookupWords) > 0 {
if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil { if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
return 0, nil, nil, err return Result{}, err
} }
for w, c := range fetched { for w, c := range fetched {
delete(expect, w) delete(expect, w)
@ -486,18 +486,34 @@ func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (
f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam)) f.log.Debug("top words", slog.Any("hams", topHam), slog.Any("spams", topSpam))
prob := 1 / (1 + math.Pow(math.E, eta)) prob := 1 / (1 + math.Pow(math.E, eta))
return prob, topHam, topSpam, nil
// We want at least some positive signals, otherwise a few negative signals can
// mark incoming messages as spam too easily. If we have no negative signals, more
// messages will be classified as ham and accepted. This is fine, the user will
// classify it such, and retrain the filter. We mostly want to avoid rejecting too
// much when there isn't enough signal.
significant := f.hams >= 50
return Result{prob, significant, words, topHam, topSpam}, nil
}
// Result is a successful classification, whether positive or negative.
type Result struct {
Probability float64 // Between 0 (ham) and 1 (spam).
Significant bool // If true, enough classified words are available to base decisions on.
Words map[string]struct{}
Hams, Spams []WordScore
} }
// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file. // ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) { func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (Result, error) {
if f.closed { if f.closed {
return 0, nil, nil, nil, errClosed return Result{}, errClosed
} }
mf, err := os.Open(path) mf, err := os.Open(path)
if err != nil { if err != nil {
return 0, nil, nil, nil, err return Result{}, err
} }
defer func() { defer func() {
err := mf.Close() err := mf.Close()
@ -505,17 +521,17 @@ func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probabil
}() }()
fi, err := mf.Stat() fi, err := mf.Stat()
if err != nil { if err != nil {
return 0, nil, nil, nil, err return Result{}, err
} }
return f.ClassifyMessageReader(ctx, mf, fi.Size()) return f.ClassifyMessageReader(ctx, mf, fi.Size())
} }
func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) { func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (Result, error) {
m, err := message.EnsurePart(f.log.Logger, false, mf, size) m, err := message.EnsurePart(f.log.Logger, false, mf, size)
if err != nil && errors.Is(err, message.ErrBadContentType) { if err != nil && errors.Is(err, message.ErrBadContentType) {
// Invalid content-type header is a sure sign of spam. // Invalid content-type header is a sure sign of spam.
//f.log.Infox("parsing content", err) //f.log.Infox("parsing content", err)
return 1, nil, nil, nil, nil return Result{Probability: 1, Significant: true}, nil
} }
return f.ClassifyMessage(ctx, m) return f.ClassifyMessage(ctx, m)
} }
@ -523,15 +539,12 @@ func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size
// ClassifyMessage parses the mail message in r and returns the spam probability // ClassifyMessage parses the mail message in r and returns the spam probability
// (between 0 and 1), along with the tokenized words found in the message, and the // (between 0 and 1), along with the tokenized words found in the message, and the
// ham and spam words and their scores used. // ham and spam words and their scores used.
func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, hams, spams []WordScore, rerr error) { func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (Result, error) {
var err error words, err := f.ParseMessage(m)
words, err = f.ParseMessage(m)
if err != nil { if err != nil {
return 0, nil, nil, nil, err return Result{}, err
} }
return f.ClassifyWords(ctx, words)
probability, hams, spams, err = f.ClassifyWords(ctx, words)
return probability, words, hams, spams, err
} }
// Train adds the words of a single message to the filter. // Train adds the words of a single message to the filter.

View File

@ -78,16 +78,16 @@ func TestFilter(t *testing.T) {
return return
} }
prob, _, _, _, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0])) result, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0]))
tcheck(t, err, "classify ham message") tcheck(t, err, "classify ham message")
if prob > 0.1 { if result.Probability > 0.1 {
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob) t.Fatalf("trained ham file has prob %v, expected <= 0.1", result.Probability)
} }
prob, _, _, _, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0])) result, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0]))
tcheck(t, err, "classify spam message") tcheck(t, err, "classify spam message")
if prob < 0.9 { if result.Probability < 0.9 {
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob) t.Fatalf("trained spam file has prob %v, expected > 0.9", result.Probability)
} }
err = f.Close() err = f.Close()
@ -145,18 +145,18 @@ func TestFilter(t *testing.T) {
// Classify and verify. // Classify and verify.
_, err = hamf.Seek(0, 0) _, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message") tcheck(t, err, "seek ham message")
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize) result, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
tcheck(t, err, "classify ham") tcheck(t, err, "classify ham")
if prob > 0.1 { if result.Probability > 0.1 {
t.Fatalf("got prob %v, expected <= 0.1", prob) t.Fatalf("got prob %v, expected <= 0.1", result.Probability)
} }
_, err = spamf.Seek(0, 0) _, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message") tcheck(t, err, "seek spam message")
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize) result, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
tcheck(t, err, "classify spam") tcheck(t, err, "classify spam")
if prob < 0.9 { if result.Probability < 0.9 {
t.Fatalf("got prob %v, expected >= 0.9", prob) t.Fatalf("got prob %v, expected >= 0.9", result.Probability)
} }
// Untrain ham & spam. // Untrain ham & spam.
@ -185,18 +185,18 @@ func TestFilter(t *testing.T) {
// Classify again, should be unknown. // Classify again, should be unknown.
_, err = hamf.Seek(0, 0) _, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message") tcheck(t, err, "seek ham message")
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize) result, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
tcheck(t, err, "classify ham") tcheck(t, err, "classify ham")
if math.Abs(prob-0.5) > 0.1 { if math.Abs(result.Probability-0.5) > 0.1 {
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob) t.Fatalf("got prob %v, expected 0.5 +-0.1", result.Probability)
} }
_, err = spamf.Seek(0, 0) _, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message") tcheck(t, err, "seek spam message")
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize) result, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
tcheck(t, err, "classify spam") tcheck(t, err, "classify spam")
if math.Abs(prob-0.5) > 0.1 { if math.Abs(result.Probability-0.5) > 0.1 {
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob) t.Fatalf("got prob %v, expected 0.5 +-0.1", result.Probability)
} }
err = f.Close() err = f.Close()

View File

@ -230,7 +230,6 @@ test email
`, "\n", "\r\n") `, "\n", "\r\n")
ts.run(func(err error, client *smtpclient.Client) { ts.run(func(err error, client *smtpclient.Client) {
t.Helper()
mailFrom := "mjl@mox.example" mailFrom := "mjl@mox.example"
rcptTo := []string{"private@mox.example", "móx@mox.example"} rcptTo := []string{"private@mox.example", "móx@mox.example"}
if err == nil { if err == nil {
@ -239,11 +238,10 @@ test email
} }
ts.smtpErr(err, nil) ts.smtpErr(err, nil)
ts.checkCount("Inbox", 0) // Not receiving for mjl@ due to msgfrom, and not móx@ due to rcpt to. ts.checkCount("Inbox", 1) // Receiving once. For explicit móx@ recipient, not for mjl@ due to msgfrom, and another again for móx@ due to rcpt to.
}) })
ts.run(func(err error, client *smtpclient.Client) { ts.run(func(err error, client *smtpclient.Client) {
t.Helper()
mailFrom := "mjl@mox.example" mailFrom := "mjl@mox.example"
rcptTo := "private@mox.example" rcptTo := "private@mox.example"
if err == nil { if err == nil {
@ -251,7 +249,7 @@ test email
} }
ts.smtpErr(err, nil) ts.smtpErr(err, nil)
ts.checkCount("Inbox", 1) // Only receiving for móx@mox.example, not mjl@. ts.checkCount("Inbox", 2) // Only receiving 1 new message compared to previous, for móx@mox.example, not mjl@.
}) })
msg = strings.ReplaceAll(`From: <private@mox.example> msg = strings.ReplaceAll(`From: <private@mox.example>

View File

@ -528,7 +528,7 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver
err := f.Close() err := f.Close()
log.Check(err, "closing junkfilter") log.Check(err, "closing junkfilter")
}() }()
contentProb, _, hams, spams, err := f.ClassifyMessageReader(ctx, store.FileMsgReader(d.m.MsgPrefix, d.dataFile), d.m.Size) result, err := f.ClassifyMessageReader(ctx, store.FileMsgReader(d.m.MsgPrefix, d.dataFile), d.m.Size)
if err != nil { if err != nil {
log.Errorx("testing for spam", err) log.Errorx("testing for spam", err)
addReasonText("classify message error: %v", err) addReasonText("classify message error: %v", err)
@ -587,11 +587,12 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver
reason = reasonJunkContentStrict reason = reasonJunkContentStrict
thresholdRemark = " (stricter due to recipient address not in to/cc header)" thresholdRemark = " (stricter due to recipient address not in to/cc header)"
} }
accept = contentProb <= threshold accept = result.Probability <= threshold || (!result.Significant && !suspiciousIPrevFail)
junkSubjectpass = contentProb < threshold-0.2 junkSubjectpass = result.Probability < threshold-0.2
log.Info("content analyzed", log.Info("content analyzed",
slog.Bool("accept", accept), slog.Bool("accept", accept),
slog.Float64("contentprob", contentProb), slog.Float64("contentprob", result.Probability),
slog.Bool("contentsignificant", result.Significant),
slog.Bool("subjectpass", junkSubjectpass)) slog.Bool("subjectpass", junkSubjectpass))
s := "content: " s := "content: "
@ -600,9 +601,12 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver
} else { } else {
s += "junk" s += "junk"
} }
s += fmt.Sprintf(", spamscore %.2f, threshold %.2f%s", contentProb, threshold, thresholdRemark) if !result.Significant {
s += " (not significant)"
}
s += fmt.Sprintf(", spamscore %.2f, threshold %.2f%s", result.Probability, threshold, thresholdRemark)
s += " (ham words: " s += " (ham words: "
for i, w := range hams { for i, w := range result.Hams {
if i > 0 { if i > 0 {
s += ", " s += ", "
} }
@ -613,7 +617,7 @@ func analyze(ctx context.Context, log mlog.Log, resolver dns.Resolver, d deliver
s += fmt.Sprintf("%s %.3f", word, w.Score) s += fmt.Sprintf("%s %.3f", word, w.Score)
} }
s += "), (spam words: " s += "), (spam words: "
for i, w := range spams { for i, w := range result.Spams {
if i > 0 { if i > 0 {
s += ", " s += ", "
} }

View File

@ -670,6 +670,8 @@ func TestSpam(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
nm := m nm := m
tinsertmsg(t, ts.acc, "Inbox", &nm, deliverMessage) tinsertmsg(t, ts.acc, "Inbox", &nm, deliverMessage)
nm = m
tinsertmsg(t, ts.acc, "mjl2", &nm, deliverMessage)
} }
// Delivery from sender with bad reputation should fail. // Delivery from sender with bad reputation should fail.
@ -922,16 +924,22 @@ func TestDMARCSent(t *testing.T) {
// Update DNS for an SPF pass, and DMARC pass. // Update DNS for an SPF pass, and DMARC pass.
resolver.TXT["example.org."] = []string{"v=spf1 ip4:127.0.0.10 -all"} resolver.TXT["example.org."] = []string{"v=spf1 ip4:127.0.0.10 -all"}
// Insert spammy messages not related to the test message. // Insert hammy & spammy messages not related to the test message.
m := store.Message{ m := store.Message{
MailFrom: "remote@test.example", MailFrom: "remote@test.example",
RcptToLocalpart: smtp.Localpart("mjl"), RcptToLocalpart: smtp.Localpart("mjl"),
RcptToDomain: "mox.example", RcptToDomain: "mox.example",
Flags: store.Flags{Seen: true, Junk: true}, Flags: store.Flags{Seen: true},
Size: int64(len(deliverMessage)), Size: int64(len(deliverMessage)),
} }
for i := 0; i < 3; i++ { // We need at least 50 ham messages for the junk filter to become significant. We
// offset it with negative messages for mediocre score.
for i := 0; i < 50; i++ {
nm := m nm := m
nm.Junk = true
tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage)
nm = m
nm.Notjunk = true
tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage) tinsertmsg(t, ts.acc, "Archive", &nm, deliverMessage)
} }
tretrain(t, ts.acc) tretrain(t, ts.acc)