update to latest bstore (with support for an index on a []string: Message.DKIMDomains), and cyclic data types (to be used for Message.Part soon); also adds a context.Context to database operations.

This commit is contained in:
Mechiel Lukkien
2023-05-22 14:40:36 +02:00
parent f6ed860ccb
commit e81930ba20
58 changed files with 1970 additions and 1035 deletions

View File

@ -11,6 +11,7 @@ package junk
// todo: perhaps: whether anchor text in links in html are different from the url
import (
"context"
"errors"
"fmt"
"io"
@ -108,7 +109,7 @@ func (f *Filter) Close() error {
return err
}
func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
func OpenFilter(ctx context.Context, log *mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
var bloom *Bloom
if loadBloom {
var err error
@ -122,7 +123,7 @@ func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloo
}
}
db, err := openDB(dbPath)
db, err := openDB(ctx, dbPath)
if err != nil {
return nil, fmt.Errorf("open database: %s", err)
}
@ -137,7 +138,7 @@ func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloo
db: db,
bloom: bloom,
}
err = f.db.Read(func(tx *bstore.Tx) error {
err = f.db.Read(ctx, func(tx *bstore.Tx) error {
wc := wordscore{Word: "-"}
err := tx.Get(&wc)
f.hams = wc.Ham
@ -156,7 +157,7 @@ func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloo
// filter is marked as new until the first save, will be done automatically if
// TrainDirs is called. If the bloom and/or database files exist, an error is
// returned.
func NewFilter(log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
func NewFilter(ctx context.Context, log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
var err error
if _, err := os.Stat(bloomPath); err == nil {
return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
@ -182,7 +183,7 @@ func NewFilter(log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter,
err = bf.Close()
log.Check(err, "closing bloomfilter file")
db, err := newDB(log, dbPath)
db, err := newDB(ctx, log, dbPath)
if err != nil {
xerr := os.Remove(bloomPath)
log.Check(xerr, "removing bloom filter file after db init error")
@ -216,7 +217,7 @@ func openBloom(path string) (*Bloom, error) {
return NewBloom(buf, bloomK)
}
func newDB(log *mlog.Log, path string) (db *bstore.DB, rerr error) {
func newDB(ctx context.Context, log *mlog.Log, path string) (db *bstore.DB, rerr error) {
// Remove any existing files.
os.Remove(path)
@ -227,18 +228,18 @@ func newDB(log *mlog.Log, path string) (db *bstore.DB, rerr error) {
}
}()
db, err := bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
db, err := bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
if err != nil {
return nil, fmt.Errorf("open new database: %w", err)
}
return db, nil
}
func openDB(path string) (*bstore.DB, error) {
func openDB(ctx context.Context, path string) (*bstore.DB, error) {
if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("stat db file: %w", err)
}
return bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
return bstore.Open(ctx, path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
}
// Save stores modifications, e.g. from training, to the database and bloom
@ -280,7 +281,7 @@ func (f *Filter) Save() error {
}()
}
}
err := f.db.Write(func(tx *bstore.Tx) error {
err := f.db.Write(context.Background(), func(tx *bstore.Tx) error {
update := func(w string, ham, spam uint32) error {
if f.isNew {
return tx.Insert(&wordscore{w, ham, spam})
@ -318,12 +319,12 @@ func (f *Filter) Save() error {
return nil
}
func loadWords(db *bstore.DB, l []string, dst map[string]word) error {
func loadWords(ctx context.Context, db *bstore.DB, l []string, dst map[string]word) error {
sort.Slice(l, func(i, j int) bool {
return l[i] < l[j]
})
err := db.Read(func(tx *bstore.Tx) error {
err := db.Read(ctx, func(tx *bstore.Tx) error {
for _, w := range l {
wc := wordscore{Word: w}
if err := tx.Get(&wc); err == nil {
@ -339,7 +340,7 @@ func loadWords(db *bstore.DB, l []string, dst map[string]word) error {
}
// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
func (f *Filter) ClassifyWords(ctx context.Context, words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
if f.closed {
return 0, 0, 0, errClosed
}
@ -380,7 +381,7 @@ func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64,
// Fetch words from database.
fetched := map[string]word{}
if len(lookupWords) > 0 {
if err := loadWords(f.db, lookupWords, fetched); err != nil {
if err := loadWords(ctx, f.db, lookupWords, fetched); err != nil {
return 0, 0, 0, err
}
for w, c := range fetched {
@ -477,7 +478,7 @@ func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64,
}
// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
func (f *Filter) ClassifyMessagePath(path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
func (f *Filter) ClassifyMessagePath(ctx context.Context, path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
if f.closed {
return 0, nil, 0, 0, errClosed
}
@ -494,35 +495,35 @@ func (f *Filter) ClassifyMessagePath(path string) (probability float64, words ma
if err != nil {
return 0, nil, 0, 0, err
}
return f.ClassifyMessageReader(mf, fi.Size())
return f.ClassifyMessageReader(ctx, mf, fi.Size())
}
func (f *Filter) ClassifyMessageReader(mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
func (f *Filter) ClassifyMessageReader(ctx context.Context, mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
m, err := message.EnsurePart(mf, size)
if err != nil && errors.Is(err, message.ErrBadContentType) {
// Invalid content-type header is a sure sign of spam.
//f.log.Infox("parsing content", err)
return 1, nil, 0, 0, nil
}
return f.ClassifyMessage(m)
return f.ClassifyMessage(ctx, m)
}
// ClassifyMessage parses the mail message in r and returns the spam probability
// (between 0 and 1), along with the tokenized words found in the message, and the
// number of recognized ham and spam words.
func (f *Filter) ClassifyMessage(m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
func (f *Filter) ClassifyMessage(ctx context.Context, m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
var err error
words, err = f.ParseMessage(m)
if err != nil {
return 0, nil, 0, 0, err
}
probability, nham, nspam, err = f.ClassifyWords(words)
probability, nham, nspam, err = f.ClassifyWords(ctx, words)
return probability, words, nham, nspam, err
}
// Train adds the words of a single message to the filter.
func (f *Filter) Train(ham bool, words map[string]struct{}) error {
func (f *Filter) Train(ctx context.Context, ham bool, words map[string]struct{}) error {
if err := f.ensureBloom(); err != nil {
return err
}
@ -539,7 +540,7 @@ func (f *Filter) Train(ham bool, words map[string]struct{}) error {
}
}
if err := f.loadCache(lwords); err != nil {
if err := f.loadCache(ctx, lwords); err != nil {
return err
}
@ -563,34 +564,34 @@ func (f *Filter) Train(ham bool, words map[string]struct{}) error {
return nil
}
func (f *Filter) TrainMessage(r io.ReaderAt, size int64, ham bool) error {
func (f *Filter) TrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
p, _ := message.EnsurePart(r, size)
words, err := f.ParseMessage(p)
if err != nil {
return fmt.Errorf("parsing mail contents: %v", err)
}
return f.Train(ham, words)
return f.Train(ctx, ham, words)
}
func (f *Filter) UntrainMessage(r io.ReaderAt, size int64, ham bool) error {
func (f *Filter) UntrainMessage(ctx context.Context, r io.ReaderAt, size int64, ham bool) error {
p, _ := message.EnsurePart(r, size)
words, err := f.ParseMessage(p)
if err != nil {
return fmt.Errorf("parsing mail contents: %v", err)
}
return f.Untrain(ham, words)
return f.Untrain(ctx, ham, words)
}
func (f *Filter) loadCache(lwords []string) error {
func (f *Filter) loadCache(ctx context.Context, lwords []string) error {
if len(lwords) == 0 {
return nil
}
return loadWords(f.db, lwords, f.cache)
return loadWords(ctx, f.db, lwords, f.cache)
}
// Untrain adjusts the filter to undo a previous training of the words.
func (f *Filter) Untrain(ham bool, words map[string]struct{}) error {
func (f *Filter) Untrain(ctx context.Context, ham bool, words map[string]struct{}) error {
if err := f.ensureBloom(); err != nil {
return err
}
@ -602,7 +603,7 @@ func (f *Filter) Untrain(ham bool, words map[string]struct{}) error {
lwords = append(lwords, w)
}
}
if err := f.loadCache(lwords); err != nil {
if err := f.loadCache(ctx, lwords); err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package junk
import (
"context"
"fmt"
"math"
"os"
@ -10,6 +11,8 @@ import (
"github.com/mjl-/mox/mlog"
)
var ctxbg = context.Background()
func tcheck(t *testing.T, err error, msg string) {
t.Helper()
if err != nil {
@ -43,12 +46,12 @@ func TestFilter(t *testing.T) {
bloomPath := "../testdata/junk/filter.bloom"
os.Remove(dbPath)
os.Remove(bloomPath)
f, err := NewFilter(log, params, dbPath, bloomPath)
f, err := NewFilter(ctxbg, log, params, dbPath, bloomPath)
tcheck(t, err, "new filter")
err = f.Close()
tcheck(t, err, "close filter")
f, err = OpenFilter(log, params, dbPath, bloomPath, true)
f, err = OpenFilter(ctxbg, log, params, dbPath, bloomPath, true)
tcheck(t, err, "open filter")
// Ensure these dirs exist. Developers should bring their own ham/spam example
@ -75,13 +78,13 @@ func TestFilter(t *testing.T) {
return
}
prob, _, _, _, err := f.ClassifyMessagePath(filepath.Join(hamdir, hamfiles[0]))
prob, _, _, _, err := f.ClassifyMessagePath(ctxbg, filepath.Join(hamdir, hamfiles[0]))
tcheck(t, err, "classify ham message")
if prob > 0.1 {
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
}
prob, _, _, _, err = f.ClassifyMessagePath(filepath.Join(spamdir, spamfiles[0]))
prob, _, _, _, err = f.ClassifyMessagePath(ctxbg, filepath.Join(spamdir, spamfiles[0]))
tcheck(t, err, "classify spam message")
if prob < 0.9 {
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
@ -94,7 +97,7 @@ func TestFilter(t *testing.T) {
// classified as ham/spam. Then we untrain to see they are no longer classified.
os.Remove(dbPath)
os.Remove(bloomPath)
f, err = NewFilter(log, params, dbPath, bloomPath)
f, err = NewFilter(ctxbg, log, params, dbPath, bloomPath)
tcheck(t, err, "open filter")
hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
@ -112,18 +115,18 @@ func TestFilter(t *testing.T) {
spamsize := spamstat.Size()
// Train each message twice, to prevent single occurrences from being ignored.
err = f.TrainMessage(hamf, hamsize, true)
err = f.TrainMessage(ctxbg, hamf, hamsize, true)
tcheck(t, err, "train ham message")
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
err = f.TrainMessage(hamf, hamsize, true)
err = f.TrainMessage(ctxbg, hamf, hamsize, true)
tcheck(t, err, "train ham message")
err = f.TrainMessage(spamf, spamsize, false)
err = f.TrainMessage(ctxbg, spamf, spamsize, false)
tcheck(t, err, "train spam message")
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
err = f.TrainMessage(spamf, spamsize, true)
err = f.TrainMessage(ctxbg, spamf, spamsize, true)
tcheck(t, err, "train spam message")
if !f.modified {
@ -142,7 +145,7 @@ func TestFilter(t *testing.T) {
// Classify and verify.
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
tcheck(t, err, "classify ham")
if prob > 0.1 {
t.Fatalf("got prob %v, expected <= 0.1", prob)
@ -150,7 +153,7 @@ func TestFilter(t *testing.T) {
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
tcheck(t, err, "classify spam")
if prob < 0.9 {
t.Fatalf("got prob %v, expected >= 0.9", prob)
@ -159,20 +162,20 @@ func TestFilter(t *testing.T) {
// Untrain ham & spam.
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
err = f.UntrainMessage(hamf, hamsize, true)
err = f.UntrainMessage(ctxbg, hamf, hamsize, true)
tcheck(t, err, "untrain ham message")
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
err = f.UntrainMessage(hamf, spamsize, true)
err = f.UntrainMessage(ctxbg, hamf, spamsize, true)
tcheck(t, err, "untrain ham message")
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
err = f.UntrainMessage(spamf, spamsize, true)
err = f.UntrainMessage(ctxbg, spamf, spamsize, true)
tcheck(t, err, "untrain spam message")
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
err = f.UntrainMessage(spamf, spamsize, true)
err = f.UntrainMessage(ctxbg, spamf, spamsize, true)
tcheck(t, err, "untrain spam message")
if !f.modified {
@ -182,7 +185,7 @@ func TestFilter(t *testing.T) {
// Classify again, should be unknown.
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, hamf, hamsize)
tcheck(t, err, "classify ham")
if math.Abs(prob-0.5) > 0.1 {
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
@ -190,7 +193,7 @@ func TestFilter(t *testing.T) {
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
prob, _, _, _, err = f.ClassifyMessageReader(ctxbg, spamf, spamsize)
tcheck(t, err, "classify spam")
if math.Abs(prob-0.5) > 0.1 {
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)

View File

@ -23,7 +23,7 @@ func FuzzParseMessage(f *testing.F) {
os.Remove(dbPath)
os.Remove(bloomPath)
params := Params{Twograms: true}
jf, err := NewFilter(xlog, params, dbPath, bloomPath)
jf, err := NewFilter(ctxbg, xlog, params, dbPath, bloomPath)
if err != nil {
f.Fatalf("new filter: %v", err)
}