mirror of
https://github.com/mjl-/mox.git
synced 2025-07-10 07:14:40 +03:00
mox!
This commit is contained in:
266
scram/parse.go
Normal file
266
scram/parse.go
Normal file
@ -0,0 +1,266 @@
|
||||
package scram
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type parser struct {
|
||||
s string // Original casing.
|
||||
lower string // Lower casing, for case-insensitive token consumption.
|
||||
o int // Offset in s/lower.
|
||||
}
|
||||
|
||||
type parseError struct{ err error }
|
||||
|
||||
func (e parseError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e parseError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// toLower lower cases bytes that are A-Z. strings.ToLower does too much. and
|
||||
// would replace invalid bytes with unicode replacement characters, which would
|
||||
// break our requirement that offsets into the original and upper case strings
|
||||
// point to the same character.
|
||||
func toLower(s string) string {
|
||||
r := []byte(s)
|
||||
for i, c := range r {
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
r[i] = c + 0x20
|
||||
}
|
||||
}
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func newParser(buf []byte) *parser {
|
||||
s := string(buf)
|
||||
return &parser{s, toLower(s), 0}
|
||||
}
|
||||
|
||||
// Turn panics of parseError into a descriptive ErrInvalidEncoding. Called with
|
||||
// defer by functions that parse.
|
||||
func (p *parser) recover(rerr *error) {
|
||||
x := recover()
|
||||
if x == nil {
|
||||
return
|
||||
}
|
||||
err, ok := x.(error)
|
||||
if !ok {
|
||||
panic(x)
|
||||
}
|
||||
var xerr Error
|
||||
if errors.As(err, &xerr) {
|
||||
*rerr = err
|
||||
return
|
||||
}
|
||||
*rerr = fmt.Errorf("%w: %s", ErrInvalidEncoding, err)
|
||||
}
|
||||
|
||||
func (p *parser) xerrorf(format string, args ...any) {
|
||||
panic(parseError{fmt.Errorf(format, args...)})
|
||||
}
|
||||
|
||||
func (p *parser) xcheckf(err error, format string, args ...any) {
|
||||
if err != nil {
|
||||
panic(parseError{fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), err)})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xempty() {
|
||||
if p.o != len(p.s) {
|
||||
p.xerrorf("leftover data")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xnonempty() {
|
||||
if p.o >= len(p.s) {
|
||||
p.xerrorf("unexpected end")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xbyte() byte {
|
||||
p.xnonempty()
|
||||
c := p.lower[p.o]
|
||||
p.o++
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *parser) peek(s string) bool {
|
||||
return strings.HasPrefix(p.lower[p.o:], s)
|
||||
}
|
||||
|
||||
func (p *parser) take(s string) bool {
|
||||
if p.peek(s) {
|
||||
p.o += len(s)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *parser) xtake(s string) {
|
||||
if !p.take(s) {
|
||||
p.xerrorf("expected %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xauthzid() string {
|
||||
p.xtake("a=")
|
||||
return p.xsaslname()
|
||||
}
|
||||
|
||||
func (p *parser) xusername() string {
|
||||
p.xtake("n=")
|
||||
return p.xsaslname()
|
||||
}
|
||||
|
||||
func (p *parser) xnonce() string {
|
||||
p.xtake("r=")
|
||||
o := p.o
|
||||
for ; o < len(p.s); o++ {
|
||||
c := p.s[o]
|
||||
if c <= ' ' || c >= 0x7f || c == ',' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if o == p.o {
|
||||
p.xerrorf("empty nonce")
|
||||
}
|
||||
r := p.s[p.o:o]
|
||||
p.o = o
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xattrval() {
|
||||
c := p.xbyte()
|
||||
if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z') {
|
||||
p.xerrorf("expected alpha for attr-val")
|
||||
}
|
||||
p.xtake("=")
|
||||
p.xvalue()
|
||||
}
|
||||
|
||||
func (p *parser) xvalue() string {
|
||||
for o, c := range p.s[p.o:] {
|
||||
if c == 0 || c == ',' {
|
||||
if o == 0 {
|
||||
p.xerrorf("invalid empty value")
|
||||
}
|
||||
r := p.s[p.o : p.o+o]
|
||||
p.o = o
|
||||
return r
|
||||
}
|
||||
}
|
||||
p.xnonempty()
|
||||
r := p.s[p.o:]
|
||||
p.o = len(p.s)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xbase64() []byte {
|
||||
o := p.o
|
||||
for ; o < len(p.s); o++ {
|
||||
c := p.s[o]
|
||||
if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '/' || c == '+' || c == '=') {
|
||||
break
|
||||
}
|
||||
}
|
||||
buf, err := base64.StdEncoding.DecodeString(p.s[p.o:o])
|
||||
p.xcheckf(err, "decoding base64")
|
||||
p.o = o
|
||||
return buf
|
||||
}
|
||||
|
||||
func (p *parser) xsaslname() string {
|
||||
var esc string
|
||||
var is bool
|
||||
var r string
|
||||
for o, c := range p.s[p.o:] {
|
||||
if c == 0 || c == ',' {
|
||||
if is {
|
||||
p.xerrorf("saslname unexpected end")
|
||||
}
|
||||
if o == 0 {
|
||||
p.xerrorf("saslname cannot be empty")
|
||||
}
|
||||
p.o += o
|
||||
return r
|
||||
}
|
||||
if is {
|
||||
esc += string(c)
|
||||
if len(esc) < 2 {
|
||||
continue
|
||||
}
|
||||
switch esc {
|
||||
case "2c", "2C":
|
||||
r += ","
|
||||
case "3d", "3D":
|
||||
r += "="
|
||||
default:
|
||||
p.xerrorf("bad escape %q in saslanem", esc)
|
||||
}
|
||||
is = false
|
||||
esc = ""
|
||||
continue
|
||||
} else if c == '=' {
|
||||
is = true
|
||||
continue
|
||||
}
|
||||
r += string(c)
|
||||
}
|
||||
if is {
|
||||
p.xerrorf("saslname unexpected end")
|
||||
}
|
||||
if r == "" {
|
||||
p.xerrorf("saslname cannot be empty")
|
||||
}
|
||||
p.o = len(p.s)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xchannelBinding() string {
|
||||
p.xtake("c=")
|
||||
return string(p.xbase64())
|
||||
}
|
||||
|
||||
func (p *parser) xproof() []byte {
|
||||
p.xtake("p=")
|
||||
return p.xbase64()
|
||||
}
|
||||
|
||||
func (p *parser) xsalt() []byte {
|
||||
p.xtake("s=")
|
||||
return p.xbase64()
|
||||
}
|
||||
|
||||
func (p *parser) xtakefn1(fn func(rune, int) bool) string {
|
||||
for o, c := range p.s[p.o:] {
|
||||
if !fn(c, o) {
|
||||
if o == 0 {
|
||||
p.xerrorf("non-empty match required")
|
||||
}
|
||||
r := p.s[p.o : p.o+o]
|
||||
p.o += o
|
||||
return r
|
||||
}
|
||||
}
|
||||
p.xnonempty()
|
||||
r := p.s[p.o:]
|
||||
p.o = len(p.s)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xiterations() int {
|
||||
p.xtake("i=")
|
||||
digits := p.xtakefn1(func(c rune, i int) bool {
|
||||
return c >= '1' && c <= '9' || i > 0 && c == '0'
|
||||
})
|
||||
v, err := strconv.ParseInt(digits, 10, 32)
|
||||
p.xcheckf(err, "parsing int")
|
||||
return int(v)
|
||||
}
|
Reference in New Issue
Block a user