add reverse proxying websocket connections

if we recognize that a request for a WebForward is trying to turn the
connection into a websocket, we forward it to the backend and check if the
backend understands the websocket request. if so, we pass back the upgrade
response and get out of the way, copying bytes between the two. we do log the
total amount of bytes read from the client and written to the client. if the
backend doesn't respond with a websocke response, or an invalid one, we respond
with a regular non-websocket response. and we log details about the failed
connection, should help with debugging and any bug reports.

we don't try to parse the websocket framing, that's between the client and the
backend.  we could try to parse it, in part to protect the backend from bad
frames, but it would be a lot of work and could be brittle in the face of
extensions.

this doesn't yet handle websocket connections when a http proxy is configured.
we'll implement it when someone needs it. we do recognize it and fail the
connection.

for issue #25
This commit is contained in:
Mechiel Lukkien
2023-05-30 22:11:31 +02:00
parent aca64828bd
commit 259928ab62
15 changed files with 1966 additions and 49 deletions

View File

@ -42,7 +42,7 @@ var (
},
[]string{
"handler", // Name from webhandler, can be empty.
"proto", // "http" or "https"
"proto", // "http", "https", "ws", "wss"
"method", // "(unknown)" and otherwise only common verbs
"code",
},
@ -58,7 +58,7 @@ var (
},
[]string{
"handler", // Name from webhandler, can be empty.
"proto", // "http" or "https"
"proto", // "http", "https", "ws", "wss"
"method", // "(unknown)" and otherwise only common verbs
"code",
},
@ -69,22 +69,37 @@ var (
// http.ResponseWriter that writes access log and tracks metrics at end of response.
type loggingWriter struct {
W http.ResponseWriter // Calls are forwarded.
Start time.Time
R *http.Request
W http.ResponseWriter // Calls are forwarded.
Start time.Time
R *http.Request
WebsocketRequest bool // Whether request from was websocket.
Handler string // Set by router.
// Set by handlers.
StatusCode int
Size int64
WriteErr error
StatusCode int
Size int64 // Of data served, for non-websocket responses.
Err error
WebsocketResponse bool // If this was a successful websocket connection with backend.
SizeFromClient, SizeToClient int64 // Websocket data.
}
func (w *loggingWriter) Header() http.Header {
return w.W.Header()
}
// protocol, for logging.
func (w *loggingWriter) proto(websocket bool) string {
proto := "http"
if websocket {
proto = "ws"
}
if w.R.TLS != nil {
proto += "s"
}
return proto
}
func (w *loggingWriter) setStatusCode(statusCode int) {
if w.StatusCode != 0 {
return
@ -92,11 +107,7 @@ func (w *loggingWriter) setStatusCode(statusCode int) {
w.StatusCode = statusCode
method := metricHTTPMethod(w.R.Method)
proto := "http"
if w.R.TLS != nil {
proto = "https"
}
metricRequest.WithLabelValues(w.Handler, proto, method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
metricRequest.WithLabelValues(w.Handler, w.proto(w.WebsocketRequest), method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
}
func (w *loggingWriter) Write(buf []byte) (int, error) {
@ -108,8 +119,8 @@ func (w *loggingWriter) Write(buf []byte) (int, error) {
if n > 0 {
w.Size += int64(n)
}
if err != nil && w.WriteErr == nil {
w.WriteErr = err
if err != nil {
w.error(err)
}
return n, err
}
@ -136,13 +147,15 @@ func metricHTTPMethod(method string) string {
return "(other)"
}
func (w *loggingWriter) error(err error) {
if w.Err == nil {
w.Err = err
}
}
func (w *loggingWriter) Done() {
method := metricHTTPMethod(w.R.Method)
proto := "http"
if w.R.TLS != nil {
proto = "https"
}
metricResponse.WithLabelValues(w.Handler, proto, method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
metricResponse.WithLabelValues(w.Handler, w.proto(w.WebsocketResponse), method, fmt.Sprintf("%d", w.StatusCode)).Observe(float64(time.Since(w.Start)) / float64(time.Second))
tlsinfo := "plain"
if w.R.TLS != nil {
@ -152,25 +165,41 @@ func (w *loggingWriter) Done() {
tlsinfo = "(other)"
}
}
err := w.WriteErr
err := w.Err
if err == nil {
err = w.R.Context().Err()
}
xlog.WithContext(w.R.Context()).Debugx("http request", err,
fields := []mlog.Pair{
mlog.Field("httpaccess", ""),
mlog.Field("handler", w.Handler),
mlog.Field("method", method),
mlog.Field("url", w.R.URL),
mlog.Field("host", w.R.Host),
mlog.Field("duration", time.Since(w.Start)),
mlog.Field("size", w.Size),
mlog.Field("statuscode", w.StatusCode),
mlog.Field("proto", strings.ToLower(w.R.Proto)),
mlog.Field("remoteaddr", w.R.RemoteAddr),
mlog.Field("tlsinfo", tlsinfo),
mlog.Field("useragent", w.R.Header.Get("User-Agent")),
mlog.Field("referrr", w.R.Header.Get("Referrer")),
)
}
if w.WebsocketRequest {
fields = append(fields,
mlog.Field("websocketrequest", true),
)
}
if w.WebsocketResponse {
fields = append(fields,
mlog.Field("websocket", true),
mlog.Field("sizetoclient", w.SizeToClient),
mlog.Field("sizefromclient", w.SizeFromClient),
)
} else {
fields = append(fields,
mlog.Field("size", w.Size),
)
}
xlog.WithContext(w.R.Context()).Debugx("http request", err, fields...)
}
// Set some http headers that should prevent potential abuse. Better safe than sorry.