mirror of
https://github.com/mjl-/mox.git
synced 2025-07-12 18:24:35 +03:00
mox!
This commit is contained in:
165
junk/bloom.go
Normal file
165
junk/bloom.go
Normal file
@ -0,0 +1,165 @@
|
||||
package junk
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/crypto/blake2b"
|
||||
)
|
||||
|
||||
// see https://en.wikipedia.org/wiki/Bloom_filter
|
||||
|
||||
var errWidth = errors.New("k and width wider than 256 bits and width not more than 32")
|
||||
var errPowerOfTwo = errors.New("data not a power of two")
|
||||
|
||||
// Bloom is a bloom filter.
|
||||
type Bloom struct {
|
||||
data []byte
|
||||
k int // Number of bits we store/lookup in the bloom filter per value.
|
||||
w int // Number of bits needed to address a single bit position.
|
||||
modified bool
|
||||
}
|
||||
|
||||
func bloomWidth(fileSize int) int {
|
||||
w := 0
|
||||
for bits := uint32(fileSize * 8); bits > 1; bits >>= 1 {
|
||||
w++
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
// BloomValid returns an error if the bloom file parameters are not correct.
|
||||
func BloomValid(fileSize int, k int) error {
|
||||
_, err := bloomValid(fileSize, k)
|
||||
return err
|
||||
}
|
||||
|
||||
func bloomValid(fileSize, k int) (int, error) {
|
||||
w := bloomWidth(fileSize)
|
||||
if 1<<w != fileSize*8 {
|
||||
return 0, errPowerOfTwo
|
||||
}
|
||||
if k*w > 256 || w > 32 {
|
||||
return 0, errWidth
|
||||
}
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// NewBloom returns a bloom filter with given initial data.
|
||||
//
|
||||
// The number of bits in data must be a power of 2.
|
||||
// K is the number of "hashes" (bits) to store/lookup for each value stored.
|
||||
// Width is calculated as the number of bits needed to represent a single bit/hash
|
||||
// position in the data.
|
||||
//
|
||||
// For each value stored/looked up, a hash over the value is calculated. The hash
|
||||
// is split into "k" values that are "width" bits wide, each used to lookup a bit.
|
||||
// K * width must not exceed 256.
|
||||
func NewBloom(data []byte, k int) (*Bloom, error) {
|
||||
w, err := bloomValid(len(data), k)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Bloom{
|
||||
data: data,
|
||||
k: k,
|
||||
w: w,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Bloom) Add(s string) {
|
||||
h := hash([]byte(s), b.w)
|
||||
for i := 0; i < b.k; i++ {
|
||||
b.set(h.nextPos())
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bloom) Has(s string) bool {
|
||||
h := hash([]byte(s), b.w)
|
||||
for i := 0; i < b.k; i++ {
|
||||
if !b.has(h.nextPos()) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *Bloom) Bytes() []byte {
|
||||
return b.data
|
||||
}
|
||||
|
||||
func (b *Bloom) Modified() bool {
|
||||
return b.modified
|
||||
}
|
||||
|
||||
// Ones returns the number of ones.
|
||||
func (b *Bloom) Ones() (n int) {
|
||||
for _, d := range b.data {
|
||||
for i := 0; i < 8; i++ {
|
||||
if d&1 != 0 {
|
||||
n++
|
||||
}
|
||||
d >>= 1
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (b *Bloom) Write(path string) error {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0660)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write(b.data); err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
b.modified = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bloom) has(p int) bool {
|
||||
v := b.data[p>>3] >> (7 - (p & 7))
|
||||
return v&1 != 0
|
||||
}
|
||||
|
||||
func (b *Bloom) set(p int) {
|
||||
by := p >> 3
|
||||
bi := p & 0x7
|
||||
var v byte = 1 << (7 - bi)
|
||||
if b.data[by]&v == 0 {
|
||||
b.data[by] |= v
|
||||
b.modified = true
|
||||
}
|
||||
}
|
||||
|
||||
type bits struct {
|
||||
width int // Number of bits for each position.
|
||||
buf []byte // Remaining bytes to use for next position.
|
||||
cur uint64 // Bits to read next position from. Replenished from buf.
|
||||
ncur int // Number of bits available in cur. We consume the highest bits first.
|
||||
}
|
||||
|
||||
func hash(v []byte, width int) *bits {
|
||||
buf := blake2b.Sum256(v)
|
||||
return &bits{width: width, buf: buf[:]}
|
||||
}
|
||||
|
||||
// nextPos returns the next bit position.
|
||||
func (b *bits) nextPos() (v int) {
|
||||
if b.width > b.ncur {
|
||||
for len(b.buf) > 0 && b.ncur < 64-8 {
|
||||
b.cur <<= 8
|
||||
b.cur |= uint64(b.buf[0])
|
||||
b.ncur += 8
|
||||
b.buf = b.buf[1:]
|
||||
}
|
||||
}
|
||||
v = int((b.cur >> (b.ncur - b.width)) & ((1 << b.width) - 1))
|
||||
b.ncur -= b.width
|
||||
return v
|
||||
}
|
136
junk/bloom_test.go
Normal file
136
junk/bloom_test.go
Normal file
@ -0,0 +1,136 @@
|
||||
package junk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBloom(t *testing.T) {
|
||||
if err := BloomValid(3, 10); err == nil {
|
||||
t.Fatalf("missing error for invalid bloom filter size")
|
||||
}
|
||||
|
||||
_, err := NewBloom(make([]byte, 3), 10)
|
||||
if err == nil {
|
||||
t.Fatalf("missing error for invalid bloom filter size")
|
||||
}
|
||||
|
||||
b, err := NewBloom(make([]byte, 256), 5)
|
||||
if err != nil {
|
||||
t.Fatalf("newbloom: %s", err)
|
||||
}
|
||||
|
||||
absent := func(v string) {
|
||||
t.Helper()
|
||||
if b.Has(v) {
|
||||
t.Fatalf("should be absent: %q", v)
|
||||
}
|
||||
}
|
||||
|
||||
present := func(v string) {
|
||||
t.Helper()
|
||||
if !b.Has(v) {
|
||||
t.Fatalf("should be present: %q", v)
|
||||
}
|
||||
}
|
||||
|
||||
absent("test")
|
||||
if b.Modified() {
|
||||
t.Fatalf("bloom filter already modified?")
|
||||
}
|
||||
b.Add("test")
|
||||
present("test")
|
||||
present("test")
|
||||
words := []string{}
|
||||
for i := 'a'; i <= 'z'; i++ {
|
||||
words = append(words, fmt.Sprintf("%c", i))
|
||||
}
|
||||
for _, w := range words {
|
||||
absent(w)
|
||||
b.Add(w)
|
||||
present(w)
|
||||
}
|
||||
for _, w := range words {
|
||||
present(w)
|
||||
}
|
||||
if !b.Modified() {
|
||||
t.Fatalf("bloom filter was not modified?")
|
||||
}
|
||||
|
||||
//log.Infof("ones: %d, m %d", b.Ones(), len(b.Bytes())*8)
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
b := &bits{width: 1, buf: []byte{0xff, 0xff}}
|
||||
for i := 0; i < 16; i++ {
|
||||
if b.nextPos() != 1 {
|
||||
t.Fatalf("pos not 1")
|
||||
}
|
||||
}
|
||||
b = &bits{width: 2, buf: []byte{0xff, 0xff}}
|
||||
for i := 0; i < 8; i++ {
|
||||
if b.nextPos() != 0b11 {
|
||||
t.Fatalf("pos not 0b11")
|
||||
}
|
||||
}
|
||||
|
||||
b = &bits{width: 1, buf: []byte{0b10101010, 0b10101010}}
|
||||
for i := 0; i < 16; i++ {
|
||||
if b.nextPos() != ((i + 1) % 2) {
|
||||
t.Fatalf("bad pos")
|
||||
}
|
||||
}
|
||||
b = &bits{width: 2, buf: []byte{0b10101010, 0b10101010}}
|
||||
for i := 0; i < 8; i++ {
|
||||
if b.nextPos() != 0b10 {
|
||||
t.Fatalf("pos not 0b10")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
b := &Bloom{
|
||||
data: []byte{
|
||||
0b10101010,
|
||||
0b00000000,
|
||||
0b11111111,
|
||||
0b01010101,
|
||||
},
|
||||
}
|
||||
for i := 0; i < 8; i++ {
|
||||
v := b.has(i)
|
||||
if v != (i%2 == 0) {
|
||||
t.Fatalf("bad has")
|
||||
}
|
||||
}
|
||||
for i := 8; i < 16; i++ {
|
||||
if b.has(i) {
|
||||
t.Fatalf("bad has")
|
||||
}
|
||||
}
|
||||
for i := 16; i < 24; i++ {
|
||||
if !b.has(i) {
|
||||
t.Fatalf("bad has")
|
||||
}
|
||||
}
|
||||
for i := 24; i < 32; i++ {
|
||||
v := b.has(i)
|
||||
if v != (i%2 != 0) {
|
||||
t.Fatalf("bad has")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnes(t *testing.T) {
|
||||
ones := func(b *Bloom, x int) {
|
||||
t.Helper()
|
||||
n := b.Ones()
|
||||
if n != x {
|
||||
t.Fatalf("ones: got %d, expected %d", n, x)
|
||||
}
|
||||
}
|
||||
ones(&Bloom{data: []byte{0b10101010}}, 4)
|
||||
ones(&Bloom{data: []byte{0b01010101}}, 4)
|
||||
ones(&Bloom{data: []byte{0b11111111}}, 8)
|
||||
ones(&Bloom{data: []byte{0b00000000}}, 0)
|
||||
}
|
726
junk/filter.go
Normal file
726
junk/filter.go
Normal file
@ -0,0 +1,726 @@
|
||||
// Package junk implements a bayesian spam filter.
|
||||
//
|
||||
// A message can be parsed into words. Words (or pairs or triplets) can be used
|
||||
// to train the filter or to classify the message as ham or spam. Training
|
||||
// records the words in the database as ham/spam. Classifying consists of
|
||||
// calculating the ham/spam probability by combining the words in the message
|
||||
// with their ham/spam status.
|
||||
package junk
|
||||
|
||||
// todo: look at inverse chi-square function? see https://www.linuxjournal.com/article/6467
|
||||
// todo: perhaps: whether anchor text in links in html are different from the url
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"os"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/mjl-/bstore"
|
||||
|
||||
"github.com/mjl-/mox/message"
|
||||
"github.com/mjl-/mox/mlog"
|
||||
)
|
||||
|
||||
var (
|
||||
xlog = mlog.New("junk")
|
||||
|
||||
errBadContentType = errors.New("bad content-type") // sure sign of spam
|
||||
errClosed = errors.New("filter is closed")
|
||||
)
|
||||
|
||||
type word struct {
|
||||
Ham uint32
|
||||
Spam uint32
|
||||
}
|
||||
|
||||
type wordscore struct {
|
||||
Word string
|
||||
Ham uint32
|
||||
Spam uint32
|
||||
}
|
||||
|
||||
// Params holds parameters for the filter. Most are at test-time. The first are
|
||||
// used during parsing and training.
|
||||
type Params struct {
|
||||
Onegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for single words."`
|
||||
Twograms bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each two consecutive words."`
|
||||
Threegrams bool `sconf:"optional" sconf-doc:"Track ham/spam ranking for each three consecutive words."`
|
||||
MaxPower float64 `sconf-doc:"Maximum power a word (combination) can have. If spaminess is 0.99, and max power is 0.1, spaminess of the word will be set to 0.9. Similar for ham words."`
|
||||
TopWords int `sconf-doc:"Number of most spammy/hammy words to use for calculating probability. E.g. 10."`
|
||||
IgnoreWords float64 `sconf:"optional" sconf-doc:"Ignore words that are this much away from 0.5 haminess/spaminess. E.g. 0.1, causing word (combinations) of 0.4 to 0.6 to be ignored."`
|
||||
RareWords int `sconf:"optional" sconf-doc:"Occurrences in word database until a word is considered rare and its influence in calculating probability reduced. E.g. 1 or 2."`
|
||||
}
|
||||
|
||||
type Filter struct {
|
||||
Params
|
||||
|
||||
log *mlog.Log // For logging cid.
|
||||
closed bool
|
||||
modified bool // Whether any modifications are pending. Cleared by Save.
|
||||
hams, spams uint32 // Message count, stored in db under word "-".
|
||||
cache map[string]word // Words read from database or during training.
|
||||
changed map[string]word // Words modified during training.
|
||||
dbPath, bloomPath string
|
||||
db *bstore.DB // Always open on a filter.
|
||||
bloom *Bloom // Only opened when writing.
|
||||
isNew bool // Set for new filters until their first sync to disk. For faster writing.
|
||||
}
|
||||
|
||||
func (f *Filter) ensureBloom() error {
|
||||
if f.bloom != nil {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
f.bloom, err = openBloom(f.bloomPath)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close first saves the filter if it has modifications, then closes the database
|
||||
// connection and releases the bloom filter.
|
||||
func (f *Filter) Close() error {
|
||||
if f.closed {
|
||||
return errClosed
|
||||
}
|
||||
var err error
|
||||
if f.modified {
|
||||
err = f.Save()
|
||||
}
|
||||
if err != nil {
|
||||
f.db.Close()
|
||||
} else {
|
||||
err = f.db.Close()
|
||||
}
|
||||
*f = Filter{log: f.log, closed: true}
|
||||
return err
|
||||
}
|
||||
|
||||
func OpenFilter(log *mlog.Log, params Params, dbPath, bloomPath string, loadBloom bool) (*Filter, error) {
|
||||
var bloom *Bloom
|
||||
if loadBloom {
|
||||
var err error
|
||||
bloom, err = openBloom(bloomPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if fi, err := os.Stat(bloomPath); err == nil {
|
||||
if err := BloomValid(int(fi.Size()), bloomK); err != nil {
|
||||
return nil, fmt.Errorf("bloom: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := openDB(dbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open database: %s", err)
|
||||
}
|
||||
|
||||
f := &Filter{
|
||||
Params: params,
|
||||
log: log,
|
||||
cache: map[string]word{},
|
||||
changed: map[string]word{},
|
||||
dbPath: dbPath,
|
||||
bloomPath: bloomPath,
|
||||
db: db,
|
||||
bloom: bloom,
|
||||
}
|
||||
err = f.db.Read(func(tx *bstore.Tx) error {
|
||||
wc := wordscore{Word: "-"}
|
||||
err := tx.Get(&wc)
|
||||
f.hams = wc.Ham
|
||||
f.spams = wc.Spam
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("looking up ham/spam message count: %s", err)
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// NewFilter creates a new filter with empty bloom filter and database files. The
|
||||
// filter is marked as new until the first save, will be done automatically if
|
||||
// TrainDirs is called. If the bloom and/or database files exist, an error is
|
||||
// returned.
|
||||
func NewFilter(log *mlog.Log, params Params, dbPath, bloomPath string) (*Filter, error) {
|
||||
var err error
|
||||
if _, err := os.Stat(bloomPath); err == nil {
|
||||
return nil, fmt.Errorf("bloom filter already exists on disk: %s", bloomPath)
|
||||
} else if _, err := os.Stat(dbPath); err == nil {
|
||||
return nil, fmt.Errorf("database file already exists on disk: %s", dbPath)
|
||||
}
|
||||
|
||||
bloomSizeBytes := 4 * 1024 * 1024
|
||||
if err := BloomValid(bloomSizeBytes, bloomK); err != nil {
|
||||
return nil, fmt.Errorf("bloom: %s", err)
|
||||
}
|
||||
bf, err := os.Create(bloomPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating bloom file: %w", err)
|
||||
}
|
||||
if err := bf.Truncate(4 * 1024 * 1024); err != nil {
|
||||
bf.Close()
|
||||
os.Remove(bloomPath)
|
||||
return nil, fmt.Errorf("making empty bloom filter: %s", err)
|
||||
}
|
||||
bf.Close()
|
||||
|
||||
db, err := newDB(dbPath)
|
||||
if err != nil {
|
||||
os.Remove(bloomPath)
|
||||
os.Remove(dbPath)
|
||||
return nil, fmt.Errorf("open database: %s", err)
|
||||
}
|
||||
|
||||
words := map[string]word{} // f.changed is set to new map after training
|
||||
f := &Filter{
|
||||
Params: params,
|
||||
log: log,
|
||||
modified: true, // Ensure ham/spam message count is added for new filter.
|
||||
cache: words,
|
||||
changed: words,
|
||||
dbPath: dbPath,
|
||||
bloomPath: bloomPath,
|
||||
db: db,
|
||||
isNew: true,
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
const bloomK = 10
|
||||
|
||||
func openBloom(path string) (*Bloom, error) {
|
||||
buf, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading bloom file: %w", err)
|
||||
}
|
||||
return NewBloom(buf, bloomK)
|
||||
}
|
||||
|
||||
func newDB(path string) (db *bstore.DB, rerr error) {
|
||||
// Remove any existing files.
|
||||
os.Remove(path)
|
||||
|
||||
defer func() {
|
||||
if rerr != nil {
|
||||
if db != nil {
|
||||
db.Close()
|
||||
}
|
||||
db = nil
|
||||
os.Remove(path)
|
||||
}
|
||||
}()
|
||||
|
||||
db, err := bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open new database: %w", err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func openDB(path string) (*bstore.DB, error) {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
return nil, fmt.Errorf("stat db file: %w", err)
|
||||
}
|
||||
return bstore.Open(path, &bstore.Options{Timeout: 5 * time.Second, Perm: 0660}, wordscore{})
|
||||
}
|
||||
|
||||
// Save stores modifications, e.g. from training, to the database and bloom
|
||||
// filter files.
|
||||
func (f *Filter) Save() error {
|
||||
if f.closed {
|
||||
return errClosed
|
||||
}
|
||||
if !f.modified {
|
||||
return nil
|
||||
}
|
||||
|
||||
if f.bloom != nil && f.bloom.Modified() {
|
||||
if err := f.bloom.Write(f.bloomPath); err != nil {
|
||||
return fmt.Errorf("writing bloom filter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We need to insert sequentially for reasonable performance.
|
||||
words := make([]string, len(f.changed))
|
||||
i := 0
|
||||
for w := range f.changed {
|
||||
words[i] = w
|
||||
i++
|
||||
}
|
||||
sort.Slice(words, func(i, j int) bool {
|
||||
return words[i] < words[j]
|
||||
})
|
||||
|
||||
f.log.Info("inserting words in junkfilter db", mlog.Field("words", len(f.changed)))
|
||||
// start := time.Now()
|
||||
if f.isNew {
|
||||
if err := f.db.HintAppend(true, wordscore{}); err != nil {
|
||||
f.log.Errorx("hint appendonly", err)
|
||||
} else {
|
||||
defer f.db.HintAppend(false, wordscore{})
|
||||
}
|
||||
}
|
||||
err := f.db.Write(func(tx *bstore.Tx) error {
|
||||
update := func(w string, ham, spam uint32) error {
|
||||
if f.isNew {
|
||||
return tx.Insert(&wordscore{w, ham, spam})
|
||||
}
|
||||
|
||||
wc := wordscore{w, 0, 0}
|
||||
err := tx.Get(&wc)
|
||||
if err == bstore.ErrAbsent {
|
||||
return tx.Insert(&wordscore{w, ham, spam})
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Update(&wordscore{w, wc.Ham + ham, wc.Spam + spam})
|
||||
}
|
||||
if err := update("-", f.hams, f.spams); err != nil {
|
||||
return fmt.Errorf("storing total ham/spam message count: %s", err)
|
||||
}
|
||||
|
||||
for _, w := range words {
|
||||
c := f.changed[w]
|
||||
if err := update(w, c.Ham, c.Spam); err != nil {
|
||||
return fmt.Errorf("updating ham/spam count: %s", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating database: %w", err)
|
||||
}
|
||||
|
||||
f.changed = map[string]word{}
|
||||
f.modified = false
|
||||
f.isNew = false
|
||||
// f.log.Info("wrote filter to db", mlog.Field("duration", time.Since(start)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadWords(db *bstore.DB, l []string, dst map[string]word) error {
|
||||
sort.Slice(l, func(i, j int) bool {
|
||||
return l[i] < l[j]
|
||||
})
|
||||
|
||||
err := db.Read(func(tx *bstore.Tx) error {
|
||||
for _, w := range l {
|
||||
wc := wordscore{Word: w}
|
||||
if err := tx.Get(&wc); err == nil {
|
||||
dst[w] = word{wc.Ham, wc.Spam}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching words: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClassifyWords returns the spam probability for the given words, and number of recognized ham and spam words.
|
||||
func (f *Filter) ClassifyWords(words map[string]struct{}) (probability float64, nham, nspam int, rerr error) {
|
||||
if f.closed {
|
||||
return 0, 0, 0, errClosed
|
||||
}
|
||||
|
||||
type xword struct {
|
||||
Word string
|
||||
R float64
|
||||
}
|
||||
|
||||
var hamHigh float64 = 0
|
||||
var spamLow float64 = 1
|
||||
var topHam []xword
|
||||
var topSpam []xword
|
||||
|
||||
// Find words that should be in the database.
|
||||
lookupWords := []string{}
|
||||
expect := map[string]struct{}{}
|
||||
unknowns := map[string]struct{}{}
|
||||
totalUnknown := 0
|
||||
for w := range words {
|
||||
if f.bloom != nil && !f.bloom.Has(w) {
|
||||
totalUnknown++
|
||||
if len(unknowns) < 50 {
|
||||
unknowns[w] = struct{}{}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if _, ok := f.cache[w]; ok {
|
||||
continue
|
||||
}
|
||||
lookupWords = append(lookupWords, w)
|
||||
expect[w] = struct{}{}
|
||||
}
|
||||
if len(unknowns) > 0 {
|
||||
f.log.Debug("unknown words in bloom filter, showing max 50", mlog.Field("words", unknowns), mlog.Field("totalunknown", totalUnknown), mlog.Field("totalwords", len(words)))
|
||||
}
|
||||
|
||||
// Fetch words from database.
|
||||
fetched := map[string]word{}
|
||||
if len(lookupWords) > 0 {
|
||||
if err := loadWords(f.db, lookupWords, fetched); err != nil {
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
for w, c := range fetched {
|
||||
delete(expect, w)
|
||||
f.cache[w] = c
|
||||
}
|
||||
f.log.Debug("unknown words in db", mlog.Field("words", expect), mlog.Field("totalunknown", len(expect)), mlog.Field("totalwords", len(words)))
|
||||
}
|
||||
|
||||
for w := range words {
|
||||
c, ok := f.cache[w]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
var wS, wH float64
|
||||
if f.spams > 0 {
|
||||
wS = float64(c.Spam) / float64(f.spams)
|
||||
}
|
||||
if f.hams > 0 {
|
||||
wH = float64(c.Ham) / float64(f.hams)
|
||||
}
|
||||
r := wS / (wS + wH)
|
||||
|
||||
if r < f.MaxPower {
|
||||
r = f.MaxPower
|
||||
} else if r >= 1-f.MaxPower {
|
||||
r = 1 - f.MaxPower
|
||||
}
|
||||
|
||||
if c.Ham+c.Spam <= uint32(f.RareWords) {
|
||||
// Reduce the power of rare words.
|
||||
r += float64(1+uint32(f.RareWords)-(c.Ham+c.Spam)) * (0.5 - r) / 10
|
||||
}
|
||||
if math.Abs(0.5-r) < f.IgnoreWords {
|
||||
continue
|
||||
}
|
||||
if r < 0.5 {
|
||||
if len(topHam) >= f.TopWords && r > hamHigh {
|
||||
continue
|
||||
}
|
||||
topHam = append(topHam, xword{w, r})
|
||||
if r > hamHigh {
|
||||
hamHigh = r
|
||||
}
|
||||
} else if r > 0.5 {
|
||||
if len(topSpam) >= f.TopWords && r < spamLow {
|
||||
continue
|
||||
}
|
||||
topSpam = append(topSpam, xword{w, r})
|
||||
if r < spamLow {
|
||||
spamLow = r
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(topHam, func(i, j int) bool {
|
||||
a, b := topHam[i], topHam[j]
|
||||
if a.R == b.R {
|
||||
return len(a.Word) > len(b.Word)
|
||||
}
|
||||
return a.R < b.R
|
||||
})
|
||||
sort.Slice(topSpam, func(i, j int) bool {
|
||||
a, b := topSpam[i], topSpam[j]
|
||||
if a.R == b.R {
|
||||
return len(a.Word) > len(b.Word)
|
||||
}
|
||||
return a.R > b.R
|
||||
})
|
||||
|
||||
nham = f.TopWords
|
||||
if nham > len(topHam) {
|
||||
nham = len(topHam)
|
||||
}
|
||||
nspam = f.TopWords
|
||||
if nspam > len(topSpam) {
|
||||
nspam = len(topSpam)
|
||||
}
|
||||
topHam = topHam[:nham]
|
||||
topSpam = topSpam[:nspam]
|
||||
|
||||
var eta float64
|
||||
for _, x := range topHam {
|
||||
eta += math.Log(1-x.R) - math.Log(x.R)
|
||||
}
|
||||
for _, x := range topSpam {
|
||||
eta += math.Log(1-x.R) - math.Log(x.R)
|
||||
}
|
||||
|
||||
f.log.Debug("top words", mlog.Field("hams", topHam), mlog.Field("spams", topSpam))
|
||||
|
||||
prob := 1 / (1 + math.Pow(math.E, eta))
|
||||
return prob, len(topHam), len(topSpam), nil
|
||||
}
|
||||
|
||||
// ClassifyMessagePath is a convenience wrapper for calling ClassifyMessage on a file.
|
||||
func (f *Filter) ClassifyMessagePath(path string) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
||||
if f.closed {
|
||||
return 0, nil, 0, 0, errClosed
|
||||
}
|
||||
|
||||
mf, err := os.Open(path)
|
||||
if err != nil {
|
||||
return 0, nil, 0, 0, err
|
||||
}
|
||||
defer mf.Close()
|
||||
fi, err := mf.Stat()
|
||||
if err != nil {
|
||||
return 0, nil, 0, 0, err
|
||||
}
|
||||
return f.ClassifyMessageReader(mf, fi.Size())
|
||||
}
|
||||
|
||||
func (f *Filter) ClassifyMessageReader(mf io.ReaderAt, size int64) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
||||
m, err := message.EnsurePart(mf, size)
|
||||
if err != nil && errors.Is(err, message.ErrBadContentType) {
|
||||
// Invalid content-type header is a sure sign of spam.
|
||||
//f.log.Infox("parsing content", err)
|
||||
return 1, nil, 0, 0, nil
|
||||
}
|
||||
return f.ClassifyMessage(m)
|
||||
}
|
||||
|
||||
// ClassifyMessage parses the mail message in r and returns the spam probability
|
||||
// (between 0 and 1), along with the tokenized words found in the message, and the
|
||||
// number of recognized ham and spam words.
|
||||
func (f *Filter) ClassifyMessage(m message.Part) (probability float64, words map[string]struct{}, nham, nspam int, rerr error) {
|
||||
var err error
|
||||
words, err = f.ParseMessage(m)
|
||||
if err != nil {
|
||||
return 0, nil, 0, 0, err
|
||||
}
|
||||
|
||||
probability, nham, nspam, err = f.ClassifyWords(words)
|
||||
return probability, words, nham, nspam, err
|
||||
}
|
||||
|
||||
// Train adds the words of a single message to the filter.
|
||||
func (f *Filter) Train(ham bool, words map[string]struct{}) error {
|
||||
if err := f.ensureBloom(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var lwords []string
|
||||
|
||||
for w := range words {
|
||||
if !f.bloom.Has(w) {
|
||||
f.bloom.Add(w)
|
||||
continue
|
||||
}
|
||||
if _, ok := f.cache[w]; !ok {
|
||||
lwords = append(lwords, w)
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.loadCache(lwords); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.modified = true
|
||||
if ham {
|
||||
f.hams++
|
||||
} else {
|
||||
f.spams++
|
||||
}
|
||||
|
||||
for w := range words {
|
||||
c := f.cache[w]
|
||||
if ham {
|
||||
c.Ham++
|
||||
} else {
|
||||
c.Spam++
|
||||
}
|
||||
f.cache[w] = c
|
||||
f.changed[w] = c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Filter) TrainMessage(r io.ReaderAt, size int64, ham bool) error {
|
||||
p, _ := message.EnsurePart(r, size)
|
||||
words, err := f.ParseMessage(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing mail contents: %v", err)
|
||||
}
|
||||
return f.Train(ham, words)
|
||||
}
|
||||
|
||||
func (f *Filter) UntrainMessage(r io.ReaderAt, size int64, ham bool) error {
|
||||
p, _ := message.EnsurePart(r, size)
|
||||
words, err := f.ParseMessage(p)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing mail contents: %v", err)
|
||||
}
|
||||
return f.Untrain(ham, words)
|
||||
}
|
||||
|
||||
func (f *Filter) loadCache(lwords []string) error {
|
||||
if len(lwords) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return loadWords(f.db, lwords, f.cache)
|
||||
}
|
||||
|
||||
// Untrain adjusts the filter to undo a previous training of the words.
|
||||
func (f *Filter) Untrain(ham bool, words map[string]struct{}) error {
|
||||
if err := f.ensureBloom(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Lookup any words from the db that aren't in the cache and put them in the cache for modification.
|
||||
var lwords []string
|
||||
for w := range words {
|
||||
if _, ok := f.cache[w]; !ok {
|
||||
lwords = append(lwords, w)
|
||||
}
|
||||
}
|
||||
if err := f.loadCache(lwords); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Modify the message count.
|
||||
f.modified = true
|
||||
if ham {
|
||||
f.hams--
|
||||
} else {
|
||||
f.spams--
|
||||
}
|
||||
|
||||
// Decrease the word counts.
|
||||
for w := range words {
|
||||
c, ok := f.cache[w]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if ham {
|
||||
c.Ham--
|
||||
} else {
|
||||
c.Spam--
|
||||
}
|
||||
f.cache[w] = c
|
||||
f.changed[w] = c
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrainDir parses mail messages from files and trains the filter.
|
||||
func (f *Filter) TrainDir(dir string, files []string, ham bool) (n, malformed uint32, rerr error) {
|
||||
if f.closed {
|
||||
return 0, 0, errClosed
|
||||
}
|
||||
if err := f.ensureBloom(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
for _, name := range files {
|
||||
p := fmt.Sprintf("%s/%s", dir, name)
|
||||
valid, words, err := f.tokenizeMail(p)
|
||||
if err != nil {
|
||||
// f.log.Infox("tokenizing mail", err, mlog.Field("path", p))
|
||||
malformed++
|
||||
continue
|
||||
}
|
||||
if !valid {
|
||||
continue
|
||||
}
|
||||
n++
|
||||
for w := range words {
|
||||
if !f.bloom.Has(w) {
|
||||
f.bloom.Add(w)
|
||||
continue
|
||||
}
|
||||
c := f.cache[w]
|
||||
f.modified = true
|
||||
if ham {
|
||||
c.Ham++
|
||||
} else {
|
||||
c.Spam++
|
||||
}
|
||||
f.cache[w] = c
|
||||
f.changed[w] = c
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// TrainDirs trains and saves a filter with mail messages from different types
|
||||
// of directories.
|
||||
func (f *Filter) TrainDirs(hamDir, sentDir, spamDir string, hamFiles, sentFiles, spamFiles []string) error {
|
||||
if f.closed {
|
||||
return errClosed
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
var start time.Time
|
||||
var hamMalformed, sentMalformed, spamMalformed uint32
|
||||
|
||||
start = time.Now()
|
||||
f.hams, hamMalformed, err = f.TrainDir(hamDir, hamFiles, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tham := time.Since(start)
|
||||
|
||||
var sent uint32
|
||||
start = time.Now()
|
||||
if sentDir != "" {
|
||||
sent, sentMalformed, err = f.TrainDir(sentDir, sentFiles, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
tsent := time.Since(start)
|
||||
|
||||
start = time.Now()
|
||||
f.spams, spamMalformed, err = f.TrainDir(spamDir, spamFiles, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tspam := time.Since(start)
|
||||
|
||||
hams := f.hams
|
||||
f.hams += sent
|
||||
if err := f.Save(); err != nil {
|
||||
return fmt.Errorf("saving filter: %s", err)
|
||||
}
|
||||
|
||||
dbSize := f.fileSize(f.dbPath)
|
||||
bloomSize := f.fileSize(f.bloomPath)
|
||||
|
||||
fields := []mlog.Pair{
|
||||
mlog.Field("hams", hams),
|
||||
mlog.Field("hamTime", tham),
|
||||
mlog.Field("hamMalformed", hamMalformed),
|
||||
mlog.Field("sent", sent),
|
||||
mlog.Field("sentTime", tsent),
|
||||
mlog.Field("sentMalformed", sentMalformed),
|
||||
mlog.Field("spams", f.spams),
|
||||
mlog.Field("spamTime", tspam),
|
||||
mlog.Field("spamMalformed", spamMalformed),
|
||||
mlog.Field("dbsize", fmt.Sprintf("%.1fmb", float64(dbSize)/(1024*1024))),
|
||||
mlog.Field("bloomsize", fmt.Sprintf("%.1fmb", float64(bloomSize)/(1024*1024))),
|
||||
mlog.Field("bloom1ratio", fmt.Sprintf("%.4f", float64(f.bloom.Ones())/float64(len(f.bloom.Bytes())*8))),
|
||||
}
|
||||
xlog.Print("training done", fields...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Filter) fileSize(p string) int {
|
||||
fi, err := os.Stat(p)
|
||||
if err != nil {
|
||||
f.log.Infox("stat", err, mlog.Field("path", p))
|
||||
return 0
|
||||
}
|
||||
return int(fi.Size())
|
||||
}
|
201
junk/filter_test.go
Normal file
201
junk/filter_test.go
Normal file
@ -0,0 +1,201 @@
|
||||
package junk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/mjl-/mox/mlog"
|
||||
)
|
||||
|
||||
func tcheck(t *testing.T, err error, msg string) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("%s: %s", msg, err)
|
||||
}
|
||||
}
|
||||
|
||||
func tlistdir(t *testing.T, name string) []string {
|
||||
t.Helper()
|
||||
l, err := os.ReadDir(name)
|
||||
tcheck(t, err, "readdir")
|
||||
names := make([]string, len(l))
|
||||
for i, e := range l {
|
||||
names[i] = e.Name()
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
log := mlog.New("junk")
|
||||
params := Params{
|
||||
Onegrams: true,
|
||||
Twograms: true,
|
||||
Threegrams: false,
|
||||
MaxPower: 0.1,
|
||||
TopWords: 10,
|
||||
IgnoreWords: 0.1,
|
||||
RareWords: 1,
|
||||
}
|
||||
dbPath := "../testdata/junk/filter.db"
|
||||
bloomPath := "../testdata/junk/filter.bloom"
|
||||
os.Remove(dbPath)
|
||||
os.Remove(bloomPath)
|
||||
f, err := NewFilter(log, params, dbPath, bloomPath)
|
||||
tcheck(t, err, "new filter")
|
||||
err = f.Close()
|
||||
tcheck(t, err, "close filter")
|
||||
|
||||
f, err = OpenFilter(log, params, dbPath, bloomPath, true)
|
||||
tcheck(t, err, "open filter")
|
||||
|
||||
// Ensure these dirs exist. Developers should bring their own ham/spam example
|
||||
// emails.
|
||||
os.MkdirAll("../testdata/train/ham", 0770)
|
||||
os.MkdirAll("../testdata/train/spam", 0770)
|
||||
|
||||
hamdir := "../testdata/train/ham"
|
||||
spamdir := "../testdata/train/spam"
|
||||
hamfiles := tlistdir(t, hamdir)
|
||||
if len(hamfiles) > 100 {
|
||||
hamfiles = hamfiles[:100]
|
||||
}
|
||||
spamfiles := tlistdir(t, spamdir)
|
||||
if len(spamfiles) > 100 {
|
||||
spamfiles = spamfiles[:100]
|
||||
}
|
||||
|
||||
err = f.TrainDirs(hamdir, "", spamdir, hamfiles, nil, spamfiles)
|
||||
tcheck(t, err, "train dirs")
|
||||
|
||||
if len(hamfiles) == 0 || len(spamfiles) == 0 {
|
||||
fmt.Println("not training, no ham and/or spam messages, add them to testdata/train/ham and testdata/train/spam")
|
||||
return
|
||||
}
|
||||
|
||||
prob, _, _, _, err := f.ClassifyMessagePath(filepath.Join(hamdir, hamfiles[0]))
|
||||
tcheck(t, err, "classify ham message")
|
||||
if prob > 0.1 {
|
||||
t.Fatalf("trained ham file has prob %v, expected <= 0.1", prob)
|
||||
}
|
||||
|
||||
prob, _, _, _, err = f.ClassifyMessagePath(filepath.Join(spamdir, spamfiles[0]))
|
||||
tcheck(t, err, "classify spam message")
|
||||
if prob < 0.9 {
|
||||
t.Fatalf("trained spam file has prob %v, expected > 0.9", prob)
|
||||
}
|
||||
|
||||
err = f.Close()
|
||||
tcheck(t, err, "close filter")
|
||||
|
||||
// Start again with empty filter. We'll train a few messages and check they are
|
||||
// classified as ham/spam. Then we untrain to see they are no longer classified.
|
||||
os.Remove(dbPath)
|
||||
os.Remove(bloomPath)
|
||||
f, err = NewFilter(log, params, dbPath, bloomPath)
|
||||
tcheck(t, err, "open filter")
|
||||
|
||||
hamf, err := os.Open(filepath.Join(hamdir, hamfiles[0]))
|
||||
tcheck(t, err, "open hamfile")
|
||||
defer hamf.Close()
|
||||
hamstat, err := hamf.Stat()
|
||||
tcheck(t, err, "stat hamfile")
|
||||
hamsize := hamstat.Size()
|
||||
|
||||
spamf, err := os.Open(filepath.Join(spamdir, spamfiles[0]))
|
||||
tcheck(t, err, "open spamfile")
|
||||
defer spamf.Close()
|
||||
spamstat, err := spamf.Stat()
|
||||
tcheck(t, err, "stat spamfile")
|
||||
spamsize := spamstat.Size()
|
||||
|
||||
// Train each message twice, to prevent single occurrences from being ignored.
|
||||
err = f.TrainMessage(hamf, hamsize, true)
|
||||
tcheck(t, err, "train ham message")
|
||||
_, err = hamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek ham message")
|
||||
err = f.TrainMessage(hamf, hamsize, true)
|
||||
tcheck(t, err, "train ham message")
|
||||
|
||||
err = f.TrainMessage(spamf, spamsize, false)
|
||||
tcheck(t, err, "train spam message")
|
||||
_, err = spamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek spam message")
|
||||
err = f.TrainMessage(spamf, spamsize, true)
|
||||
tcheck(t, err, "train spam message")
|
||||
|
||||
if !f.modified {
|
||||
t.Fatalf("filter not modified after training")
|
||||
}
|
||||
if !f.bloom.Modified() {
|
||||
t.Fatalf("bloom filter not modified after training")
|
||||
}
|
||||
|
||||
err = f.Save()
|
||||
tcheck(t, err, "save filter")
|
||||
if f.modified || f.bloom.Modified() {
|
||||
t.Fatalf("filter or bloom filter still modified after save")
|
||||
}
|
||||
|
||||
// Classify and verify.
|
||||
_, err = hamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek ham message")
|
||||
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
|
||||
tcheck(t, err, "classify ham")
|
||||
if prob > 0.1 {
|
||||
t.Fatalf("got prob %v, expected <= 0.1", prob)
|
||||
}
|
||||
|
||||
_, err = spamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek spam message")
|
||||
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
|
||||
tcheck(t, err, "classify spam")
|
||||
if prob < 0.9 {
|
||||
t.Fatalf("got prob %v, expected >= 0.9", prob)
|
||||
}
|
||||
|
||||
// Untrain ham & spam.
|
||||
_, err = hamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek ham message")
|
||||
err = f.UntrainMessage(hamf, hamsize, true)
|
||||
tcheck(t, err, "untrain ham message")
|
||||
_, err = hamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek ham message")
|
||||
err = f.UntrainMessage(hamf, spamsize, true)
|
||||
tcheck(t, err, "untrain ham message")
|
||||
|
||||
_, err = spamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek spam message")
|
||||
err = f.UntrainMessage(spamf, spamsize, true)
|
||||
tcheck(t, err, "untrain spam message")
|
||||
_, err = spamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek spam message")
|
||||
err = f.UntrainMessage(spamf, spamsize, true)
|
||||
tcheck(t, err, "untrain spam message")
|
||||
|
||||
if !f.modified {
|
||||
t.Fatalf("filter not modified after untraining")
|
||||
}
|
||||
|
||||
// Classify again, should be unknown.
|
||||
_, err = hamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek ham message")
|
||||
prob, _, _, _, err = f.ClassifyMessageReader(hamf, hamsize)
|
||||
tcheck(t, err, "classify ham")
|
||||
if math.Abs(prob-0.5) > 0.1 {
|
||||
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
||||
}
|
||||
|
||||
_, err = spamf.Seek(0, 0)
|
||||
tcheck(t, err, "seek spam message")
|
||||
prob, _, _, _, err = f.ClassifyMessageReader(spamf, spamsize)
|
||||
tcheck(t, err, "classify spam")
|
||||
if math.Abs(prob-0.5) > 0.1 {
|
||||
t.Fatalf("got prob %v, expected 0.5 +-0.1", prob)
|
||||
}
|
||||
|
||||
err = f.Close()
|
||||
tcheck(t, err, "close filter")
|
||||
}
|
323
junk/parse.go
Normal file
323
junk/parse.go
Normal file
@ -0,0 +1,323 @@
|
||||
package junk
|
||||
|
||||
// see https://en.wikipedia.org/wiki/Naive_Bayes_spam_filtering
|
||||
// - todo: better html parsing?
|
||||
// - todo: try reading text in pdf?
|
||||
// - todo: try to detect language, have words per language? can be in the same dictionary. currently my dictionary is biased towards treating english as spam.
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
|
||||
"github.com/mjl-/mox/message"
|
||||
)
|
||||
|
||||
func (f *Filter) tokenizeMail(path string) (bool, map[string]struct{}, error) {
|
||||
mf, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
defer mf.Close()
|
||||
fi, err := mf.Stat()
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
p, _ := message.EnsurePart(mf, fi.Size())
|
||||
words, err := f.ParseMessage(p)
|
||||
return true, words, err
|
||||
}
|
||||
|
||||
// ParseMessage reads a mail and returns a map with words.
|
||||
func (f *Filter) ParseMessage(p message.Part) (map[string]struct{}, error) {
|
||||
metaWords := map[string]struct{}{}
|
||||
textWords := map[string]struct{}{}
|
||||
htmlWords := map[string]struct{}{}
|
||||
|
||||
hdrs, err := p.Header()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing headers: %v", err)
|
||||
}
|
||||
|
||||
// Add words from the header, annotated with <field>+":".
|
||||
// todo: add whether header is dkim-verified?
|
||||
for k, l := range hdrs {
|
||||
for _, h := range l {
|
||||
switch k {
|
||||
case "From", "To", "Cc", "Bcc", "Reply-To", "Subject", "Sender", "Return-Path":
|
||||
// case "Subject", "To":
|
||||
default:
|
||||
continue
|
||||
}
|
||||
words := map[string]struct{}{}
|
||||
f.tokenizeText(strings.NewReader(h), words)
|
||||
for w := range words {
|
||||
if len(w) <= 3 {
|
||||
continue
|
||||
}
|
||||
metaWords[k+":"+w] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := f.mailParse(p, metaWords, textWords, htmlWords); err != nil {
|
||||
return nil, fmt.Errorf("parsing message: %w", err)
|
||||
}
|
||||
|
||||
for w := range metaWords {
|
||||
textWords[w] = struct{}{}
|
||||
}
|
||||
for w := range htmlWords {
|
||||
textWords[w] = struct{}{}
|
||||
}
|
||||
|
||||
return textWords, nil
|
||||
}
|
||||
|
||||
// mailParse looks through the mail for the first text and html parts, and tokenizes their words.
|
||||
func (f *Filter) mailParse(p message.Part, metaWords, textWords, htmlWords map[string]struct{}) error {
|
||||
ct := p.MediaType + "/" + p.MediaSubType
|
||||
|
||||
if ct == "TEXT/HTML" {
|
||||
err := f.tokenizeHTML(p.Reader(), metaWords, htmlWords)
|
||||
// log.Printf("html parsed, words %v", htmlWords)
|
||||
return err
|
||||
}
|
||||
if ct == "" || strings.HasPrefix(ct, "TEXT/") {
|
||||
err := f.tokenizeText(p.Reader(), textWords)
|
||||
// log.Printf("text parsed, words %v", textWords)
|
||||
return err
|
||||
}
|
||||
if p.Message != nil {
|
||||
// Nested message, happens for forwarding.
|
||||
if err := p.SetMessageReaderAt(); err != nil {
|
||||
return fmt.Errorf("setting reader on nested message: %w", err)
|
||||
}
|
||||
return f.mailParse(*p.Message, metaWords, textWords, htmlWords)
|
||||
}
|
||||
for _, sp := range p.Parts {
|
||||
if err := f.mailParse(sp, metaWords, textWords, htmlWords); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func looksRandom(s string) bool {
|
||||
// Random strings, eg 2fvu9stm9yxhnlu. ASCII only and a many consonants in a stretch.
|
||||
stretch := 0
|
||||
const consonants = "bcdfghjklmnpqrstvwxyzBCDFGHJKLMNPQRSTVWXYZ23456789" // 0 and 1 may be used as o and l/i
|
||||
stretches := 0
|
||||
for _, c := range s {
|
||||
if c >= 0x80 {
|
||||
return false
|
||||
}
|
||||
if strings.ContainsRune(consonants, c) {
|
||||
stretch++
|
||||
continue
|
||||
}
|
||||
if stretch >= 6 {
|
||||
stretches++
|
||||
}
|
||||
stretch = 0
|
||||
}
|
||||
if stretch >= 6 {
|
||||
stretches++
|
||||
}
|
||||
return stretches > 0
|
||||
}
|
||||
|
||||
func looksNumeric(s string) bool {
|
||||
s = strings.TrimPrefix(s, "0x") // Hexadecimal.
|
||||
var digits, hex, other, digitstretch, maxdigitstretch int
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
digits++
|
||||
digitstretch++
|
||||
continue
|
||||
} else if c >= 'a' && c <= 'f' || c >= 'A' && c <= 'F' {
|
||||
hex++
|
||||
} else {
|
||||
other++
|
||||
}
|
||||
if digitstretch > maxdigitstretch {
|
||||
maxdigitstretch = digitstretch
|
||||
}
|
||||
}
|
||||
if digitstretch > maxdigitstretch {
|
||||
maxdigitstretch = digitstretch
|
||||
}
|
||||
return maxdigitstretch >= 4 || other == 0 && maxdigitstretch >= 3
|
||||
}
|
||||
|
||||
func (f *Filter) tokenizeText(r io.Reader, words map[string]struct{}) error {
|
||||
b := &strings.Builder{}
|
||||
var prev string
|
||||
var prev2 string
|
||||
|
||||
add := func() {
|
||||
defer b.Reset()
|
||||
if b.Len() <= 2 {
|
||||
return
|
||||
}
|
||||
|
||||
s := b.String()
|
||||
s = strings.Trim(s, "'")
|
||||
var nondigit bool
|
||||
for _, c := range s {
|
||||
if !unicode.IsDigit(c) {
|
||||
nondigit = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !(nondigit && len(s) > 2) {
|
||||
return
|
||||
}
|
||||
|
||||
if looksRandom(s) {
|
||||
return
|
||||
}
|
||||
if looksNumeric(s) {
|
||||
return
|
||||
}
|
||||
|
||||
// todo: do something for URLs, parse them? keep their domain only?
|
||||
|
||||
if f.Threegrams && prev2 != "" && prev != "" {
|
||||
words[prev2+" "+prev+" "+s] = struct{}{}
|
||||
}
|
||||
if f.Twograms && prev != "" {
|
||||
words[prev+" "+s] = struct{}{}
|
||||
}
|
||||
if f.Onegrams {
|
||||
words[s] = struct{}{}
|
||||
}
|
||||
prev2 = prev
|
||||
prev = s
|
||||
}
|
||||
|
||||
br := bufio.NewReader(r)
|
||||
|
||||
peekLetter := func() bool {
|
||||
c, _, err := br.ReadRune()
|
||||
br.UnreadRune()
|
||||
return err == nil && unicode.IsLetter(c)
|
||||
}
|
||||
|
||||
for {
|
||||
c, _, err := br.ReadRune()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !unicode.IsLetter(c) && !unicode.IsDigit(c) && (c != '\'' || b.Len() > 0 && peekLetter()) {
|
||||
add()
|
||||
} else {
|
||||
b.WriteRune(unicode.ToLower(c))
|
||||
}
|
||||
}
|
||||
add()
|
||||
return nil
|
||||
}
|
||||
|
||||
// tokenizeHTML parses html, and tokenizes its text into words.
|
||||
func (f *Filter) tokenizeHTML(r io.Reader, meta, words map[string]struct{}) error {
|
||||
htmlReader := &htmlTextReader{
|
||||
t: html.NewTokenizer(r),
|
||||
meta: map[string]struct{}{},
|
||||
}
|
||||
return f.tokenizeText(htmlReader, words)
|
||||
}
|
||||
|
||||
type htmlTextReader struct {
|
||||
t *html.Tokenizer
|
||||
meta map[string]struct{}
|
||||
tagStack []string
|
||||
buf []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *htmlTextReader) Read(buf []byte) (n int, err error) {
|
||||
// todo: deal with invalid html better. the tokenizer is just tokenizing, we need to fix up the nesting etc. eg, rules say some elements close certain open elements.
|
||||
// todo: deal with inline elements? they shouldn't cause a word break.
|
||||
|
||||
give := func(nbuf []byte) (int, error) {
|
||||
n := len(buf)
|
||||
if n > len(nbuf) {
|
||||
n = len(nbuf)
|
||||
}
|
||||
copy(buf, nbuf[:n])
|
||||
nbuf = nbuf[n:]
|
||||
if len(nbuf) < cap(r.buf) {
|
||||
r.buf = r.buf[:len(nbuf)]
|
||||
} else {
|
||||
r.buf = make([]byte, len(nbuf), 3*len(nbuf)/2)
|
||||
}
|
||||
copy(r.buf, nbuf)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if len(r.buf) > 0 {
|
||||
return give(r.buf)
|
||||
}
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
|
||||
for {
|
||||
switch r.t.Next() {
|
||||
case html.ErrorToken:
|
||||
r.err = r.t.Err()
|
||||
return 0, r.err
|
||||
case html.TextToken:
|
||||
if len(r.tagStack) > 0 {
|
||||
switch r.tagStack[len(r.tagStack)-1] {
|
||||
case "script", "style", "svg":
|
||||
continue
|
||||
}
|
||||
}
|
||||
buf := r.t.Text()
|
||||
if len(buf) > 0 {
|
||||
return give(buf)
|
||||
}
|
||||
case html.StartTagToken:
|
||||
tagBuf, moreAttr := r.t.TagName()
|
||||
tag := string(tagBuf)
|
||||
//log.Printf("tag %q %v", tag, r.tagStack)
|
||||
|
||||
if tag == "img" && moreAttr {
|
||||
var key, val []byte
|
||||
for moreAttr {
|
||||
key, val, moreAttr = r.t.TagAttr()
|
||||
if string(key) == "alt" && len(val) > 0 {
|
||||
return give(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Empty elements, https://developer.mozilla.org/en-US/docs/Glossary/Empty_element
|
||||
switch tag {
|
||||
case "area", "base", "br", "col", "embed", "hr", "img", "input", "link", "meta", "param", "source", "track", "wbr":
|
||||
continue
|
||||
}
|
||||
|
||||
r.tagStack = append(r.tagStack, tag)
|
||||
case html.EndTagToken:
|
||||
// log.Printf("tag pop %v", r.tagStack)
|
||||
if len(r.tagStack) > 0 {
|
||||
r.tagStack = r.tagStack[:len(r.tagStack)-1]
|
||||
}
|
||||
case html.SelfClosingTagToken:
|
||||
case html.CommentToken:
|
||||
case html.DoctypeToken:
|
||||
}
|
||||
}
|
||||
}
|
33
junk/parse_test.go
Normal file
33
junk/parse_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package junk
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func FuzzParseMessage(f *testing.F) {
|
||||
f.Add("")
|
||||
add := func(p string) {
|
||||
buf, err := os.ReadFile(p)
|
||||
if err != nil {
|
||||
f.Fatalf("reading file %q: %v", p, err)
|
||||
}
|
||||
f.Add(string(buf))
|
||||
}
|
||||
add("../testdata/junk/parse.eml")
|
||||
add("../testdata/junk/parse2.eml")
|
||||
add("../testdata/junk/parse3.eml")
|
||||
|
||||
dbPath := "../testdata/junk/parse.db"
|
||||
bloomPath := "../testdata/junk/parse.bloom"
|
||||
os.Remove(dbPath)
|
||||
os.Remove(bloomPath)
|
||||
params := Params{Twograms: true}
|
||||
jf, err := NewFilter(xlog, params, dbPath, bloomPath)
|
||||
if err != nil {
|
||||
f.Fatalf("new filter: %v", err)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
jf.tokenizeMail(s)
|
||||
})
|
||||
}
|
Reference in New Issue
Block a user