This commit is contained in:
Mechiel Lukkien
2023-01-30 14:27:06 +01:00
commit cb229cb6cf
1256 changed files with 491723 additions and 0 deletions

165
junk/bloom.go Normal file
View File

@ -0,0 +1,165 @@
package junk
import (
"errors"
"os"
"golang.org/x/crypto/blake2b"
)
// see https://en.wikipedia.org/wiki/Bloom_filter
var errWidth = errors.New("k and width wider than 256 bits and width not more than 32")
var errPowerOfTwo = errors.New("data not a power of two")
// Bloom is a bloom filter.
type Bloom struct {
data []byte
k int // Number of bits we store/lookup in the bloom filter per value.
w int // Number of bits needed to address a single bit position.
modified bool
}
func bloomWidth(fileSize int) int {
w := 0
for bits := uint32(fileSize * 8); bits > 1; bits >>= 1 {
w++
}
return w
}
// BloomValid returns an error if the bloom file parameters are not correct.
func BloomValid(fileSize int, k int) error {
_, err := bloomValid(fileSize, k)
return err
}
func bloomValid(fileSize, k int) (int, error) {
w := bloomWidth(fileSize)
if 1<<w != fileSize*8 {
return 0, errPowerOfTwo
}
if k*w > 256 || w > 32 {
return 0, errWidth
}
return w, nil
}
// NewBloom returns a bloom filter with given initial data.
//
// The number of bits in data must be a power of 2.
// K is the number of "hashes" (bits) to store/lookup for each value stored.
// Width is calculated as the number of bits needed to represent a single bit/hash
// position in the data.
//
// For each value stored/looked up, a hash over the value is calculated. The hash
// is split into "k" values that are "width" bits wide, each used to lookup a bit.
// K * width must not exceed 256.
func NewBloom(data []byte, k int) (*Bloom, error) {
w, err := bloomValid(len(data), k)
if err != nil {
return nil, err
}
return &Bloom{
data: data,
k: k,
w: w,
}, nil
}
func (b *Bloom) Add(s string) {
h := hash([]byte(s), b.w)
for i := 0; i < b.k; i++ {
b.set(h.nextPos())
}
}
func (b *Bloom) Has(s string) bool {
h := hash([]byte(s), b.w)
for i := 0; i < b.k; i++ {
if !b.has(h.nextPos()) {
return false
}
}
return true
}
func (b *Bloom) Bytes() []byte {
return b.data
}
func (b *Bloom) Modified() bool {
return b.modified
}
// Ones returns the number of ones.
func (b *Bloom) Ones() (n int) {
for _, d := range b.data {
for i := 0; i < 8; i++ {
if d&1 != 0 {
n++
}
d >>= 1
}
}
return n
}
func (b *Bloom) Write(path string) error {
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0660)
if err != nil {
return err
}
if _, err := f.Write(b.data); err != nil {
f.Close()
return err
}
if err := f.Close(); err != nil {
return err
}
b.modified = false
return nil
}
func (b *Bloom) has(p int) bool {
v := b.data[p>>3] >> (7 - (p & 7))
return v&1 != 0
}
func (b *Bloom) set(p int) {
by := p >> 3
bi := p & 0x7
var v byte = 1 << (7 - bi)
if b.data[by]&v == 0 {
b.data[by] |= v
b.modified = true
}
}
type bits struct {
width int // Number of bits for each position.
buf []byte // Remaining bytes to use for next position.
cur uint64 // Bits to read next position from. Replenished from buf.
ncur int // Number of bits available in cur. We consume the highest bits first.
}
func hash(v []byte, width int) *bits {
buf := blake2b.Sum256(v)
return &bits{width: width, buf: buf[:]}
}
// nextPos returns the next bit position.
func (b *bits) nextPos() (v int) {
if b.width > b.ncur {
for len(b.buf) > 0 && b.ncur < 64-8 {
b.cur <<= 8
b.cur |= uint64(b.buf[0])
b.ncur += 8
b.buf = b.buf[1:]
}
}
v = int((b.cur >> (b.ncur - b.width)) & ((1 << b.width) - 1))
b.ncur -= b.width
return v
}

136
junk/bloom_test.go Normal file
View File

@ -0,0 +1,136 @@
package junk
import (
"fmt"
"testing"
)
func TestBloom(t *testing.T) {
if err := BloomValid(3, 10); err == nil {
t.Fatalf("missing error for invalid bloom filter size")
}
_, err := NewBloom(make([]byte, 3), 10)
if err == nil {
t.Fatalf("missing error for invalid bloom filter size")
}
b, err := NewBloom(make([]byte, 256), 5)
if err != nil {
t.Fatalf("newbloom: %s", err)
}
absent := func(v string) {
t.Helper()
if b.Has(v) {
t.Fatalf("should be absent: %q", v)
}
}
present := func(v string) {
t.Helper()
if !b.Has(v) {
t.Fatalf("should be present: %q", v)
}
}
absent("test")
if b.Modified() {
t.Fatalf("bloom filter already modified?")
}
b.Add("test")
present("test")
present("test")
words := []string{}
for i := 'a'; i <= 'z'; i++ {
words = append(words, fmt.Sprintf("%c", i))
}
for _, w := range words {
absent(w)
b.Add(w)
present(w)
}
for _, w := range words {
present(w)
}
if !b.Modified() {
t.Fatalf("bloom filter was not modified?")
}
//log.Infof("ones: %d, m %d", b.Ones(), len(b.Bytes())*8)
}
func TestBits(t *testing.T) {
b := &bits{width: 1, buf: []byte{0xff, 0xff}}
for i := 0; i < 16; i++ {
if b.nextPos() != 1 {
t.Fatalf("pos not 1")
}
}
b = &bits{width: 2, buf: []byte{0xff, 0xff}}
for i := 0; i < 8; i++ {
if b.nextPos() != 0b11 {
t.Fatalf("pos not 0b11")
}
}
b = &bits{width: 1, buf: []byte{0b10101010, 0b10101010}}
for i := 0; i < 16; i++ {
if b.nextPos() != ((i + 1) % 2) {
t.Fatalf("bad pos")
}
}
b = &bits{width: 2, buf: []byte{0b10101010, 0b10101010}}
for i := 0; i < 8; i++ {
if b.nextPos() != 0b10 {
t.Fatalf("pos not 0b10")
}
}
}
func TestSet(t *testing.T) {
b := &Bloom{
data: []byte{
0b10101010,
0b00000000,
0b11111111,
0b01010101,
},
}
for i := 0; i < 8; i++ {
v := b.has(i)
if v != (i%2 == 0) {
t.Fatalf("bad has")
}
}
for i := 8; i < 16; i++ {
if b.has(i) {
t.Fatalf("bad has")
}
}
for i := 16; i < 24; i++ {
if !b.has(i) {
t.Fatalf("bad has")
}
}
for i := 24; i < 32; i++ {
v := b.has(i)
if v != (i%2 != 0) {
t.Fatalf("bad has")
}
}
}
func TestOnes(t *testing.T) {
ones := func(b *Bloom, x int) {
t.Helper()
n := b.Ones()
if n != x {
t.Fatalf("ones: got %d, expected %d", n, x)
}
}
ones(&Bloom{data: []byte{0b10101010}}, 4)
ones(&Bloom{data: []byte{0b01010101}}, 4)
ones(&Bloom{data: []byte{0b11111111}}, 8)
ones(&Bloom{data: []byte{0b00000000}}, 0)
}

726
junk/filter.go Normal file
View File

@ -0,0 +1,726 @@
// Package junk implements a bayesian spam filter.
//
// A message can be parsed into words. Words (or pairs or triplets) can be used
// to train the filter or to classify the message as ham or spam. Training
// records the words in the database as ham/spam. Classifying consists of
// calculating the ham/spam probability by combining the words in the message
// with their ham/spam status.
package junk
// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
// todo: perhaps: whether anchor text in links in html are different from the url
import (
"errors"
"fmt"
"io"
"math"
"os"
"sort"
"time"
"github.com/mjl-/bstore"
"github.com/mjl-/mox/message"
"github.com/mjl-/mox/mlog"
)
var (
xlog = mlog.New("junk")
errBadContentType = errors.New("bad content-type") // sure sign of spam
errClosed = errors.New("filter is closed")
)
type word struct {
Ham uint32
Spam uint32
}
type wordscore struct {
Word string
Ham uint32
Spam uint32
}
// Params holds parameters for the filter. Most are at test-time. The first are
// used during parsing and training.
type Params struct {
Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
}
type Filter struct {
Params
log *mlog.Log // For logging cid.
closed bool
modified bool // Whether any modifications are pending. Cleared by Save.
hams, spams uint32 // Message count, stored in db under word "-".
cache map[string]word // Words read from database or during training.
changed map[string]word // Words modified during training.
dbPath, bloomPath string
db *bstore.DB // Always open on a filter.
bloom *Bloom // Only opened when writing.
isNew bool // Set for new filters until their first sync to disk. For faster writing.
}
func (f *Filter) ensureBloom() error {
if f.bloom != nil {
return nil
}
var err error
f.bloom, err = openBloom(f.bloomPath)
return err
}
// Close first saves the filter if it has modifications, then closes the database
// connection and releases the bloom filter.
func (f *Filter) Close() error {
if f.closed {
return errClosed
}
var err error
if f.modified {
err = f.Save()
}
if err != nil {
f.db.Close()
} else {
err = f.db.Close()
}
*f = Filter{log: f.log, closed: true}
return err
}
func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
var bloom *Bloom
if loadBloom {
var err error
bloom, err = openBloom(bloomPath)
if err != nil {
return nil, err
}
} else if fi, err := os.Stat(bloomPath); err == nil {
if err := BloomValid(int(fi.Size()), bloomK); err != nil {
return nil, fmt.Errorf("bloom: %s", err)
}
}
db, err := openDB(dbPath)
if err != nil {
return nil, fmt.Errorf("open database: %s", err)
}
f := &Filter{
Params: params,
log: log,
cache: map[string]word{},
changed: map[string]word{},
dbPath: dbPath,
bloomPath: bloomPath,
db: db,
bloom: bloom,
}
err = f.db.Read(func(tx *bstore.Tx) error {
wc := wordscore{Word: "-"}
err := tx.Get(&wc)
f.hams = wc.Ham
f.spams = wc.Spam
return err
})
if err != nil {
f.Close()
return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
}
return f, nil
}
// NewFilter creates a new filter with empty bloom filter and database files. The
// filter is marked as new until the first save, will be done automatically if
// TrainDirs is called. If the bloom and/or database files exist, an error is
// returned.
func NewFilter(log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
var err error
if _, err := os.Stat(bloomPath); err == nil {
return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
} else if _, err := os.Stat(dbPath); err == nil {
return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
}
bloomSizeBytes := 4 * 1024 * 1024
if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
return nil, fmt.Errorf("bloom: %s", err)
}
bf, err := os.Create(bloomPath)
if err != nil {
return nil, fmt.Errorf("creating bloom file: %w", err)
}
if err := bf.Truncate(4 * 1024 * 1024); err != nil {
bf.Close()
os.Remove(bloomPath)
return nil, fmt.Errorf("making empty bloom filter: %s", err)
}
bf.Close()
db, err := newDB(dbPath)
if err != nil {
os.Remove(bloomPath)
os.Remove(dbPath)
return nil, fmt.Errorf("open database: %s", err)
}
words := map[string]word{} // f.changed is set to new map after training
f := &Filter{
Params: params,
log: log,
modified: true, // Ensure ham/spam message count is added for new filter.
cache: words,
changed: words,
dbPath: dbPath,
bloomPath: bloomPath,
db: db,
isNew: true,
}
return f, nil
}
const bloomK = 10
func openBloom(path string) (*Bloom, error) {
buf, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading bloom file: %w", err)
}
return NewBloom(buf, bloomK)
}
func newDB(path string) (db *bstore.DB, rerr error) {
// Remove any existing files.
os.Remove(path)
defer func() {
if rerr != nil {
if db != nil {
db.Close()
}
db = nil
os.Remove(path)
}
}()
db, err := bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
if err != nil {
return nil, fmt.Errorf("open new database: %w", err)
}
return db, nil
}
func openDB(path string) (*bstore.DB, error) {
if _, err := os.Stat(path); err != nil {
return nil, fmt.Errorf("stat db file: %w", err)
}
return bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
}
// Save stores modifications, e.g. from training, to the database and bloom
// filter files.
func (f *Filter) Save() error {
if f.closed {
return errClosed
}
if !f.modified {
return nil
}
if f.bloom != nil && f.bloom.Modified() {
if err := f.bloom.Write(f.bloomPath); err != nil {
return fmt.Errorf("writing bloom filter: %w", err)
}
}
// We need to insert sequentially for reasonable performance.
words := make([]string, len(f.changed))
i := 0
for w := range f.changed {
words[i] = w
i++
}
sort.Slice(words, func(i, j int) bool {
return words[i] < words[j]
})
f.log.Info("inserting words in junkfilter db", mlog.Field("words", len(f.changed)))
// start := time.Now()
if f.isNew {
if err := f.db.HintAppend(true, wordscore{}); err != nil {
f.log.Errorx("hint appendonly", err)
} else {
defer f.db.HintAppend(false, wordscore{})
}
}
err := f.db.Write(func(tx *bstore.Tx) error {
update := func(w string, ham, spam uint32) error {
if f.isNew {
return tx.Insert(&wordscore{w, ham, spam})
}
wc := wordscore{w, 0, 0}
err := tx.Get(&wc)
if err == bstore.ErrAbsent {
return tx.Insert(&wordscore{w, ham, spam})
} else if err != nil {
return err
}
return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
}
if err := update("-", f.hams, f.spams); err != nil {
return fmt.Errorf("storing total ham/spam message count: %s", err)
}
for _, w := range words {
c := f.changed[w]
if err := update(w, c.Ham, c.Spam); err != nil {
return fmt.Errorf("updating ham/spam count: %s", err)
}
}
return nil
})
if err != nil {
return fmt.Errorf("updating database: %w", err)
}
f.changed = map[string]word{}
f.modified = false
f.isNew = false
// f.log.Info("wrote filter to db", mlog.Field("duration", time.Since(start)))
return nil
}
func loadWords(db *bstore.DB, l []string, dst map[string]word) error {
sort.Slice(l, func(i, j int) bool {
return l[i] < l[j]
})
err := db.Read(func(tx *bstore.Tx) error {
for _, w := range l {
wc := wordscore{Word: w}
if err := tx.Get(&wc); err == nil {
dst[w] = word{wc.Ham, wc.Spam}
}
}
return nil
})
if err != nil {
return fmt.Errorf("fetching words: %s", err)
}
return nil
}
// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
if f.closed {
return 0, 0, 0, errClosed
}
type xword struct {
Word string
R float64
}
var hamHigh float64 = 0
var spamLow float64 = 1
var topHam []xword
var topSpam []xword
// Find words that should be in the database.
lookupWords := []string{}
expect := map[string]struct{}{}
unknowns := map[string]struct{}{}
totalUnknown := 0
for w := range words {
if f.bloom != nil && !f.bloom.Has(w) {
totalUnknown++
if len(unknowns) < 50 {
unknowns[w] = struct{}{}
}
continue
}
if _, ok := f.cache[w]; ok {
continue
}
lookupWords = append(lookupWords, w)
expect[w] = struct{}{}
}
if len(unknowns) > 0 {
f.log.Debug("unknown words in bloom filter, showing max 50", mlog.Field("words", unknowns), mlog.Field("totalunknown", totalUnknown), mlog.Field("totalwords", len(words)))
}
// Fetch words from database.
fetched := map[string]word{}
if len(lookupWords) > 0 {
if err := loadWords(f.db, lookupWords, fetched); err != nil {
return 0, 0, 0, err
}
for w, c := range fetched {
delete(expect, w)
f.cache[w] = c
}
f.log.Debug("unknown words in db", mlog.Field("words", expect), mlog.Field("totalunknown", len(expect)), mlog.Field("totalwords", len(words)))
}
for w := range words {
c, ok := f.cache[w]
if !ok {
continue
}
var wS, wH float64
if f.spams > 0 {
wS = float64(c.Spam) / float64(f.spams)
}
if f.hams > 0 {
wH = float64(c.Ham) / float64(f.hams)
}
r := wS / (wS + wH)
if r < f.MaxPower {
r = f.MaxPower
} else if r >= 1-f.MaxPower {
r = 1 - f.MaxPower
}
if c.Ham+c.Spam <= uint32(f.RareWords) {
// Reduce the power of rare words.
r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
}
if math.Abs(0.5-r) < f.IgnoreWords {
continue
}
if r < 0.5 {
if len(topHam) >= f.TopWords && r > hamHigh {
continue
}
topHam = append(topHam, xword{w, r})
if r > hamHigh {
hamHigh = r
}
} else if r > 0.5 {
if len(topSpam) >= f.TopWords && r < spamLow {
continue
}
topSpam = append(topSpam, xword{w, r})
if r < spamLow {
spamLow = r
}
}
}
sort.Slice(topHam, func(i, j int) bool {
a, b := topHam[i], topHam[j]
if a.R == b.R {
return len(a.Word) > len(b.Word)
}
return a.R < b.R
})
sort.Slice(topSpam, func(i, j int) bool {
a, b := topSpam[i], topSpam[j]
if a.R == b.R {
return len(a.Word) > len(b.Word)
}
return a.R > b.R
})
nham = f.TopWords
if nham > len(topHam) {
nham = len(topHam)
}
nspam = f.TopWords
if nspam > len(topSpam) {
nspam = len(topSpam)
}
topHam = topHam[:nham]
topSpam = topSpam[:nspam]
var eta float64
for _, x := range topHam {
eta += math.Log(1-x.R) - math.Log(x.R)
}
for _, x := range topSpam {
eta += math.Log(1-x.R) - math.Log(x.R)
}
f.log.Debug("top words", mlog.Field("hams", topHam), mlog.Field("spams", topSpam))
prob := 1 / (1 + math.Pow(math.E, eta))
return prob, len(topHam), len(topSpam), nil
}
// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
func (f *Filter) ClassifyMessagePath(path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
if f.closed {
return 0, nil, 0, 0, errClosed
}
mf, err := os.Open(path)
if err != nil {
return 0, nil, 0, 0, err
}
defer mf.Close()
fi, err := mf.Stat()
if err != nil {
return 0, nil, 0, 0, err
}
return f.ClassifyMessageReader(mf, fi.Size())
}
func (f *Filter) ClassifyMessageReader(mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
m, err := message.EnsurePart(mf, size)
if err != nil && errors.Is(err, message.ErrBadContentType) {
// Invalid content-type header is a sure sign of spam.
//f.log.Infox("parsing content", err)
return 1, nil, 0, 0, nil
}
return f.ClassifyMessage(m)
}
// ClassifyMessage parses the mail message in r and returns the spam probability
// (between 0 and 1), along with the tokenized words found in the message, and the
// number of recognized ham and spam words.
func (f *Filter) ClassifyMessage(m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
var err error
words, err = f.ParseMessage(m)
if err != nil {
return 0, nil, 0, 0, err
}
probability, nham, nspam, err = f.ClassifyWords(words)
return probability, words, nham, nspam, err
}
// Train adds the words of a single message to the filter.
func (f *Filter) Train(ham bool, words map[string]struct{}) error {
if err := f.ensureBloom(); err != nil {
return err
}
var lwords []string
for w := range words {
if !f.bloom.Has(w) {
f.bloom.Add(w)
continue
}
if _, ok := f.cache[w]; !ok {
lwords = append(lwords, w)
}
}
if err := f.loadCache(lwords); err != nil {
return err
}
f.modified = true
if ham {
f.hams++
} else {
f.spams++
}
for w := range words {
c := f.cache[w]
if ham {
c.Ham++
} else {
c.Spam++
}
f.cache[w] = c
f.changed[w] = c
}
return nil
}
func (f *Filter) TrainMessage(r io.ReaderAt, size int64, ham bool) error {
p, _ := message.EnsurePart(r, size)
words, err := f.ParseMessage(p)
if err != nil {
return fmt.Errorf("parsing mail contents: %v", err)
}
return f.Train(ham, words)
}
func (f *Filter) UntrainMessage(r io.ReaderAt, size int64, ham bool) error {
p, _ := message.EnsurePart(r, size)
words, err := f.ParseMessage(p)
if err != nil {
return fmt.Errorf("parsing mail contents: %v", err)
}
return f.Untrain(ham, words)
}
func (f *Filter) loadCache(lwords []string) error {
if len(lwords) == 0 {
return nil
}
return loadWords(f.db, lwords, f.cache)
}
// Untrain adjusts the filter to undo a previous training of the words.
func (f *Filter) Untrain(ham bool, words map[string]struct{}) error {
if err := f.ensureBloom(); err != nil {
return err
}
// Lookup any words from the db that aren't in the cache and put them in the cache for modification.
var lwords []string
for w := range words {
if _, ok := f.cache[w]; !ok {
lwords = append(lwords, w)
}
}
if err := f.loadCache(lwords); err != nil {
return err
}
// Modify the message count.
f.modified = true
if ham {
f.hams--
} else {
f.spams--
}
// Decrease the word counts.
for w := range words {
c, ok := f.cache[w]
if !ok {
continue
}
if ham {
c.Ham--
} else {
c.Spam--
}
f.cache[w] = c
f.changed[w] = c
}
return nil
}
// TrainDir parses mail messages from files and trains the filter.
func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
if f.closed {
return 0, 0, errClosed
}
if err := f.ensureBloom(); err != nil {
return 0, 0, err
}
for _, name := range files {
p := fmt.Sprintf("%s/%s", dir, name)
valid, words, err := f.tokenizeMail(p)
if err != nil {
// f.log.Infox("tokenizing mail", err, mlog.Field("path", p))
malformed++
continue
}
if !valid {
continue
}
n++
for w := range words {
if !f.bloom.Has(w) {
f.bloom.Add(w)
continue
}
c := f.cache[w]
f.modified = true
if ham {
c.Ham++
} else {
c.Spam++
}
f.cache[w] = c
f.changed[w] = c
}
}
return
}
// TrainDirs trains and saves a filter with mail messages from different types
// of directories.
func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
if f.closed {
return errClosed
}
var err error
var start time.Time
var hamMalformed, sentMalformed, spamMalformed uint32
start = time.Now()
f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
if err != nil {
return err
}
tham := time.Since(start)
var sent uint32
start = time.Now()
if sentDir != "" {
sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
if err != nil {
return err
}
}
tsent := time.Since(start)
start = time.Now()
f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
if err != nil {
return err
}
tspam := time.Since(start)
hams := f.hams
f.hams += sent
if err := f.Save(); err != nil {
return fmt.Errorf("saving filter: %s", err)
}
dbSize := f.fileSize(f.dbPath)
bloomSize := f.fileSize(f.bloomPath)
fields := []mlog.Pair{
mlog.Field("hams", hams),
mlog.Field("hamTime", tham),
mlog.Field("hamMalformed", hamMalformed),
mlog.Field("sent", sent),
mlog.Field("sentTime", tsent),
mlog.Field("sentMalformed", sentMalformed),
mlog.Field("spams", f.spams),
mlog.Field("spamTime", tspam),
mlog.Field("spamMalformed", spamMalformed),
mlog.Field("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
mlog.Field("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
mlog.Field("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
}
xlog.Print("training done", fields...)
return nil
}
func (f *Filter) fileSize(p string) int {
fi, err := os.Stat(p)
if err != nil {
f.log.Infox("stat", err, mlog.Field("path", p))
return 0
}
return int(fi.Size())
}

201
junk/filter_test.go Normal file
View File

@ -0,0 +1,201 @@
package junk
import (
"fmt"
"math"
"os"
"path/filepath"
"testing"
"github.com/mjl-/mox/mlog"
)
func tcheck(t *testing.T, err error, msg string) {
t.Helper()
if err != nil {
t.Fatalf("%s: %s", msg, err)
}
}
func tlistdir(t *testing.T, name string) []string {
t.Helper()
l, err := os.ReadDir(name)
tcheck(t, err, "readdir")
names := make([]string, len(l))
for i, e := range l {
names[i] = e.Name()
}
return names
}
func TestFilter(t *testing.T) {
log := mlog.New("junk")
params := Params{
Onegrams: true,
Twograms: true,
Threegrams: false,
MaxPower: 0.1,
TopWords: 10,
IgnoreWords: 0.1,
RareWords: 1,
}
dbPath := "../testdata/junk/filter.db"
bloomPath := "../testdata/junk/filter.bloom"
os.Remove(dbPath)
os.Remove(bloomPath)
f, err := NewFilter(log, params, dbPath, bloomPath)
tcheck(t, err, "new filter")
err = f.Close()
tcheck(t, err, "close filter")
f, err = OpenFilter(log, params, dbPath, bloomPath, true)
tcheck(t, err, "open filter")
// Ensure these dirs exist. Developers should bring their own ham/spam example
// emails.
os.MkdirAll("../testdata/train/ham", 0770)
os.MkdirAll("../testdata/train/spam", 0770)
hamdir := "../testdata/train/ham"
spamdir := "../testdata/train/spam"
hamfiles := tlistdir(t, hamdir)
if len(hamfiles) > 100 {
hamfiles = hamfiles[:100]
}
spamfiles := tlistdir(t, spamdir)
if len(spamfiles) > 100 {
spamfiles = spamfiles[:100]
}
err = f.TrainDirs(hamdir, "", spamdir, hamfiles, nil, spamfiles)
tcheck(t, err, "train dirs")
if len(hamfiles) == 0 || len(spamfiles) == 0 {
fmt.Println("not training, no ham and/or spam messages, add them to testdata/train/ham and testdata/train/spam")
return
}
prob, _, _, _, err := f.ClassifyMessagePath(filepath.Join(hamdir, hamfiles[0]))
tcheck(t, err, "classify ham message")
if prob > 0.1 {
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
}
prob, _, _, _, err = f.ClassifyMessagePath(filepath.Join(spamdir, spamfiles[0]))
tcheck(t, err, "classify spam message")
if prob < 0.9 {
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
}
err = f.Close()
tcheck(t, err, "close filter")
// Start again with empty filter. We'll train a few messages and check they are
// classified as ham/spam. Then we untrain to see they are no longer classified.
os.Remove(dbPath)
os.Remove(bloomPath)
f, err = NewFilter(log, params, dbPath, bloomPath)
tcheck(t, err, "open filter")
hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
tcheck(t, err, "open hamfile")
defer hamf.Close()
hamstat, err := hamf.Stat()
tcheck(t, err, "stat hamfile")
hamsize := hamstat.Size()
spamf, err := os.Open(filepath.Join(spamdir, spamfiles[0]))
tcheck(t, err, "open spamfile")
defer spamf.Close()
spamstat, err := spamf.Stat()
tcheck(t, err, "stat spamfile")
spamsize := spamstat.Size()
// Train each message twice, to prevent single occurrences from being ignored.
err = f.TrainMessage(hamf, hamsize, true)
tcheck(t, err, "train ham message")
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
err = f.TrainMessage(hamf, hamsize, true)
tcheck(t, err, "train ham message")
err = f.TrainMessage(spamf, spamsize, false)
tcheck(t, err, "train spam message")
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
err = f.TrainMessage(spamf, spamsize, true)
tcheck(t, err, "train spam message")
if !f.modified {
t.Fatalf("filter not modified after training")
}
if !f.bloom.Modified() {
t.Fatalf("bloom filter not modified after training")
}
err = f.Save()
tcheck(t, err, "save filter")
if f.modified || f.bloom.Modified() {
t.Fatalf("filter or bloom filter still modified after save")
}
// Classify and verify.
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
tcheck(t, err, "classify ham")
if prob > 0.1 {
t.Fatalf("got prob %v, expected <= 0.1", prob)
}
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
tcheck(t, err, "classify spam")
if prob < 0.9 {
t.Fatalf("got prob %v, expected >= 0.9", prob)
}
// Untrain ham & spam.
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
err = f.UntrainMessage(hamf, hamsize, true)
tcheck(t, err, "untrain ham message")
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
err = f.UntrainMessage(hamf, spamsize, true)
tcheck(t, err, "untrain ham message")
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
err = f.UntrainMessage(spamf, spamsize, true)
tcheck(t, err, "untrain spam message")
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
err = f.UntrainMessage(spamf, spamsize, true)
tcheck(t, err, "untrain spam message")
if !f.modified {
t.Fatalf("filter not modified after untraining")
}
// Classify again, should be unknown.
_, err = hamf.Seek(0, 0)
tcheck(t, err, "seek ham message")
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
tcheck(t, err, "classify ham")
if math.Abs(prob-0.5) > 0.1 {
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
}
_, err = spamf.Seek(0, 0)
tcheck(t, err, "seek spam message")
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
tcheck(t, err, "classify spam")
if math.Abs(prob-0.5) > 0.1 {
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
}
err = f.Close()
tcheck(t, err, "close filter")
}

323
junk/parse.go Normal file
View File

@ -0,0 +1,323 @@
package junk
// see https://en.wikipedia.org/wiki/Naive_Bayes_spam_filtering
// - todo: better html parsing?
// - todo: try reading text in pdf?
// - todo: try to detect language, have words per language? can be in the same dictionary. currently my dictionary is biased towards treating english as spam.
import (
"bufio"
"fmt"
"io"
"os"
"strings"
"unicode"
"golang.org/x/net/html"
"github.com/mjl-/mox/message"
)
func (f *Filter) tokenizeMail(path string) (bool, map[string]struct{}, error) {
mf, err := os.Open(path)
if err != nil {
return false, nil, err
}
defer mf.Close()
fi, err := mf.Stat()
if err != nil {
return false, nil, err
}
p, _ := message.EnsurePart(mf, fi.Size())
words, err := f.ParseMessage(p)
return true, words, err
}
// ParseMessage reads a mail and returns a map with words.
func (f *Filter) ParseMessage(p message.Part) (map[string]struct{}, error) {
metaWords := map[string]struct{}{}
textWords := map[string]struct{}{}
htmlWords := map[string]struct{}{}
hdrs, err := p.Header()
if err != nil {
return nil, fmt.Errorf("parsing headers: %v", err)
}
// Add words from the header, annotated with <field>+":".
// todo: add whether header is dkim-verified?
for k, l := range hdrs {
for _, h := range l {
switch k {
case "From", "To", "Cc", "Bcc", "Reply-To", "Subject", "Sender", "Return-Path":
// case "Subject", "To":
default:
continue
}
words := map[string]struct{}{}
f.tokenizeText(strings.NewReader(h), words)
for w := range words {
if len(w) <= 3 {
continue
}
metaWords[k+":"+w] = struct{}{}
}
}
}
if err := f.mailParse(p, metaWords, textWords, htmlWords); err != nil {
return nil, fmt.Errorf("parsing message: %w", err)
}
for w := range metaWords {
textWords[w] = struct{}{}
}
for w := range htmlWords {
textWords[w] = struct{}{}
}
return textWords, nil
}
// mailParse looks through the mail for the first text and html parts, and tokenizes their words.
func (f *Filter) mailParse(p message.Part, metaWords, textWords, htmlWords map[string]struct{}) error {
ct := p.MediaType + "/" + p.MediaSubType
if ct == "TEXT/HTML" {
err := f.tokenizeHTML(p.Reader(), metaWords, htmlWords)
// log.Printf("html parsed, words %v", htmlWords)
return err
}
if ct == "" || strings.HasPrefix(ct, "TEXT/") {
err := f.tokenizeText(p.Reader(), textWords)
// log.Printf("text parsed, words %v", textWords)
return err
}
if p.Message != nil {
// Nested message, happens for forwarding.
if err := p.SetMessageReaderAt(); err != nil {
return fmt.Errorf("setting reader on nested message: %w", err)
}
return f.mailParse(*p.Message, metaWords, textWords, htmlWords)
}
for _, sp := range p.Parts {
if err := f.mailParse(sp, metaWords, textWords, htmlWords); err != nil {
return err
}
}
return nil
}
func looksRandom(s string) bool {
// Random strings, eg 2fvu9stm9yxhnlu. ASCII only and a many consonants in a stretch.
stretch := 0
const consonants = "bcdfghjklmnpqrstvwxyzBCDFGHJKLMNPQRSTVWXYZ23456789" // 0 and 1 may be used as o and l/i
stretches := 0
for _, c := range s {
if c >= 0x80 {
return false
}
if strings.ContainsRune(consonants, c) {
stretch++
continue
}
if stretch >= 6 {
stretches++
}
stretch = 0
}
if stretch >= 6 {
stretches++
}
return stretches > 0
}
func looksNumeric(s string) bool {
s = strings.TrimPrefix(s, "0x") // Hexadecimal.
var digits, hex, other, digitstretch, maxdigitstretch int
for _, c := range s {
if c >= '0' && c <= '9' {
digits++
digitstretch++
continue
} else if c >= 'a' && c <= 'f' || c >= 'A' && c <= 'F' {
hex++
} else {
other++
}
if digitstretch > maxdigitstretch {
maxdigitstretch = digitstretch
}
}
if digitstretch > maxdigitstretch {
maxdigitstretch = digitstretch
}
return maxdigitstretch >= 4 || other == 0 && maxdigitstretch >= 3
}
func (f *Filter) tokenizeText(r io.Reader, words map[string]struct{}) error {
b := &strings.Builder{}
var prev string
var prev2 string
add := func() {
defer b.Reset()
if b.Len() <= 2 {
return
}
s := b.String()
s = strings.Trim(s, "'")
var nondigit bool
for _, c := range s {
if !unicode.IsDigit(c) {
nondigit = true
break
}
}
if !(nondigit && len(s) > 2) {
return
}
if looksRandom(s) {
return
}
if looksNumeric(s) {
return
}
// todo: do something for URLs, parse them? keep their domain only?
if f.Threegrams && prev2 != "" && prev != "" {
words[prev2+" "+prev+" "+s] = struct{}{}
}
if f.Twograms && prev != "" {
words[prev+" "+s] = struct{}{}
}
if f.Onegrams {
words[s] = struct{}{}
}
prev2 = prev
prev = s
}
br := bufio.NewReader(r)
peekLetter := func() bool {
c, _, err := br.ReadRune()
br.UnreadRune()
return err == nil && unicode.IsLetter(c)
}
for {
c, _, err := br.ReadRune()
if err == io.EOF {
break
}
if err != nil {
return err
}
if !unicode.IsLetter(c) && !unicode.IsDigit(c) && (c != '\'' || b.Len() > 0 && peekLetter()) {
add()
} else {
b.WriteRune(unicode.ToLower(c))
}
}
add()
return nil
}
// tokenizeHTML parses html, and tokenizes its text into words.
func (f *Filter) tokenizeHTML(r io.Reader, meta, words map[string]struct{}) error {
htmlReader := &htmlTextReader{
t: html.NewTokenizer(r),
meta: map[string]struct{}{},
}
return f.tokenizeText(htmlReader, words)
}
type htmlTextReader struct {
t *html.Tokenizer
meta map[string]struct{}
tagStack []string
buf []byte
err error
}
func (r *htmlTextReader) Read(buf []byte) (n int, err error) {
// todo: deal with invalid html better. the tokenizer is just tokenizing, we need to fix up the nesting etc. eg, rules say some elements close certain open elements.
// todo: deal with inline elements? they shouldn't cause a word break.
give := func(nbuf []byte) (int, error) {
n := len(buf)
if n > len(nbuf) {
n = len(nbuf)
}
copy(buf, nbuf[:n])
nbuf = nbuf[n:]
if len(nbuf) < cap(r.buf) {
r.buf = r.buf[:len(nbuf)]
} else {
r.buf = make([]byte, len(nbuf), 3*len(nbuf)/2)
}
copy(r.buf, nbuf)
return n, nil
}
if len(r.buf) > 0 {
return give(r.buf)
}
if r.err != nil {
return 0, r.err
}
for {
switch r.t.Next() {
case html.ErrorToken:
r.err = r.t.Err()
return 0, r.err
case html.TextToken:
if len(r.tagStack) > 0 {
switch r.tagStack[len(r.tagStack)-1] {
case "script", "style", "svg":
continue
}
}
buf := r.t.Text()
if len(buf) > 0 {
return give(buf)
}
case html.StartTagToken:
tagBuf, moreAttr := r.t.TagName()
tag := string(tagBuf)
//log.Printf("tag %q %v", tag, r.tagStack)
if tag == "img" && moreAttr {
var key, val []byte
for moreAttr {
key, val, moreAttr = r.t.TagAttr()
if string(key) == "alt" && len(val) > 0 {
return give(val)
}
}
}
// Empty elements, https://developer.mozilla.org/en-US/docs/Glossary/Empty_element
switch tag {
case "area", "base", "br", "col", "embed", "hr", "img", "input", "link", "meta", "param", "source", "track", "wbr":
continue
}
r.tagStack = append(r.tagStack, tag)
case html.EndTagToken:
// log.Printf("tag pop %v", r.tagStack)
if len(r.tagStack) > 0 {
r.tagStack = r.tagStack[:len(r.tagStack)-1]
}
case html.SelfClosingTagToken:
case html.CommentToken:
case html.DoctypeToken:
}
}
}

33
junk/parse_test.go Normal file
View File

@ -0,0 +1,33 @@
package junk
import (
"os"
"testing"
)
func FuzzParseMessage(f *testing.F) {
f.Add("")
add := func(p string) {
buf, err := os.ReadFile(p)
if err != nil {
f.Fatalf("reading file %q: %v", p, err)
}
f.Add(string(buf))
}
add("../testdata/junk/parse.eml")
add("../testdata/junk/parse2.eml")
add("../testdata/junk/parse3.eml")
dbPath := "../testdata/junk/parse.db"
bloomPath := "../testdata/junk/parse.bloom"
os.Remove(dbPath)
os.Remove(bloomPath)
params := Params{Twograms: true}
jf, err := NewFilter(xlog, params, dbPath, bloomPath)
if err != nil {
f.Fatalf("new filter: %v", err)
}
f.Fuzz(func(t *testing.T, s string) {
jf.tokenizeMail(s)
})
}