Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modules/caddyhttp/intercept/intercept.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func (ir Intercept) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddy

repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
rec := interceptedResponseHandler{replacer: repl}
rec.ResponseRecorder = caddyhttp.NewResponseRecorder(w, buf, func(status int, header http.Header) bool {
rec.ResponseRecorder = caddyhttp.NewResponseRecorder(w, r, buf, func(status int, header http.Header) bool {
// see if any response handler is configured for this original response
for i, rh := range ir.HandleResponse {
if rh.Match != nil && !rh.Match.Match(status, header) {
Expand Down
12 changes: 4 additions & 8 deletions modules/caddyhttp/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
h.metrics.httpMetrics.responseDuration.With(statusLabels).Observe(ttfb)
return false
})
wrec := NewResponseRecorder(w, nil, writeHeaderRecorder)
wrec := NewResponseRecorder(w, r, nil, writeHeaderRecorder)
err := h.mh.ServeHTTP(wrec, r, next)
dur := time.Since(start).Seconds()
h.metrics.httpMetrics.requestCount.With(labels).Inc()
Expand All @@ -168,7 +168,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
}

h.metrics.httpMetrics.requestDuration.With(statusLabels).Observe(dur)
h.metrics.httpMetrics.requestSize.With(statusLabels).Observe(float64(computeApproximateRequestSize(r)))
h.metrics.httpMetrics.requestSize.With(statusLabels).Observe(float64(computeApproximateRequestSize(wrec, r)))
h.metrics.httpMetrics.responseSize.With(statusLabels).Observe(float64(wrec.Size()))
}

Expand All @@ -189,7 +189,7 @@ func (h *metricsInstrumentedHandler) ServeHTTP(w http.ResponseWriter, r *http.Re
}

// taken from https://github.com/prometheus/client_golang/blob/6007b2b5cae01203111de55f753e76d8dac1f529/prometheus/promhttp/instrument_server.go#L298
func computeApproximateRequestSize(r *http.Request) int {
func computeApproximateRequestSize(wrec ResponseRecorder, r *http.Request) int {
s := 0
if r.URL != nil {
s += len(r.URL.String())
Expand All @@ -205,10 +205,6 @@ func computeApproximateRequestSize(r *http.Request) int {
}
s += len(r.Host)

// N.B. r.Form and r.MultipartForm are assumed to be included in r.URL.

if r.ContentLength != -1 {
s += int(r.ContentLength)
}
s += wrec.RequestSize()
return s
}
52 changes: 39 additions & 13 deletions modules/caddyhttp/responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ type responseRecorder struct {
wroteHeader bool
stream bool

readSize *int
reqBodyLengthReader lengthReader
}

// NewResponseRecorder returns a new ResponseRecorder that can be
Expand Down Expand Up @@ -101,7 +101,7 @@ type responseRecorder struct {
//
// Proper usage of a recorder looks like this:
//
// rec := caddyhttp.NewResponseRecorder(w, buf, shouldBuffer)
// rec := caddyhttp.NewResponseRecorder(w, req, buf, shouldBuffer)
// err := next.ServeHTTP(rec, req)
// if err != nil {
// return err
Expand Down Expand Up @@ -134,12 +134,19 @@ type responseRecorder struct {
// As a special case, 1xx responses are not buffered nor recorded
// because they are not the final response; they are passed through
// directly to the underlying ResponseWriter.
func NewResponseRecorder(w http.ResponseWriter, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder {
return &responseRecorder{
func NewResponseRecorder(w http.ResponseWriter, r *http.Request, buf *bytes.Buffer, shouldBuffer ShouldBufferFunc) ResponseRecorder {
rr := &responseRecorder{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: w},
buf: buf,
shouldBuffer: shouldBuffer,
reqBodyLengthReader: lengthReader{},
}
if r.Body != nil {
rr.reqBodyLengthReader.source = r.Body
r.Body = &rr.reqBodyLengthReader
}

return rr
}

// WriteHeader writes the headers with statusCode to the wrapped
Expand Down Expand Up @@ -211,6 +218,12 @@ func (rr *responseRecorder) Size() int {
return rr.size
}

// RequestSize returns the number of bytes read from the Request,
// not including the request headers.
func (rr *responseRecorder) RequestSize() int {
return rr.reqBodyLengthReader.length
}

// Buffer returns the body buffer that rr was created with.
// You should still have your original pointer, though.
func (rr *responseRecorder) Buffer() *bytes.Buffer {
Expand Down Expand Up @@ -246,12 +259,6 @@ func (rr *responseRecorder) FlushError() error {
return nil
}

// Private interface so it can only be used in this package
// #TODO: maybe export it later
func (rr *responseRecorder) setReadSize(size *int) {
rr.readSize = size
}

func (rr *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
//nolint:bodyclose
conn, brw, err := http.NewResponseController(rr.ResponseWriterWrapper).Hijack()
Expand Down Expand Up @@ -282,9 +289,7 @@ type hijackedConn struct {
}

func (hc *hijackedConn) updateReadSize(n int) {
if hc.rr.readSize != nil {
*hc.rr.readSize += n
}
hc.rr.reqBodyLengthReader.length += n
}

func (hc *hijackedConn) Read(p []byte) (int, error) {
Expand Down Expand Up @@ -320,6 +325,7 @@ type ResponseRecorder interface {
Buffer() *bytes.Buffer
Buffered() bool
Size() int
RequestSize() int
WriteResponse() error
}

Expand All @@ -342,3 +348,23 @@ var (

_ io.WriterTo = (*hijackedConn)(nil)
)

// lengthReader is an io.ReadCloser that keeps track of the
// number of bytes read from the request body.
// This wrapper is for http request process only. If the underlying
// conn hijacked by a websocket session. ResponseRecorder will
// update the Length field.
type lengthReader struct {
source io.ReadCloser
length int
}

func (r *lengthReader) Read(b []byte) (int, error) {
n, err := r.source.Read(b)
r.length += n
return n, err
}

func (r *lengthReader) Close() error {
return r.source.Close()
}
7 changes: 6 additions & 1 deletion modules/caddyhttp/responsewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,26 +104,31 @@ func TestResponseWriterWrapperUnwrap(t *testing.T) {
func TestResponseRecorderReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
req *http.Request
shouldBuffer bool
wantReadFrom bool
}{
"buffered plain": {
responseWriter: &baseRespWriter{},
req: &http.Request{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed plain": {
responseWriter: &baseRespWriter{},
req: &http.Request{},
shouldBuffer: false,
wantReadFrom: false,
},
"buffered ReadFrom": {
responseWriter: &readFromRespWriter{},
req: &http.Request{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed ReadFrom": {
responseWriter: &readFromRespWriter{},
req: &http.Request{},
shouldBuffer: false,
wantReadFrom: true,
},
Expand All @@ -132,7 +137,7 @@ func TestResponseRecorderReadFrom(t *testing.T) {
t.Run(name, func(t *testing.T) {
var buf bytes.Buffer

rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
rr := NewResponseRecorder(tt.responseWriter, tt.req, &buf, func(status int, header http.Header) bool {
return tt.shouldBuffer
})

Expand Down
50 changes: 8 additions & 42 deletions modules/caddyhttp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/netip"
Expand Down Expand Up @@ -335,26 +334,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var duration time.Duration

if s.shouldLogRequest(r) {
wrec := NewResponseRecorder(w, nil, nil)
wrec := NewResponseRecorder(w, r, nil, nil)
w = wrec

// wrap the request body in a LengthReader
// so we can track the number of bytes read from it
var bodyReader *lengthReader
if r.Body != nil {
bodyReader = &lengthReader{Source: r.Body}
r.Body = bodyReader

// should always be true, private interface can only be referenced in the same package
if setReadSizer, ok := wrec.(interface{ setReadSize(*int) }); ok {
setReadSizer.setReadSize(&bodyReader.Length)
}
}

// capture the original version of the request
accLog := s.accessLogger.With(loggableReq)

defer s.logRequest(accLog, r, wrec, &duration, repl, bodyReader, shouldLogCredentials)
defer s.logRequest(accLog, r, wrec, &duration, repl, shouldLogCredentials)
}

start := time.Now()
Expand Down Expand Up @@ -771,18 +757,20 @@ func (s *Server) logTrace(mh MiddlewareHandler) {
// logRequest logs the request to access logs, unless skipped.
func (s *Server) logRequest(
accLog *zap.Logger, r *http.Request, wrec ResponseRecorder, duration *time.Duration,
repl *caddy.Replacer, bodyReader *lengthReader, shouldLogCredentials bool,
repl *caddy.Replacer, shouldLogCredentials bool,
) {
// this request may be flagged as omitted from the logs
if skip, ok := GetVar(r.Context(), LogSkipVar).(bool); ok && skip {
return
}

status := wrec.Status()
size := wrec.Size()
respSize := wrec.Size()
reqSize := wrec.RequestSize()

repl.Set("http.response.status", status) // will be 0 if no response is written by us (Go will write 200 to client)
repl.Set("http.response.size", size)
repl.Set("http.response.size", respSize)
repl.Set("http.request.size", reqSize)
repl.Set("http.response.duration", duration)
repl.Set("http.response.duration_ms", duration.Seconds()*1e3) // multiply seconds to preserve decimal (see #4666)

Expand Down Expand Up @@ -811,17 +799,12 @@ func (s *Server) logRequest(
if fields == nil {
userID, _ := repl.GetString("http.auth.user.id")

reqBodyLength := 0
if bodyReader != nil {
reqBodyLength = bodyReader.Length
}

extra := r.Context().Value(ExtraLogFieldsCtxKey).(*ExtraLogFields)

fieldCount := 6
fields = make([]zapcore.Field, 0, fieldCount+len(extra.fields))
fields = append(fields,
zap.Int("bytes_read", reqBodyLength),
zap.Int("bytes_read", wrec.RequestSize()),
zap.String("user_id", userID),
zap.Duration("duration", *duration),
zap.Int("size", size),
Expand Down Expand Up @@ -1050,23 +1033,6 @@ func cloneURL(from, to *url.URL) {
}
}

// lengthReader is an io.ReadCloser that keeps track of the
// number of bytes read from the request body.
type lengthReader struct {
Source io.ReadCloser
Length int
}

func (r *lengthReader) Read(b []byte) (int, error) {
n, err := r.Source.Read(b)
r.Length += n
return n, err
}

func (r *lengthReader) Close() error {
return r.Source.Close()
}

// Context keys for HTTP request context values.
const (
// For referencing the server instance
Expand Down
Loading