add basic rate limiters

limiting is done based on remote ip's, with 3 ip mask variants to limit networks
of machines. often with two windows, enabling short bursts of activity, but not
sustained high activity. currently only for imap and smtp, not yet http.

limits are currently based on:
- number of open connections
- connection rate
- limits after authentication failures. too many failures, and new connections will be dropped.
- rate of delivery in total number of messages
- rate of delivery in total size of messages

the limits on connections and authentication failures are in-memory. the limits
on delivery of messages are based on stored messages.

the limits themselves are not yet configurable, let's use this first.

in the future, we may also want to have stricter limits for senders without any
reputation.
This commit is contained in:
Mechiel Lukkien
2023-02-07 22:56:03 +01:00
parent 1617b7c0d6
commit 2154392bd8
7 changed files with 584 additions and 6 deletions

146
ratelimit/ratelimit.go Normal file
View File

@ -0,0 +1,146 @@
// Package ratelimit provides a simple window-based rate limiter.
package ratelimit
import (
"net"
"sync"
"time"
)
// Limiter is a simple rate limiter with one or more fixed windows, e.g. the
// last minute/hour/day/week, working on three classes/subnets of an IP.
type Limiter struct {
sync.Mutex
WindowLimits []WindowLimit
ipmasked [3][16]byte
}
// WindowLimit holds counters for one window, with limits for each IP class/subnet.
type WindowLimit struct {
Window time.Duration
Limits [3]int64 // For "ipmasked1" through "ipmasked3".
Time uint32 // Time/Window.
Counts map[struct {
Index uint8
IPMasked [16]byte
}]int64
}
// Add attempts to consume "n" items from the rate limiter. If the total for this
// key and this interval would exceed limit, "n" is not counted and false is
// returned. If now represents a different time interval, all counts are reset.
func (l *Limiter) Add(ip net.IP, tm time.Time, n int64) bool {
return l.checkAdd(true, ip, tm, n)
}
// CanAdd returns if n could be added to the limiter.
func (l *Limiter) CanAdd(ip net.IP, tm time.Time, n int64) bool {
return l.checkAdd(false, ip, tm, n)
}
func (l *Limiter) checkAdd(add bool, ip net.IP, tm time.Time, n int64) bool {
l.Lock()
defer l.Unlock()
// First check.
for i, pl := range l.WindowLimits {
t := uint32(tm.UnixNano() / int64(pl.Window))
if t > pl.Time || pl.Counts == nil {
l.WindowLimits[i].Time = t
pl.Counts = map[struct {
Index uint8
IPMasked [16]byte
}]int64{} // Used below.
l.WindowLimits[i].Counts = pl.Counts
}
for j := 0; j < 3; j++ {
if i == 0 {
l.ipmasked[j] = l.maskIP(j, ip)
}
v := pl.Counts[struct {
Index uint8
IPMasked [16]byte
}{uint8(j), l.ipmasked[j]}]
if v+n > pl.Limits[j] {
return false
}
}
}
if !add {
return true
}
// Finally record.
for _, pl := range l.WindowLimits {
for j := 0; j < 3; j++ {
pl.Counts[struct {
Index uint8
IPMasked [16]byte
}{uint8(j), l.ipmasked[j]}] += n
}
}
return true
}
// Reset sets the counter to 0 for key and ip, and substracts from the ipmasked counts.
func (l *Limiter) Reset(ip net.IP, tm time.Time) {
l.Lock()
defer l.Unlock()
// Prepare masked ip's.
for i := 0; i < 3; i++ {
l.ipmasked[i] = l.maskIP(i, ip)
}
for _, pl := range l.WindowLimits {
t := uint32(tm.UnixNano() / int64(pl.Window))
if t != pl.Time || pl.Counts == nil {
continue
}
var n int64
for j := 0; j < 3; j++ {
k := struct {
Index uint8
IPMasked [16]byte
}{uint8(j), l.ipmasked[j]}
if j == 0 {
n = pl.Counts[k]
}
if pl.Counts != nil {
pl.Counts[k] -= n
}
}
}
}
func (l *Limiter) maskIP(i int, ip net.IP) [16]byte {
isv4 := ip.To4() != nil
var ipmasked net.IP
if isv4 {
switch i {
case 0:
ipmasked = ip
case 1:
ipmasked = ip.Mask(net.CIDRMask(26, 32))
case 2:
ipmasked = ip.Mask(net.CIDRMask(21, 32))
default:
panic("missing case for maskip ipv4")
}
} else {
switch i {
case 0:
ipmasked = ip.Mask(net.CIDRMask(64, 128))
case 1:
ipmasked = ip.Mask(net.CIDRMask(48, 128))
case 2:
ipmasked = ip.Mask(net.CIDRMask(32, 128))
default:
panic("missing case for masking ipv6")
}
}
return *(*[16]byte)(ipmasked.To16())
}

View File

@ -0,0 +1,72 @@
package ratelimit
import (
"net"
"testing"
"time"
)
func TestLimiter(t *testing.T) {
l := &Limiter{
WindowLimits: []WindowLimit{
{
Window: time.Minute,
Limits: [...]int64{2, 4, 6},
},
},
}
now := time.Now()
check := func(exp bool, ip net.IP, tm time.Time, n int64) {
t.Helper()
ok := l.CanAdd(ip, tm, n)
if ok != exp {
t.Fatalf("canadd, got %v, expected %v", ok, exp)
}
ok = l.Add(ip, tm, n)
if ok != exp {
t.Fatalf("add, got %v, expected %v", ok, exp)
}
}
check(false, net.ParseIP("10.0.0.1"), now, 3) // past limit
check(true, net.ParseIP("10.0.0.1"), now, 1)
check(false, net.ParseIP("10.0.0.1"), now, 2) // now past limit
check(true, net.ParseIP("10.0.0.1"), now, 1)
check(false, net.ParseIP("10.0.0.1"), now, 1) // now past limit
next := now.Add(time.Minute)
check(true, net.ParseIP("10.0.0.1"), next, 2) // next minute, should have reset
check(true, net.ParseIP("10.0.0.2"), next, 2) // other ip
check(false, net.ParseIP("10.0.0.3"), next, 2) // yet another ip, ipmasked2 was consumed
check(true, net.ParseIP("10.0.1.4"), next, 2) // using ipmasked3
check(false, net.ParseIP("10.0.2.4"), next, 2) // ipmasked3 consumed
l.Reset(net.ParseIP("10.0.1.4"), next)
if !l.CanAdd(net.ParseIP("10.0.1.4"), next, 2) {
t.Fatalf("reset did not free up count for ip")
}
check(true, net.ParseIP("10.0.2.4"), next, 2) // ipmasked3 available again
l = &Limiter{
WindowLimits: []WindowLimit{
{
Window: time.Minute,
Limits: [...]int64{1, 2, 3},
},
{
Window: time.Hour,
Limits: [...]int64{2, 3, 4},
},
},
}
min1 := time.UnixMilli((time.Now().UnixNano() / int64(time.Hour)) * int64(time.Hour) / int64(time.Millisecond))
min2 := min1.Add(time.Minute)
min3 := min1.Add(2 * time.Minute)
check(true, net.ParseIP("10.0.0.1"), min1, 1)
check(true, net.ParseIP("10.0.0.1"), min2, 1)
check(false, net.ParseIP("10.0.0.1"), min3, 1)
check(true, net.ParseIP("10.0.0.255"), min3, 1) // ipmasked2 still ok
check(false, net.ParseIP("10.0.0.255"), min3, 1) // ipmasked2 also full
check(true, net.ParseIP("10.0.1.1"), min3, 1) // ipmasked3 still ok
check(false, net.ParseIP("10.0.1.255"), min3, 1) // ipmasked3 also full
}