Skip to content

Commit ad8f8a3

Browse files
committed
Test with gzip compression
1 parent 764c760 commit ad8f8a3

File tree

2 files changed

+107
-12
lines changed

2 files changed

+107
-12
lines changed

prometheus/promhttp/http.go

+20-9
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,18 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
190190
}
191191
rsp.Header().Set(contentTypeHeader, string(contentType))
192192

193-
w, encodingHeader, err := NegotiateEncodingWriter(req, rsp, opts.DisableCompression, compressions)
193+
w, encodingHeader, closeWriter, err := NegotiateEncodingWriter(req, rsp, opts.DisableCompression, compressions)
194+
195+
if closeWriter != nil {
196+
defer func() {
197+
err := closeWriter()
198+
if err != nil {
199+
if opts.ErrorLog != nil {
200+
opts.ErrorLog.Println("error closing writer:", err)
201+
}
202+
}
203+
}()
204+
}
194205
if err != nil {
195206
if opts.ErrorLog != nil {
196207
opts.ErrorLog.Println("error getting writer", err)
@@ -432,11 +443,11 @@ func httpError(rsp http.ResponseWriter, err error) {
432443
// selects the right compression based on an allow-list of supported
433444
// compressions. It returns a writer implementing the compression and an the
434445
// correct value that the caller can set in the response header.
435-
func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression bool, compressions []string) (_ io.Writer, encodingHeaderValue string, _ error) {
446+
func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression bool, compressions []string) (_ io.Writer, encodingHeaderValue string, closeWriter func() error, _ error) {
436447
w := rw
437448

438449
if disableCompression {
439-
return w, string(Identity), nil
450+
return w, string(Identity), nil, nil
440451
}
441452

442453
// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
@@ -447,26 +458,26 @@ func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression b
447458
// TODO(mrueg): Replace klauspost/compress with stdlib implementation once https://github.com/golang/go/issues/62513 is implemented.
448459
z, err := zstd.NewWriter(rw, zstd.WithEncoderLevel(zstd.SpeedFastest))
449460
if err != nil {
450-
return nil, "", err
461+
return nil, "", nil, err
451462
}
452463

453464
z.Reset(w)
454-
defer z.Close()
455-
456465
w = z
466+
467+
return w, compression, z.Close, nil
457468
case "gzip":
458469
gz := gzipPool.Get().(*gzip.Writer)
459470
defer gzipPool.Put(gz)
460471

461472
gz.Reset(w)
462-
defer gz.Close()
463473

464474
w = gz
475+
return w, compression, gz.Close, nil
465476
case "identity":
466477
// This means the content is not compressed.
478+
return w, compression, nil, nil
467479
default:
468480
// The content encoding was not implemented yet.
469-
return nil, "", fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
481+
return nil, "", nil, fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
470482
}
471-
return w, compression, nil
472483
}

prometheus/promhttp/http_test.go

+87-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ package promhttp
1515

1616
import (
1717
"bytes"
18+
"compress/gzip"
1819
"errors"
1920
"fmt"
21+
"io"
2022
"log"
2123
"net/http"
2224
"net/http/httptest"
@@ -31,6 +33,11 @@ import (
3133

3234
type errorCollector struct{}
3335

36+
const (
37+
acceptHeader = "Accept"
38+
acceptTextPlain = "text/plain"
39+
)
40+
3441
func (e errorCollector) Describe(ch chan<- *prometheus.Desc) {
3542
ch <- prometheus.NewDesc("invalid_metric", "not helpful", nil, nil)
3643
}
@@ -71,6 +78,16 @@ func (g *mockTransactionGatherer) Gather() (_ []*dto.MetricFamily, done func(),
7178
return mfs, func() { g.doneInvoked++ }, err
7279
}
7380

81+
func readGzippedBody(r io.Reader) (string, error) {
82+
reader, err := gzip.NewReader(r)
83+
if err != nil {
84+
return "", err
85+
}
86+
defer reader.Close()
87+
got, err := io.ReadAll(reader)
88+
return string(got), err
89+
}
90+
7491
func TestHandlerErrorHandling(t *testing.T) {
7592
// Create a registry that collects a MetricFamily with two elements,
7693
// another with one, and reports an error. Further down, we'll use the
@@ -223,7 +240,7 @@ func TestInstrumentMetricHandler(t *testing.T) {
223240
InstrumentMetricHandler(reg, HandlerForTransactional(mReg, HandlerOpts{}))
224241
writer := httptest.NewRecorder()
225242
request, _ := http.NewRequest("GET", "/", nil)
226-
request.Header.Add("Accept", "test/plain")
243+
request.Header.Add(acceptHeader, acceptTextPlain)
227244

228245
handler.ServeHTTP(writer, request)
229246
if got := mReg.gatherInvoked; got != 1 {
@@ -237,6 +254,10 @@ func TestInstrumentMetricHandler(t *testing.T) {
237254
t.Errorf("got HTTP status code %d, want %d", got, want)
238255
}
239256

257+
if got, want := writer.Header().Get(contentEncodingHeader), string(Identity); got != want {
258+
t.Errorf("got HTTP content encoding header %s, want %s", got, want)
259+
}
260+
240261
want := "promhttp_metric_handler_requests_in_flight 1\n"
241262
if got := writer.Body.String(); !strings.Contains(got, want) {
242263
t.Errorf("got body %q, does not contain %q", got, want)
@@ -278,7 +299,7 @@ func TestHandlerMaxRequestsInFlight(t *testing.T) {
278299
w2 := httptest.NewRecorder()
279300
w3 := httptest.NewRecorder()
280301
request, _ := http.NewRequest("GET", "/", nil)
281-
request.Header.Add("Accept", "test/plain")
302+
request.Header.Add(acceptHeader, acceptTextPlain)
282303

283304
c := blockingCollector{Block: make(chan struct{}), CollectStarted: make(chan struct{}, 1)}
284305
reg.MustRegister(c)
@@ -332,6 +353,69 @@ func TestHandlerTimeout(t *testing.T) {
332353
close(c.Block) // To not leak a goroutine.
333354
}
334355

356+
func TestInstrumentMetricHandlerWithCompression(t *testing.T) {
357+
reg := prometheus.NewRegistry()
358+
mReg := &mockTransactionGatherer{g: reg}
359+
handler := InstrumentMetricHandler(reg, HandlerForTransactional(mReg, HandlerOpts{DisableCompression: false}))
360+
writer := httptest.NewRecorder()
361+
request, _ := http.NewRequest("GET", "/", nil)
362+
request.Header.Add(acceptHeader, acceptTextPlain)
363+
request.Header.Add(acceptEncodingHeader, string(Gzip))
364+
365+
handler.ServeHTTP(writer, request)
366+
if got := mReg.gatherInvoked; got != 1 {
367+
t.Fatalf("unexpected number of gather invokes, want 1, got %d", got)
368+
}
369+
if got := mReg.doneInvoked; got != 1 {
370+
t.Fatalf("unexpected number of done invokes, want 1, got %d", got)
371+
}
372+
373+
if got, want := writer.Code, http.StatusOK; got != want {
374+
t.Errorf("got HTTP status code %d, want %d", got, want)
375+
}
376+
377+
if got, want := writer.Header().Get(contentEncodingHeader), string(Gzip); got != want {
378+
t.Errorf("got HTTP content encoding header %s, want %s", got, want)
379+
}
380+
381+
body, err := readGzippedBody(writer.Body)
382+
want := "promhttp_metric_handler_requests_in_flight 1\n"
383+
if got := body; !strings.Contains(got, want) {
384+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
385+
}
386+
387+
want = "promhttp_metric_handler_requests_total{code=\"200\"} 0\n"
388+
if got := body; !strings.Contains(got, want) {
389+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
390+
}
391+
392+
for i := 0; i < 100; i++ {
393+
writer.Body.Reset()
394+
handler.ServeHTTP(writer, request)
395+
396+
if got, want := mReg.gatherInvoked, i+2; got != want {
397+
t.Fatalf("unexpected number of gather invokes, want %d, got %d", want, got)
398+
}
399+
if got, want := mReg.doneInvoked, i+2; got != want {
400+
t.Fatalf("unexpected number of done invokes, want %d, got %d", want, got)
401+
}
402+
if got, want := writer.Code, http.StatusOK; got != want {
403+
t.Errorf("got HTTP status code %d, want %d", got, want)
404+
}
405+
body, err := readGzippedBody(writer.Body)
406+
407+
want := "promhttp_metric_handler_requests_in_flight 1\n"
408+
if got := body; !strings.Contains(got, want) {
409+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
410+
}
411+
412+
want = fmt.Sprintf("promhttp_metric_handler_requests_total{code=\"200\"} %d\n", i+1)
413+
if got := body; !strings.Contains(got, want) {
414+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
415+
}
416+
}
417+
}
418+
335419
func TestNegotiateEncodingWriter(t *testing.T) {
336420
var defaultCompressions []string
337421

@@ -393,7 +477,7 @@ func TestNegotiateEncodingWriter(t *testing.T) {
393477
request, _ := http.NewRequest("GET", "/", nil)
394478
request.Header.Add(acceptEncodingHeader, test.acceptEncoding)
395479
rr := httptest.NewRecorder()
396-
_, encodingHeader, err := NegotiateEncodingWriter(request, rr, test.disableCompression, test.offeredCompressions)
480+
_, encodingHeader, _, err := NegotiateEncodingWriter(request, rr, test.disableCompression, test.offeredCompressions)
397481

398482
if !errors.Is(err, test.err) {
399483
t.Errorf("got error: %v, expected: %v", err, test.err)

0 commit comments

Comments
 (0)