mirror of
https://github.com/mjl-/mox.git
synced 2025-07-12 17:44:35 +03:00
mox!
This commit is contained in:
266
scram/parse.go
Normal file
266
scram/parse.go
Normal file
@ -0,0 +1,266 @@
|
||||
package scram
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type parser struct {
|
||||
s string // Original casing.
|
||||
lower string // Lower casing, for case-insensitive token consumption.
|
||||
o int // Offset in s/lower.
|
||||
}
|
||||
|
||||
type parseError struct{ err error }
|
||||
|
||||
func (e parseError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e parseError) Unwrap() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
// toLower lower cases bytes that are A-Z. strings.ToLower does too much. and
|
||||
// would replace invalid bytes with unicode replacement characters, which would
|
||||
// break our requirement that offsets into the original and upper case strings
|
||||
// point to the same character.
|
||||
func toLower(s string) string {
|
||||
r := []byte(s)
|
||||
for i, c := range r {
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
r[i] = c + 0x20
|
||||
}
|
||||
}
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func newParser(buf []byte) *parser {
|
||||
s := string(buf)
|
||||
return &parser{s, toLower(s), 0}
|
||||
}
|
||||
|
||||
// Turn panics of parseError into a descriptive ErrInvalidEncoding. Called with
|
||||
// defer by functions that parse.
|
||||
func (p *parser) recover(rerr *error) {
|
||||
x := recover()
|
||||
if x == nil {
|
||||
return
|
||||
}
|
||||
err, ok := x.(error)
|
||||
if !ok {
|
||||
panic(x)
|
||||
}
|
||||
var xerr Error
|
||||
if errors.As(err, &xerr) {
|
||||
*rerr = err
|
||||
return
|
||||
}
|
||||
*rerr = fmt.Errorf("%w: %s", ErrInvalidEncoding, err)
|
||||
}
|
||||
|
||||
func (p *parser) xerrorf(format string, args ...any) {
|
||||
panic(parseError{fmt.Errorf(format, args...)})
|
||||
}
|
||||
|
||||
func (p *parser) xcheckf(err error, format string, args ...any) {
|
||||
if err != nil {
|
||||
panic(parseError{fmt.Errorf("%s: %w", fmt.Sprintf(format, args...), err)})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xempty() {
|
||||
if p.o != len(p.s) {
|
||||
p.xerrorf("leftover data")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xnonempty() {
|
||||
if p.o >= len(p.s) {
|
||||
p.xerrorf("unexpected end")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xbyte() byte {
|
||||
p.xnonempty()
|
||||
c := p.lower[p.o]
|
||||
p.o++
|
||||
return c
|
||||
}
|
||||
|
||||
func (p *parser) peek(s string) bool {
|
||||
return strings.HasPrefix(p.lower[p.o:], s)
|
||||
}
|
||||
|
||||
func (p *parser) take(s string) bool {
|
||||
if p.peek(s) {
|
||||
p.o += len(s)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *parser) xtake(s string) {
|
||||
if !p.take(s) {
|
||||
p.xerrorf("expected %q", s)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *parser) xauthzid() string {
|
||||
p.xtake("a=")
|
||||
return p.xsaslname()
|
||||
}
|
||||
|
||||
func (p *parser) xusername() string {
|
||||
p.xtake("n=")
|
||||
return p.xsaslname()
|
||||
}
|
||||
|
||||
func (p *parser) xnonce() string {
|
||||
p.xtake("r=")
|
||||
o := p.o
|
||||
for ; o < len(p.s); o++ {
|
||||
c := p.s[o]
|
||||
if c <= ' ' || c >= 0x7f || c == ',' {
|
||||
break
|
||||
}
|
||||
}
|
||||
if o == p.o {
|
||||
p.xerrorf("empty nonce")
|
||||
}
|
||||
r := p.s[p.o:o]
|
||||
p.o = o
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xattrval() {
|
||||
c := p.xbyte()
|
||||
if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z') {
|
||||
p.xerrorf("expected alpha for attr-val")
|
||||
}
|
||||
p.xtake("=")
|
||||
p.xvalue()
|
||||
}
|
||||
|
||||
func (p *parser) xvalue() string {
|
||||
for o, c := range p.s[p.o:] {
|
||||
if c == 0 || c == ',' {
|
||||
if o == 0 {
|
||||
p.xerrorf("invalid empty value")
|
||||
}
|
||||
r := p.s[p.o : p.o+o]
|
||||
p.o = o
|
||||
return r
|
||||
}
|
||||
}
|
||||
p.xnonempty()
|
||||
r := p.s[p.o:]
|
||||
p.o = len(p.s)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xbase64() []byte {
|
||||
o := p.o
|
||||
for ; o < len(p.s); o++ {
|
||||
c := p.s[o]
|
||||
if !(c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '/' || c == '+' || c == '=') {
|
||||
break
|
||||
}
|
||||
}
|
||||
buf, err := base64.StdEncoding.DecodeString(p.s[p.o:o])
|
||||
p.xcheckf(err, "decoding base64")
|
||||
p.o = o
|
||||
return buf
|
||||
}
|
||||
|
||||
func (p *parser) xsaslname() string {
|
||||
var esc string
|
||||
var is bool
|
||||
var r string
|
||||
for o, c := range p.s[p.o:] {
|
||||
if c == 0 || c == ',' {
|
||||
if is {
|
||||
p.xerrorf("saslname unexpected end")
|
||||
}
|
||||
if o == 0 {
|
||||
p.xerrorf("saslname cannot be empty")
|
||||
}
|
||||
p.o += o
|
||||
return r
|
||||
}
|
||||
if is {
|
||||
esc += string(c)
|
||||
if len(esc) < 2 {
|
||||
continue
|
||||
}
|
||||
switch esc {
|
||||
case "2c", "2C":
|
||||
r += ","
|
||||
case "3d", "3D":
|
||||
r += "="
|
||||
default:
|
||||
p.xerrorf("bad escape %q in saslanem", esc)
|
||||
}
|
||||
is = false
|
||||
esc = ""
|
||||
continue
|
||||
} else if c == '=' {
|
||||
is = true
|
||||
continue
|
||||
}
|
||||
r += string(c)
|
||||
}
|
||||
if is {
|
||||
p.xerrorf("saslname unexpected end")
|
||||
}
|
||||
if r == "" {
|
||||
p.xerrorf("saslname cannot be empty")
|
||||
}
|
||||
p.o = len(p.s)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xchannelBinding() string {
|
||||
p.xtake("c=")
|
||||
return string(p.xbase64())
|
||||
}
|
||||
|
||||
func (p *parser) xproof() []byte {
|
||||
p.xtake("p=")
|
||||
return p.xbase64()
|
||||
}
|
||||
|
||||
func (p *parser) xsalt() []byte {
|
||||
p.xtake("s=")
|
||||
return p.xbase64()
|
||||
}
|
||||
|
||||
func (p *parser) xtakefn1(fn func(rune, int) bool) string {
|
||||
for o, c := range p.s[p.o:] {
|
||||
if !fn(c, o) {
|
||||
if o == 0 {
|
||||
p.xerrorf("non-empty match required")
|
||||
}
|
||||
r := p.s[p.o : p.o+o]
|
||||
p.o += o
|
||||
return r
|
||||
}
|
||||
}
|
||||
p.xnonempty()
|
||||
r := p.s[p.o:]
|
||||
p.o = len(p.s)
|
||||
return r
|
||||
}
|
||||
|
||||
func (p *parser) xiterations() int {
|
||||
p.xtake("i=")
|
||||
digits := p.xtakefn1(func(c rune, i int) bool {
|
||||
return c >= '1' && c <= '9' || i > 0 && c == '0'
|
||||
})
|
||||
v, err := strconv.ParseInt(digits, 10, 32)
|
||||
p.xcheckf(err, "parsing int")
|
||||
return int(v)
|
||||
}
|
368
scram/scram.go
Normal file
368
scram/scram.go
Normal file
@ -0,0 +1,368 @@
|
||||
// Package scram implements the SCRAM-SHA256 SASL authentication mechanism, RFC 7677.
|
||||
//
|
||||
// SCRAM-SHA256 allows a client to authenticate to a server using a password
|
||||
// without handing plaintext password over to the server. The client also
|
||||
// verifies the server knows (a derivative of) the password.
|
||||
package scram
|
||||
|
||||
// todo: test with messages that contains extensions
|
||||
// todo: some tests for the parser
|
||||
// todo: figure out how invalid parameters etc should be handled. just abort? perhaps mostly a problem for imap.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
)
|
||||
|
||||
// Errors at scram protocol level. Can be exchanged between client and server.
|
||||
var (
|
||||
ErrInvalidEncoding Error = "invalid-encoding"
|
||||
ErrExtensionsNotSupported Error = "extensions-not-supported"
|
||||
ErrInvalidProof Error = "invalid-proof"
|
||||
ErrChannelBindingsDontMatch Error = "channel-bindings-dont-match"
|
||||
ErrServerDoesSupportChannelBinding Error = "server-does-support-channel-binding"
|
||||
ErrChannelBindingNotSupported Error = "channel-binding-not-supported"
|
||||
ErrUnsupportedChannelBindingType Error = "unsupported-channel-binding-type"
|
||||
ErrUnknownUser Error = "unknown-user"
|
||||
ErrNoResources Error = "no-resources"
|
||||
ErrOtherError Error = "other-error"
|
||||
)
|
||||
|
||||
var scramErrors = makeErrors()
|
||||
|
||||
func makeErrors() map[string]Error {
|
||||
l := []Error{
|
||||
ErrInvalidEncoding,
|
||||
ErrExtensionsNotSupported,
|
||||
ErrInvalidProof,
|
||||
ErrChannelBindingsDontMatch,
|
||||
ErrServerDoesSupportChannelBinding,
|
||||
ErrChannelBindingNotSupported,
|
||||
ErrUnsupportedChannelBindingType,
|
||||
ErrUnknownUser,
|
||||
ErrNoResources,
|
||||
ErrOtherError,
|
||||
}
|
||||
m := map[string]Error{}
|
||||
for _, e := range l {
|
||||
m[string(e)] = e
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNorm = errors.New("parameter not unicode normalized") // E.g. if client sends non-normalized username or authzid.
|
||||
ErrUnsafe = errors.New("unsafe parameter") // E.g. salt, nonce too short, or too few iterations.
|
||||
ErrProtocol = errors.New("protocol error") // E.g. server responded with a nonce not prefixed by the client nonce.
|
||||
)
|
||||
|
||||
type Error string
|
||||
|
||||
func (e Error) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// MakeRandom returns a cryptographically random buffer for use as salt or as
|
||||
// nonce.
|
||||
func MakeRandom() []byte {
|
||||
buf := make([]byte, 12)
|
||||
_, err := cryptorand.Read(buf)
|
||||
if err != nil {
|
||||
panic("generate random")
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
// SaltPassword returns a salted password.
|
||||
func SaltPassword(password string, salt []byte, iterations int) []byte {
|
||||
password = norm.NFC.String(password)
|
||||
return pbkdf2.Key([]byte(password), salt, iterations, sha256.Size, sha256.New)
|
||||
}
|
||||
|
||||
// HMAC returns the hmac with key over msg.
|
||||
func HMAC(key []byte, msg string) []byte {
|
||||
mac := hmac.New(sha256.New, key)
|
||||
mac.Write([]byte(msg))
|
||||
return mac.Sum(nil)
|
||||
}
|
||||
|
||||
func xor(a, b []byte) {
|
||||
for i := range a {
|
||||
a[i] ^= b[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the server-side of a SCRAM-SHA-256 authentication.
|
||||
type Server struct {
|
||||
Authentication string // Username for authentication, "authc". Always set and non-empty.
|
||||
Authorization string // If set, role of user to assume after authentication, "authz".
|
||||
|
||||
// Messages used in hash calculations.
|
||||
clientFirstBare string
|
||||
serverFirst string
|
||||
clientFinalWithoutProof string
|
||||
|
||||
gs2header string
|
||||
clientNonce string // Client-part of the nonce.
|
||||
serverNonceOverride string // If set, server does not generate random nonce, but uses this. For tests with the test vector.
|
||||
nonce string // Full client + server nonce.
|
||||
}
|
||||
|
||||
// NewServer returns a server given the first SCRAM message from a client.
|
||||
//
|
||||
// The sequence for data and calls on a server:
|
||||
//
|
||||
// - Read initial data from client, call NewServer (this call), then ServerFirst and write to the client.
|
||||
// - Read response from client, call Finish or FinishFinal and write the resulting string.
|
||||
func NewServer(clientFirst []byte) (server *Server, rerr error) {
|
||||
p := newParser(clientFirst)
|
||||
defer p.recover(&rerr)
|
||||
|
||||
server = &Server{}
|
||||
|
||||
// ../rfc/5802:949 ../rfc/5802:910
|
||||
gs2cbindFlag := p.xbyte()
|
||||
switch gs2cbindFlag {
|
||||
case 'n', 'y':
|
||||
case 'p':
|
||||
p.xerrorf("gs2 header with p: %w", ErrChannelBindingNotSupported)
|
||||
}
|
||||
p.xtake(",")
|
||||
if !p.take(",") {
|
||||
server.Authorization = p.xauthzid()
|
||||
if norm.NFC.String(server.Authorization) != server.Authorization {
|
||||
return nil, fmt.Errorf("%w: authzid", ErrNorm)
|
||||
}
|
||||
p.xtake(",")
|
||||
}
|
||||
server.gs2header = p.s[:p.o]
|
||||
server.clientFirstBare = p.s[p.o:]
|
||||
|
||||
// ../rfc/5802:945
|
||||
if p.take("m=") {
|
||||
p.xerrorf("unexpected mandatory extension: %w", ErrExtensionsNotSupported)
|
||||
}
|
||||
server.Authentication = p.xusername()
|
||||
if norm.NFC.String(server.Authentication) != server.Authentication {
|
||||
return nil, fmt.Errorf("%w: username", ErrNorm)
|
||||
}
|
||||
p.xtake(",")
|
||||
server.clientNonce = p.xnonce()
|
||||
if len(server.clientNonce) < 8 {
|
||||
return nil, fmt.Errorf("%w: client nonce too short", ErrUnsafe)
|
||||
}
|
||||
// Extensions, we don't recognize them.
|
||||
for p.take(",") {
|
||||
p.xattrval()
|
||||
}
|
||||
p.xempty()
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// ServerFirst returns the string to send back to the client. To be called after NewServer.
|
||||
func (s *Server) ServerFirst(iterations int, salt []byte) (string, error) {
|
||||
// ../rfc/5802:959
|
||||
serverNonce := s.serverNonceOverride
|
||||
if serverNonce == "" {
|
||||
serverNonce = base64.StdEncoding.EncodeToString(MakeRandom())
|
||||
}
|
||||
s.nonce = s.clientNonce + serverNonce
|
||||
s.serverFirst = fmt.Sprintf("r=%s,s=%s,i=%d", s.nonce, base64.StdEncoding.EncodeToString(salt), iterations)
|
||||
return s.serverFirst, nil
|
||||
}
|
||||
|
||||
// Finish takes the final client message, and the salted password (probably
|
||||
// from server storage), verifies the client, and returns a message to return
|
||||
// to the client. If err is nil, authentication was successful. If the
|
||||
// authorization requested is not acceptable, the server should call
|
||||
// FinishError instead.
|
||||
func (s *Server) Finish(clientFinal []byte, saltedPassword []byte) (serverFinal string, rerr error) {
|
||||
p := newParser(clientFinal)
|
||||
defer p.recover(&rerr)
|
||||
|
||||
cbind := p.xchannelBinding()
|
||||
if cbind != s.gs2header {
|
||||
return "e=" + string(ErrChannelBindingsDontMatch), ErrChannelBindingsDontMatch
|
||||
}
|
||||
p.xtake(",")
|
||||
nonce := p.xnonce()
|
||||
if nonce != s.nonce {
|
||||
return "e=" + string(ErrInvalidProof), ErrInvalidProof
|
||||
}
|
||||
for !p.peek(",p=") {
|
||||
p.xtake(",")
|
||||
p.xattrval() // Ignored.
|
||||
}
|
||||
s.clientFinalWithoutProof = p.s[:p.o]
|
||||
p.xtake(",")
|
||||
proof := p.xproof()
|
||||
p.xempty()
|
||||
|
||||
msg := s.clientFirstBare + "," + s.serverFirst + "," + s.clientFinalWithoutProof
|
||||
|
||||
clientKey := HMAC(saltedPassword, "Client Key")
|
||||
storedKey0 := sha256.Sum256(clientKey)
|
||||
storedKey := storedKey0[:]
|
||||
|
||||
clientSig := HMAC(storedKey, msg)
|
||||
xor(clientSig, clientKey) // Now clientProof.
|
||||
if !bytes.Equal(clientSig, proof) {
|
||||
return "e=" + string(ErrInvalidProof), ErrInvalidProof
|
||||
}
|
||||
|
||||
serverKey := HMAC(saltedPassword, "Server Key")
|
||||
serverSig := HMAC(serverKey, msg)
|
||||
return fmt.Sprintf("v=%s", base64.StdEncoding.EncodeToString(serverSig)), nil
|
||||
}
|
||||
|
||||
// FinishError returns an error message to write to the client for the final
|
||||
// server message.
|
||||
func (s *Server) FinishError(err Error) string {
|
||||
return "e=" + string(err)
|
||||
}
|
||||
|
||||
// Client represents the client-side of a SCRAM-SHA-256 authentication.
|
||||
type Client struct {
|
||||
authc string
|
||||
authz string
|
||||
|
||||
// Messages used in hash calculations.
|
||||
clientFirstBare string
|
||||
serverFirst string
|
||||
clientFinalWithoutProof string
|
||||
authMessage string
|
||||
|
||||
gs2header string
|
||||
clientNonce string
|
||||
nonce string // Full client + server nonce.
|
||||
saltedPassword []byte
|
||||
}
|
||||
|
||||
// NewClient returns a client for authentication authc, optionally for
|
||||
// authorization with role authz.
|
||||
//
|
||||
// The sequence for data and calls on a client:
|
||||
//
|
||||
// - ClientFirst, write result to server.
|
||||
// - Read response from server, feed to ServerFirst, write response to server.
|
||||
// - Read response from server, feed to ServerFinal.
|
||||
func NewClient(authc, authz string) *Client {
|
||||
authc = norm.NFC.String(authc)
|
||||
authz = norm.NFC.String(authz)
|
||||
return &Client{authc: authc, authz: authz}
|
||||
}
|
||||
|
||||
// ClientFirst returns the first client message to write to the server.
|
||||
// No channel binding is done/supported.
|
||||
// A random nonce is generated.
|
||||
func (c *Client) ClientFirst() (clientFirst string, rerr error) {
|
||||
c.gs2header = fmt.Sprintf("n,%s,", saslname(c.authz))
|
||||
if c.clientNonce == "" {
|
||||
c.clientNonce = base64.StdEncoding.EncodeToString(MakeRandom())
|
||||
}
|
||||
c.clientFirstBare = fmt.Sprintf("n=%s,r=%s", saslname(c.authc), c.clientNonce)
|
||||
return c.gs2header + c.clientFirstBare, nil
|
||||
}
|
||||
|
||||
// ServerFirst processes the first response message from the server. The
|
||||
// provided nonce, salt and iterations are checked. If valid, a final client
|
||||
// message is calculated and returned. This message must be written to the
|
||||
// server. It includes proof that the client knows the password.
|
||||
func (c *Client) ServerFirst(serverFirst []byte, password string) (clientFinal string, rerr error) {
|
||||
c.serverFirst = string(serverFirst)
|
||||
p := newParser(serverFirst)
|
||||
defer p.recover(&rerr)
|
||||
|
||||
// ../rfc/5802:959
|
||||
if p.take("m=") {
|
||||
p.xerrorf("unsupported mandatory extension: %w", ErrExtensionsNotSupported)
|
||||
}
|
||||
|
||||
c.nonce = p.xnonce()
|
||||
p.xtake(",")
|
||||
salt := p.xsalt()
|
||||
p.xtake(",")
|
||||
iterations := p.xiterations()
|
||||
// We ignore extensions that we don't know about.
|
||||
for p.take(",") {
|
||||
p.xattrval()
|
||||
}
|
||||
p.xempty()
|
||||
|
||||
if !strings.HasPrefix(c.nonce, c.clientNonce) {
|
||||
return "", fmt.Errorf("%w: server dropped our nonce", ErrProtocol)
|
||||
}
|
||||
if len(c.nonce)-len(c.clientNonce) < 8 {
|
||||
return "", fmt.Errorf("%w: server nonce too short", ErrUnsafe)
|
||||
}
|
||||
if len(salt) < 8 {
|
||||
return "", fmt.Errorf("%w: salt too short", ErrUnsafe)
|
||||
}
|
||||
if iterations < 2048 {
|
||||
return "", fmt.Errorf("%w: too few iterations", ErrUnsafe)
|
||||
}
|
||||
|
||||
c.clientFinalWithoutProof = fmt.Sprintf("c=%s,r=%s", base64.StdEncoding.EncodeToString([]byte(c.gs2header)), c.nonce)
|
||||
|
||||
c.authMessage = c.clientFirstBare + "," + c.serverFirst + "," + c.clientFinalWithoutProof
|
||||
|
||||
c.saltedPassword = SaltPassword(password, salt, iterations)
|
||||
clientKey := HMAC(c.saltedPassword, "Client Key")
|
||||
storedKey0 := sha256.Sum256(clientKey)
|
||||
storedKey := storedKey0[:]
|
||||
clientSig := HMAC(storedKey, c.authMessage)
|
||||
xor(clientSig, clientKey) // Now clientProof.
|
||||
clientProof := clientSig
|
||||
|
||||
r := c.clientFinalWithoutProof + ",p=" + base64.StdEncoding.EncodeToString(clientProof)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// ServerFinal processes the final message from the server, verifying that the
|
||||
// server knows the password.
|
||||
func (c *Client) ServerFinal(serverFinal []byte) (rerr error) {
|
||||
p := newParser(serverFinal)
|
||||
defer p.recover(&rerr)
|
||||
|
||||
if p.take("e=") {
|
||||
errstr := p.xvalue()
|
||||
var err error = scramErrors[errstr]
|
||||
if err == Error("") {
|
||||
err = errors.New(errstr)
|
||||
}
|
||||
return fmt.Errorf("error from server: %w", err)
|
||||
}
|
||||
p.xtake("v=")
|
||||
verifier := p.xbase64()
|
||||
|
||||
serverKey := HMAC(c.saltedPassword, "Server Key")
|
||||
serverSig := HMAC(serverKey, c.authMessage)
|
||||
if !bytes.Equal(verifier, serverSig) {
|
||||
return fmt.Errorf("incorrect server signature")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert "," to =2C and "=" to =3D.
|
||||
func saslname(s string) string {
|
||||
var r string
|
||||
for _, c := range s {
|
||||
if c == ',' {
|
||||
r += "=2C"
|
||||
} else if c == '=' {
|
||||
r += "=3D"
|
||||
} else {
|
||||
r += string(c)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
169
scram/scram_test.go
Normal file
169
scram/scram_test.go
Normal file
@ -0,0 +1,169 @@
|
||||
package scram
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func base64Decode(s string) []byte {
|
||||
buf, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
panic("bad base64")
|
||||
}
|
||||
return buf
|
||||
}
|
||||
|
||||
func tcheck(t *testing.T, err error, msg string) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatalf("%s: %s", msg, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScramServer(t *testing.T) {
|
||||
// Test vector from ../rfc/7677:122
|
||||
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
|
||||
saltedPassword := SaltPassword("pencil", salt, 4096)
|
||||
|
||||
server, err := NewServer([]byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"))
|
||||
server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
|
||||
tcheck(t, err, "newserver")
|
||||
resp, err := server.ServerFirst(4096, salt)
|
||||
tcheck(t, err, "server first")
|
||||
if resp != "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096" {
|
||||
t.Fatalf("bad server first")
|
||||
}
|
||||
serverFinal, err := server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
|
||||
tcheck(t, err, "finish")
|
||||
if serverFinal != "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4=" {
|
||||
t.Fatalf("bad server final")
|
||||
}
|
||||
}
|
||||
|
||||
// Bad attempt with wrong password.
|
||||
func TestScramServerBadPassword(t *testing.T) {
|
||||
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
|
||||
saltedPassword := SaltPassword("marker", salt, 4096)
|
||||
|
||||
server, err := NewServer([]byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"))
|
||||
server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
|
||||
tcheck(t, err, "newserver")
|
||||
_, err = server.ServerFirst(4096, salt)
|
||||
tcheck(t, err, "server first")
|
||||
_, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
|
||||
if !errors.Is(err, ErrInvalidProof) {
|
||||
t.Fatalf("got %v, expected ErrInvalidProof", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Bad attempt with different number of rounds.
|
||||
func TestScramServerBadIterations(t *testing.T) {
|
||||
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
|
||||
saltedPassword := SaltPassword("pencil", salt, 2048)
|
||||
|
||||
server, err := NewServer([]byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"))
|
||||
server.serverNonceOverride = "%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0"
|
||||
tcheck(t, err, "newserver")
|
||||
_, err = server.ServerFirst(4096, salt)
|
||||
tcheck(t, err, "server first")
|
||||
_, err = server.Finish([]byte("c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
|
||||
if !errors.Is(err, ErrInvalidProof) {
|
||||
t.Fatalf("got %v, expected ErrInvalidProof", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Another attempt but with a randomly different nonce.
|
||||
func TestScramServerBad(t *testing.T) {
|
||||
salt := base64Decode("W22ZaJ0SNY7soEsUEjb6gQ==")
|
||||
saltedPassword := SaltPassword("pencil", salt, 4096)
|
||||
|
||||
server, err := NewServer([]byte("n,,n=user,r=rOprNGfwEbeRWgbNEkqO"))
|
||||
tcheck(t, err, "newserver")
|
||||
_, err = server.ServerFirst(4096, salt)
|
||||
tcheck(t, err, "server first")
|
||||
_, err = server.Finish([]byte("c=biws,r="+server.nonce+",p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="), saltedPassword)
|
||||
if !errors.Is(err, ErrInvalidProof) {
|
||||
t.Fatalf("got %v, expected ErrInvalidProof", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScramClient(t *testing.T) {
|
||||
c := NewClient("user", "")
|
||||
c.clientNonce = "rOprNGfwEbeRWgbNEkqO"
|
||||
clientFirst, err := c.ClientFirst()
|
||||
tcheck(t, err, "ClientFirst")
|
||||
if clientFirst != "n,,n=user,r=rOprNGfwEbeRWgbNEkqO" {
|
||||
t.Fatalf("bad clientFirst")
|
||||
}
|
||||
clientFinal, err := c.ServerFirst([]byte("r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"), "pencil")
|
||||
tcheck(t, err, "ServerFirst")
|
||||
if clientFinal != "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ=" {
|
||||
t.Fatalf("bad clientFinal")
|
||||
}
|
||||
err = c.ServerFinal([]byte("v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="))
|
||||
tcheck(t, err, "ServerFinal")
|
||||
}
|
||||
|
||||
func TestScram(t *testing.T) {
|
||||
run := func(expErr error, username, authzid, password string, iterations int, clientNonce, serverNonce string) {
|
||||
t.Helper()
|
||||
|
||||
defer func() {
|
||||
x := recover()
|
||||
if x == nil || x == "" {
|
||||
return
|
||||
}
|
||||
panic(x)
|
||||
}()
|
||||
|
||||
// check err is either nil or the expected error. if the expected error, panic to abort the authentication session.
|
||||
xerr := func(err error, msg string) {
|
||||
t.Helper()
|
||||
if err != nil && !errors.Is(err, expErr) {
|
||||
t.Fatalf("%s: got %v, expected %v", msg, err, expErr)
|
||||
}
|
||||
if err != nil {
|
||||
panic("") // Abort test.
|
||||
}
|
||||
}
|
||||
|
||||
salt := MakeRandom()
|
||||
saltedPassword := SaltPassword(password, salt, iterations)
|
||||
|
||||
client := NewClient(username, "")
|
||||
client.clientNonce = clientNonce
|
||||
clientFirst, err := client.ClientFirst()
|
||||
xerr(err, "client.ClientFirst")
|
||||
|
||||
server, err := NewServer([]byte(clientFirst))
|
||||
xerr(err, "NewServer")
|
||||
server.serverNonceOverride = serverNonce
|
||||
|
||||
serverFirst, err := server.ServerFirst(iterations, salt)
|
||||
xerr(err, "server.ServerFirst")
|
||||
|
||||
clientFinal, err := client.ServerFirst([]byte(serverFirst), password)
|
||||
xerr(err, "client.ServerFirst")
|
||||
|
||||
serverFinal, err := server.Finish([]byte(clientFinal), saltedPassword)
|
||||
xerr(err, "server.Finish")
|
||||
|
||||
err = client.ServerFinal([]byte(serverFinal))
|
||||
xerr(err, "client.ServerFinal")
|
||||
|
||||
if expErr != nil {
|
||||
t.Fatalf("got no error, expected %v", expErr)
|
||||
}
|
||||
}
|
||||
|
||||
run(nil, "user", "", "pencil", 4096, "", "")
|
||||
run(nil, "mjl@mox.example", "", "testtest", 4096, "", "")
|
||||
run(nil, "mjl@mox.example", "", "short", 4096, "", "")
|
||||
run(nil, "mjl@mox.example", "", "short", 2048, "", "")
|
||||
run(nil, "mjl@mox.example", "mjl@mox.example", "testtest", 4096, "", "")
|
||||
run(nil, "mjl@mox.example", "other@mox.example", "testtest", 4096, "", "")
|
||||
run(ErrUnsafe, "user", "", "pencil", 1, "", "") // Few iterations.
|
||||
run(ErrUnsafe, "user", "", "pencil", 2048, "short", "") // Short client nonce.
|
||||
run(ErrUnsafe, "user", "", "pencil", 2048, "test1234", "test") // Server added too few random data.
|
||||
}
|
Reference in New Issue
Block a user