imapserver: Prevent spurious unhandled panics for connections with compress=deflate that break

Writing to a connection goes through the flate library to compress. That writes
the compressed bytes to the underlying connection. But that underlying
connection is wrapped to raise a panic with an i/o error instead of returning a
normal error.  Jumping out of flate leaves the internal state of the compressor
in undefined state. So far so good. But as part of cleaning up the connection,
we could try to flush output again. Specifically: If we were writing user data,
we had switched from tracing of protocol data to tracing of user data, and we
registered a defer that restored the tracing kind and flushed (to ensure data
was traced at the right level). That flush would cause a write into the
compressor again, which could panic with an out of bounds slice access due to
its inconsistent internal state.

This fix prevents that compressor panic in two ways:

1. We wrap the flate.Writer with a moxio.FlateWriter that keeps track of
   whether a panic came out of an operation on it. If so, any further operation
   raises the same panic. This prevents access to the inconsistent internal flate
   state entirely.
2. Once we raise an i/o error, we mark the connection as broken and that makes
   flushes a no-op.
This commit is contained in:
Mechiel Lukkien 2025-02-26 10:50:04 +01:00
parent ea55c85938
commit 17de90e29d
No known key found for this signature in database
7 changed files with 157 additions and 26 deletions

View File

@ -22,8 +22,6 @@ import (
"reflect" "reflect"
"strings" "strings"
"github.com/mjl-/flate"
"github.com/mjl-/mox/mlog" "github.com/mjl-/mox/mlog"
"github.com/mjl-/mox/moxio" "github.com/mjl-/mox/moxio"
) )
@ -34,10 +32,11 @@ type Conn struct {
// writes through c.bw. It wraps a tracing reading/writer and may wrap flate // writes through c.bw. It wraps a tracing reading/writer and may wrap flate
// compression. // compression.
conn net.Conn conn net.Conn
connBroken bool // If connection is broken, we won't flush (and write) again.
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
compress bool // If compression is enabled, we must flush flateWriter and its target original bufio writer. compress bool // If compression is enabled, we must flush flateWriter and its target original bufio writer.
flateWriter *flate.Writer flateWriter *moxio.FlateWriter
flateBW *bufio.Writer flateBW *bufio.Writer
log mlog.Log log mlog.Log
@ -146,11 +145,19 @@ func (c *Conn) Write(buf []byte) (n int, rerr error) {
defer c.recover(&rerr) defer c.recover(&rerr)
n, rerr = c.conn.Write(buf) n, rerr = c.conn.Write(buf)
if rerr != nil {
c.connBroken = true
}
c.xcheckf(rerr, "write") c.xcheckf(rerr, "write")
return n, nil return n, nil
} }
func (c *Conn) xflush() { func (c *Conn) xflush() {
// Not writing any more when connection is broken.
if c.connBroken {
return
}
err := c.bw.Flush() err := c.bw.Flush()
c.xcheckf(err, "flush") c.xcheckf(err, "flush")
@ -173,7 +180,7 @@ func (c *Conn) Close() (rerr error) {
if c.conn == nil { if c.conn == nil {
return nil return nil
} }
if c.flateWriter != nil { if !c.connBroken && c.flateWriter != nil {
err := c.flateWriter.Close() err := c.flateWriter.Close()
c.xcheckf(err, "close deflate writer") c.xcheckf(err, "close deflate writer")
err = c.flateBW.Flush() err = c.flateBW.Flush()

View File

@ -140,8 +140,9 @@ func (c *Conn) CompressDeflate() (untagged []Untagged, result Result, rerr error
c.xcheck(rerr) c.xcheck(rerr)
c.flateBW = bufio.NewWriter(c) c.flateBW = bufio.NewWriter(c)
fw, err := flate.NewWriter(c.flateBW, flate.DefaultCompression) fw0, err := flate.NewWriter(c.flateBW, flate.DefaultCompression)
c.xcheckf(err, "deflate") // Cannot happen. c.xcheckf(err, "deflate") // Cannot happen.
fw := moxio.NewFlateWriter(fw0)
c.compress = true c.compress = true
c.flateWriter = fw c.flateWriter = fw

View File

@ -2,7 +2,11 @@ package imapserver
import ( import (
"crypto/tls" "crypto/tls"
"encoding/base64"
"io"
mathrand "math/rand/v2"
"testing" "testing"
"time"
) )
func TestCompress(t *testing.T) { func TestCompress(t *testing.T) {
@ -37,3 +41,42 @@ func TestCompressStartTLS(t *testing.T) {
tc.transactf("ok", "noop") tc.transactf("ok", "noop")
tc.transactf("ok", "fetch 1 body.peek[1]") tc.transactf("ok", "fetch 1 body.peek[1]")
} }
func TestCompressBreak(t *testing.T) {
// Close the client connection when the server is writing. That causes writes in
// the server to fail (panic), jumping out of the flate writer and leaving its
// state inconsistent. We must not call into the flate writer again because due to
// its broken internal state it may cause array out of bounds accesses.
tc := start(t)
defer tc.close()
msg := exampleMsg
// Add random data (so it is not compressible). Don't know why, but only
// reproducible with large writes. As if setting socket buffers had no effect.
buf := make([]byte, 64*1024)
_, err := io.ReadFull(mathrand.NewChaCha8([32]byte{}), buf)
tcheck(t, err, "read random")
text := base64.StdEncoding.EncodeToString(buf)
for len(text) > 0 {
n := min(78, len(text))
msg += text[:n] + "\r\n"
text = text[n:]
}
tc.client.Login("mjl@mox.example", password0)
tc.client.CompressDeflate()
tc.client.Select("inbox")
tc.transactf("ok", "append inbox (\\seen) {%d+}\r\n%s", len(msg), msg)
tc.transactf("ok", "noop")
// Write request. Close connection instead of reading data. Write will panic,
// coming through flate writer leaving its state inconsistent. Server must not try
// to Flush/Write again on flate writer or it may panic.
tc.client.Writelinef("x fetch 1 body.peek[1]")
// Close client connection and prevent cleanup from closing the client again.
time.Sleep(time.Second / 10)
tc.client = nil
tc.conn.Close() // Simulate client disappearing.
}

View File

@ -211,10 +211,10 @@ func (c *conn) cmdxReplace(isUID bool, tag, cmd string, p *parser) {
c.xtrace(mlog.LevelTrace) // Restore. c.xtrace(mlog.LevelTrace) // Restore.
if err != nil { if err != nil {
// Cannot use xcheckf due to %w handling of errIO. // Cannot use xcheckf due to %w handling of errIO.
xserverErrorf("reading literal message: %s (%w)", err, errIO) c.xbrokenf("reading literal message: %s (%w)", err, errIO)
} }
if msize != size { if msize != size {
xserverErrorf("read %d bytes for message, expected %d (%w)", msize, size, errIO) c.xbrokenf("read %d bytes for message, expected %d (%w)", msize, size, errIO)
} }
// Finish reading the command. // Finish reading the command.

View File

@ -56,6 +56,7 @@ import (
"sort" "sort"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@ -102,6 +103,8 @@ var (
) )
) )
var unhandledPanics atomic.Int64 // For tests.
var limiterConnectionrate, limiterConnections *ratelimit.Limiter var limiterConnectionrate, limiterConnections *ratelimit.Limiter
func init() { func init() {
@ -177,6 +180,7 @@ type conn struct {
cid int64 cid int64
state state state state
conn net.Conn conn net.Conn
connBroken bool // Once broken, we won't flush any more data.
tls bool // Whether TLS has been initialized. tls bool // Whether TLS has been initialized.
viaHTTPS bool // Whether this connection came in via HTTPS (using TLS ALPN). viaHTTPS bool // Whether this connection came in via HTTPS (using TLS ALPN).
br *bufio.Reader // From remote, with TLS unwrapped in case of TLS, and possibly wrapping inflate. br *bufio.Reader // From remote, with TLS unwrapped in case of TLS, and possibly wrapping inflate.
@ -197,7 +201,7 @@ type conn struct {
log mlog.Log // Used for all synchronous logging on this connection, see logbg for logging in a separate goroutine. log mlog.Log // Used for all synchronous logging on this connection, see logbg for logging in a separate goroutine.
enabled map[capability]bool // All upper-case. enabled map[capability]bool // All upper-case.
compress bool // Whether compression is enabled, via compress command. compress bool // Whether compression is enabled, via compress command.
flateWriter *flate.Writer // For flushing output after flushing conn.bw, and for closing. flateWriter *moxio.FlateWriter // For flushing output after flushing conn.bw, and for closing.
flateBW *bufio.Writer // Wraps raw connection writes, flateWriter writes here, also needs flushing. flateBW *bufio.Writer // Wraps raw connection writes, flateWriter writes here, also needs flushing.
// Set by SEARCH with SAVE. Can be used by commands accepting a sequence-set with // Set by SEARCH with SAVE. Can be used by commands accepting a sequence-set with
@ -343,6 +347,11 @@ func (c *conn) xsanity(err error, format string, args ...any) {
c.log.Errorx(fmt.Sprintf(format, args...), err) c.log.Errorx(fmt.Sprintf(format, args...), err)
} }
func (c *conn) xbrokenf(format string, args ...any) {
c.connBroken = true
panic(fmt.Errorf(format, args...))
}
type msgseq uint32 type msgseq uint32
// Listen initializes all imap listeners for the configuration, and stores them for Serve to start them. // Listen initializes all imap listeners for the configuration, and stores them for Serve to start them.
@ -499,7 +508,7 @@ func (c *conn) Write(buf []byte) (int, error) {
nn, err := c.conn.Write(buf[:chunk]) nn, err := c.conn.Write(buf[:chunk])
if err != nil { if err != nil {
panic(fmt.Errorf("write: %s (%w)", err, errIO)) c.xbrokenf("write: %s (%w)", err, errIO)
} }
n += nn n += nn
buf = buf[chunk:] buf = buf[chunk:]
@ -542,6 +551,7 @@ func (c *conn) readline0() (string, error) {
if err != nil && errors.Is(err, moxio.ErrLineTooLong) { if err != nil && errors.Is(err, moxio.ErrLineTooLong) {
return "", fmt.Errorf("%s (%w)", err, errProtocol) return "", fmt.Errorf("%s (%w)", err, errProtocol)
} else if err != nil { } else if err != nil {
c.connBroken = true
return "", fmt.Errorf("%s (%w)", err, errIO) return "", fmt.Errorf("%s (%w)", err, errIO)
} }
return line, nil return line, nil
@ -576,7 +586,7 @@ func (c *conn) readline(readCmd bool) string {
c.writelinef("* BYE inactive") c.writelinef("* BYE inactive")
} }
if !errors.Is(err, errIO) && !errors.Is(err, errProtocol) { if !errors.Is(err, errIO) && !errors.Is(err, errProtocol) {
err = fmt.Errorf("%s (%w)", err, errIO) c.xbrokenf("%s (%w)", err, errIO)
} }
panic(err) panic(err)
} }
@ -628,6 +638,11 @@ func (c *conn) bwritelinef(format string, args ...any) {
} }
func (c *conn) xflush() { func (c *conn) xflush() {
// If the connection is already broken, we're not going to write more.
if c.connBroken {
return
}
err := c.bw.Flush() err := c.bw.Flush()
xcheckf(err, "flush") // Should never happen, the Write caused by the Flush should panic on i/o error. xcheckf(err, "flush") // Should never happen, the Write caused by the Flush should panic on i/o error.
@ -668,8 +683,7 @@ func (c *conn) xreadliteral(size int64, sync bool) []byte {
_, err := io.ReadFull(c.br, buf) _, err := io.ReadFull(c.br, buf)
if err != nil { if err != nil {
// Cannot use xcheckf due to %w handling of errIO. c.xbrokenf("reading literal: %s (%w)", err, errIO)
panic(fmt.Errorf("reading literal: %s (%w)", err, errIO))
} }
} }
return buf return buf
@ -780,6 +794,7 @@ func serve(listenerName string, cid int64, tlsConfig *tls.Config, nc net.Conn, x
c.log.Error("unhandled panic", slog.Any("err", x)) c.log.Error("unhandled panic", slog.Any("err", x))
debug.PrintStack() debug.PrintStack()
metrics.PanicInc(metrics.Imapserver) metrics.PanicInc(metrics.Imapserver)
unhandledPanics.Add(1) // For tests.
} }
}() }()
@ -1067,7 +1082,7 @@ func (c *conn) xtlsHandshakeAndAuthenticate(conn net.Conn) {
defer cancel() defer cancel()
c.log.Debug("starting tls server handshake") c.log.Debug("starting tls server handshake")
if err := tlsConn.HandshakeContext(ctx); err != nil { if err := tlsConn.HandshakeContext(ctx); err != nil {
panic(fmt.Errorf("tls handshake: %s (%w)", err, errIO)) c.xbrokenf("tls handshake: %s (%w)", err, errIO)
} }
cancel() cancel()
@ -1076,8 +1091,8 @@ func (c *conn) xtlsHandshakeAndAuthenticate(conn net.Conn) {
// Verify client after session resumption. // Verify client after session resumption.
err := c.tlsClientAuthVerifyPeerCertParsed(cs.PeerCertificates[0]) err := c.tlsClientAuthVerifyPeerCertParsed(cs.PeerCertificates[0])
if err != nil { if err != nil {
c.bwritelinef("* BYE [ALERT] Error verifying client certificate after TLS session resumption: %s", err) c.writelinef("* BYE [ALERT] Error verifying client certificate after TLS session resumption: %s", err)
panic(fmt.Errorf("tls verify client certificate after resumption: %s (%w)", err, errIO)) c.xbrokenf("tls verify client certificate after resumption: %s (%w)", err, errIO)
} }
} }
@ -1162,7 +1177,7 @@ func (c *conn) command() {
// stop processing because there is a good chance whatever they sent has multiple // stop processing because there is a good chance whatever they sent has multiple
// lines. // lines.
c.writelinef("* BYE please try again speaking imap") c.writelinef("* BYE please try again speaking imap")
panic(errIO) c.xbrokenf("not speaking imap (%w)", errIO)
} }
c.log.Debugx("imap command syntax error", sxerr.err, logFields...) c.log.Debugx("imap command syntax error", sxerr.err, logFields...)
c.log.Info("imap syntax error", slog.String("lastline", c.lastLine)) c.log.Info("imap syntax error", slog.String("lastline", c.lastLine))
@ -1215,7 +1230,7 @@ func (c *conn) command() {
case <-mox.Shutdown.Done(): case <-mox.Shutdown.Done():
// ../rfc/9051:5375 // ../rfc/9051:5375
c.writelinef("* BYE shutting down") c.writelinef("* BYE shutting down")
panic(errIO) c.xbrokenf("shutting down (%w)", errIO)
default: default:
} }
@ -1851,8 +1866,9 @@ func (c *conn) cmdCompress(tag, cmd string, p *parser) {
c.ok(tag, cmd) c.ok(tag, cmd)
c.flateBW = bufio.NewWriter(c) c.flateBW = bufio.NewWriter(c)
fw, err := flate.NewWriter(c.flateBW, flate.DefaultCompression) fw0, err := flate.NewWriter(c.flateBW, flate.DefaultCompression)
xcheckf(err, "deflate") // Cannot happen. xcheckf(err, "deflate") // Cannot happen.
fw := moxio.NewFlateWriter(fw0)
c.compress = true c.compress = true
c.flateWriter = fw c.flateWriter = fw
@ -3452,10 +3468,10 @@ func (c *conn) cmdAppend(tag, cmd string, p *parser) {
c.xtrace(mlog.LevelTrace) // Restore. c.xtrace(mlog.LevelTrace) // Restore.
if err != nil { if err != nil {
// Cannot use xcheckf due to %w handling of errIO. // Cannot use xcheckf due to %w handling of errIO.
panic(fmt.Errorf("reading literal message: %s (%w)", err, errIO)) c.xbrokenf("reading literal message: %s (%w)", err, errIO)
} }
if msize != size { if msize != size {
xserverErrorf("read %d bytes for message, expected %d (%w)", msize, size, errIO) c.xbrokenf("read %d bytes for message, expected %d (%w)", msize, size, errIO)
} }
totalSize += msize totalSize += msize
@ -3610,7 +3626,7 @@ wait:
case <-mox.Shutdown.Done(): case <-mox.Shutdown.Done():
// ../rfc/9051:5375 // ../rfc/9051:5375
c.writelinef("* BYE shutting down") c.writelinef("* BYE shutting down")
panic(errIO) c.xbrokenf("shutting down (%w)", errIO)
} }
} }
@ -3621,7 +3637,7 @@ wait:
if strings.ToUpper(line) != "DONE" { if strings.ToUpper(line) != "DONE" {
// We just close the connection because our protocols are out of sync. // We just close the connection because our protocols are out of sync.
panic(fmt.Errorf("%w: in IDLE, expected DONE", errIO)) c.xbrokenf("%w: in IDLE, expected DONE", errIO)
} }
c.ok(tag, cmd) c.ok(tag, cmd)

View File

@ -309,6 +309,12 @@ func (tc *testconn) waitDone() {
} }
func (tc *testconn) close() { func (tc *testconn) close() {
defer func() {
if unhandledPanics.Swap(0) > 0 {
tc.t.Fatalf("handled panic in server")
}
}()
if tc.account == nil { if tc.account == nil {
// Already closed, we are not strict about closing multiple times. // Already closed, we are not strict about closing multiple times.
return return
@ -317,7 +323,9 @@ func (tc *testconn) close() {
tc.check(err, "close account") tc.check(err, "close account")
// no account.CheckClosed(), the tests open accounts multiple times. // no account.CheckClosed(), the tests open accounts multiple times.
tc.account = nil tc.account = nil
tc.client.Close() if tc.client != nil {
tc.client.Close()
}
tc.serverConn.Close() tc.serverConn.Close()
tc.waitDone() tc.waitDone()
if tc.switchStop != nil { if tc.switchStop != nil {
@ -381,9 +389,9 @@ func (c namedConn) RemoteAddr() net.Addr {
func startArgsMore(t *testing.T, first, immediateTLS bool, serverConfig, clientConfig *tls.Config, allowLoginWithoutTLS, noCloseSwitchboard, setPassword bool, accname string, afterInit func() error) *testconn { func startArgsMore(t *testing.T, first, immediateTLS bool, serverConfig, clientConfig *tls.Config, allowLoginWithoutTLS, noCloseSwitchboard, setPassword bool, accname string, afterInit func() error) *testconn {
limitersInit() // Reset rate limiters. limitersInit() // Reset rate limiters.
mox.ConfigStaticPath = filepath.FromSlash("../testdata/imap/mox.conf")
mox.MustLoadConfig(true, false)
if first { if first {
mox.ConfigStaticPath = filepath.FromSlash("../testdata/imap/mox.conf")
mox.MustLoadConfig(true, false)
store.Close() // May not be open, we ignore error. store.Close() // May not be open, we ignore error.
os.RemoveAll("../testdata/imap/data") os.RemoveAll("../testdata/imap/data")
err := store.Init(ctxbg) err := store.Init(ctxbg)
@ -418,7 +426,15 @@ func startArgsMore(t *testing.T, first, immediateTLS bool, serverConfig, clientC
tcheck(t, err, "fileconn") tcheck(t, err, "fileconn")
err = f.Close() err = f.Close()
tcheck(t, err, "close file for conn") tcheck(t, err, "close file for conn")
return namedConn{fc}
// Small read/write buffers, for detecting closed/broken connections quickly.
uc := fc.(*net.UnixConn)
err = uc.SetReadBuffer(512)
tcheck(t, err, "set read buffer")
uc.SetWriteBuffer(512)
tcheck(t, err, "set write buffer")
return namedConn{uc}
} }
serverConn := xfdconn(fds[0], "server") serverConn := xfdconn(fds[0], "server")
clientConn := xfdconn(fds[1], "client") clientConn := xfdconn(fds[1], "client")

48
moxio/flatewriter.go Normal file
View File

@ -0,0 +1,48 @@
package moxio
import (
"github.com/mjl-/flate"
)
// FlateWriter wraps a flate.Writer and ensures no Write/Flush/Close calls are made
// again on the underlying flate writer when a panic came out of the flate writer
// (e.g. raised by the destination writer of the flate writer). After a panic
// "through" a flate.Writer, its state is inconsistent and further calls could
// panic with out of bounds slice accesses.
type FlateWriter struct {
w *flate.Writer
panic any
}
func NewFlateWriter(w *flate.Writer) *FlateWriter {
return &FlateWriter{w, nil}
}
func (w *FlateWriter) checkBroken() func() {
if w.panic != nil {
panic(w.panic)
}
return func() {
x := recover()
if x == nil {
return
}
w.panic = x
panic(x)
}
}
func (w *FlateWriter) Write(data []byte) (int, error) {
defer w.checkBroken()()
return w.w.Write(data)
}
func (w *FlateWriter) Flush() error {
defer w.checkBroken()()
return w.w.Flush()
}
func (w *FlateWriter) Close() error {
defer w.checkBroken()()
return w.w.Close()
}