Skip to content

Commit 7cfc1f3

Browse files
committed
Test with gzip compression
1 parent 764c760 commit 7cfc1f3

File tree

2 files changed

+100
-12
lines changed

2 files changed

+100
-12
lines changed

prometheus/promhttp/http.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,11 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
190190
}
191191
rsp.Header().Set(contentTypeHeader, string(contentType))
192192

193-
w, encodingHeader, err := NegotiateEncodingWriter(req, rsp, opts.DisableCompression, compressions)
193+
w, encodingHeader, closeWriter, err := NegotiateEncodingWriter(req, rsp, opts.DisableCompression, compressions)
194+
195+
if closeWriter != nil {
196+
defer closeWriter()
197+
}
194198
if err != nil {
195199
if opts.ErrorLog != nil {
196200
opts.ErrorLog.Println("error getting writer", err)
@@ -432,11 +436,11 @@ func httpError(rsp http.ResponseWriter, err error) {
432436
// selects the right compression based on an allow-list of supported
433437
// compressions. It returns a writer implementing the compression and an the
434438
// correct value that the caller can set in the response header.
435-
func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression bool, compressions []string) (_ io.Writer, encodingHeaderValue string, _ error) {
439+
func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression bool, compressions []string) (_ io.Writer, encodingHeaderValue string, closeWriter func() error, _ error) {
436440
w := rw
437441

438442
if disableCompression {
439-
return w, string(Identity), nil
443+
return w, string(Identity), nil, nil
440444
}
441445

442446
// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
@@ -447,26 +451,26 @@ func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression b
447451
// TODO(mrueg): Replace klauspost/compress with stdlib implementation once https://github.com/golang/go/issues/62513 is implemented.
448452
z, err := zstd.NewWriter(rw, zstd.WithEncoderLevel(zstd.SpeedFastest))
449453
if err != nil {
450-
return nil, "", err
454+
return nil, "", nil, err
451455
}
452456

453457
z.Reset(w)
454-
defer z.Close()
455-
456458
w = z
459+
460+
return w, compression, z.Close, nil
457461
case "gzip":
458462
gz := gzipPool.Get().(*gzip.Writer)
459463
defer gzipPool.Put(gz)
460464

461465
gz.Reset(w)
462-
defer gz.Close()
463466

464467
w = gz
468+
return w, compression, gz.Close, nil
465469
case "identity":
466470
// This means the content is not compressed.
471+
return w, compression, nil, nil
467472
default:
468473
// The content encoding was not implemented yet.
469-
return nil, "", fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
474+
return nil, "", nil, fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
470475
}
471-
return w, compression, nil
472476
}

prometheus/promhttp/http_test.go

+87-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ package promhttp
1515

1616
import (
1717
"bytes"
18+
"compress/gzip"
1819
"errors"
1920
"fmt"
21+
"io"
2022
"log"
2123
"net/http"
2224
"net/http/httptest"
@@ -31,6 +33,11 @@ import (
3133

3234
type errorCollector struct{}
3335

36+
const (
37+
acceptHeader = "Accept"
38+
acceptTextPlain = "text/plain"
39+
)
40+
3441
func (e errorCollector) Describe(ch chan<- *prometheus.Desc) {
3542
ch <- prometheus.NewDesc("invalid_metric", "not helpful", nil, nil)
3643
}
@@ -71,6 +78,16 @@ func (g *mockTransactionGatherer) Gather() (_ []*dto.MetricFamily, done func(),
7178
return mfs, func() { g.doneInvoked++ }, err
7279
}
7380

81+
func readGzippedBody(r io.Reader) (string, error) {
82+
reader, err := gzip.NewReader(r)
83+
if err != nil {
84+
return "", err
85+
}
86+
defer reader.Close()
87+
got, err := io.ReadAll(reader)
88+
return string(got), err
89+
}
90+
7491
func TestHandlerErrorHandling(t *testing.T) {
7592
// Create a registry that collects a MetricFamily with two elements,
7693
// another with one, and reports an error. Further down, we'll use the
@@ -223,7 +240,7 @@ func TestInstrumentMetricHandler(t *testing.T) {
223240
InstrumentMetricHandler(reg, HandlerForTransactional(mReg, HandlerOpts{}))
224241
writer := httptest.NewRecorder()
225242
request, _ := http.NewRequest("GET", "/", nil)
226-
request.Header.Add("Accept", "test/plain")
243+
request.Header.Add(acceptHeader, acceptTextPlain)
227244

228245
handler.ServeHTTP(writer, request)
229246
if got := mReg.gatherInvoked; got != 1 {
@@ -237,6 +254,10 @@ func TestInstrumentMetricHandler(t *testing.T) {
237254
t.Errorf("got HTTP status code %d, want %d", got, want)
238255
}
239256

257+
if got, want := writer.Header().Get(contentEncodingHeader), string(Identity); got != want {
258+
t.Errorf("got HTTP content encoding header %s, want %s", got, want)
259+
}
260+
240261
want := "promhttp_metric_handler_requests_in_flight 1\n"
241262
if got := writer.Body.String(); !strings.Contains(got, want) {
242263
t.Errorf("got body %q, does not contain %q", got, want)
@@ -278,7 +299,7 @@ func TestHandlerMaxRequestsInFlight(t *testing.T) {
278299
w2 := httptest.NewRecorder()
279300
w3 := httptest.NewRecorder()
280301
request, _ := http.NewRequest("GET", "/", nil)
281-
request.Header.Add("Accept", "test/plain")
302+
request.Header.Add(acceptHeader, acceptTextPlain)
282303

283304
c := blockingCollector{Block: make(chan struct{}), CollectStarted: make(chan struct{}, 1)}
284305
reg.MustRegister(c)
@@ -332,6 +353,69 @@ func TestHandlerTimeout(t *testing.T) {
332353
close(c.Block) // To not leak a goroutine.
333354
}
334355

356+
func TestInstrumentMetricHandlerWithCompression(t *testing.T) {
357+
reg := prometheus.NewRegistry()
358+
mReg := &mockTransactionGatherer{g: reg}
359+
handler := InstrumentMetricHandler(reg, HandlerForTransactional(mReg, HandlerOpts{DisableCompression: false}))
360+
writer := httptest.NewRecorder()
361+
request, _ := http.NewRequest("GET", "/", nil)
362+
request.Header.Add(acceptHeader, acceptTextPlain)
363+
request.Header.Add(acceptEncodingHeader, string(Gzip))
364+
365+
handler.ServeHTTP(writer, request)
366+
if got := mReg.gatherInvoked; got != 1 {
367+
t.Fatalf("unexpected number of gather invokes, want 1, got %d", got)
368+
}
369+
if got := mReg.doneInvoked; got != 1 {
370+
t.Fatalf("unexpected number of done invokes, want 1, got %d", got)
371+
}
372+
373+
if got, want := writer.Code, http.StatusOK; got != want {
374+
t.Errorf("got HTTP status code %d, want %d", got, want)
375+
}
376+
377+
if got, want := writer.Header().Get(contentEncodingHeader), string(Gzip); got != want {
378+
t.Errorf("got HTTP content encoding header %s, want %s", got, want)
379+
}
380+
381+
body, err := readGzippedBody(writer.Body)
382+
want := "promhttp_metric_handler_requests_in_flight 1\n"
383+
if got := body; !strings.Contains(got, want) {
384+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
385+
}
386+
387+
want = "promhttp_metric_handler_requests_total{code=\"200\"} 0\n"
388+
if got := body; !strings.Contains(got, want) {
389+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
390+
}
391+
392+
for i := 0; i < 100; i++ {
393+
writer.Body.Reset()
394+
handler.ServeHTTP(writer, request)
395+
396+
if got, want := mReg.gatherInvoked, i+2; got != want {
397+
t.Fatalf("unexpected number of gather invokes, want %d, got %d", want, got)
398+
}
399+
if got, want := mReg.doneInvoked, i+2; got != want {
400+
t.Fatalf("unexpected number of done invokes, want %d, got %d", want, got)
401+
}
402+
if got, want := writer.Code, http.StatusOK; got != want {
403+
t.Errorf("got HTTP status code %d, want %d", got, want)
404+
}
405+
body, err := readGzippedBody(writer.Body)
406+
407+
want := "promhttp_metric_handler_requests_in_flight 1\n"
408+
if got := body; !strings.Contains(got, want) {
409+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
410+
}
411+
412+
want = fmt.Sprintf("promhttp_metric_handler_requests_total{code=\"200\"} %d\n", i+1)
413+
if got := body; !strings.Contains(got, want) {
414+
t.Errorf("got body %q, does not contain %q, err: %v", got, want, err)
415+
}
416+
}
417+
}
418+
335419
func TestNegotiateEncodingWriter(t *testing.T) {
336420
var defaultCompressions []string
337421

@@ -393,7 +477,7 @@ func TestNegotiateEncodingWriter(t *testing.T) {
393477
request, _ := http.NewRequest("GET", "/", nil)
394478
request.Header.Add(acceptEncodingHeader, test.acceptEncoding)
395479
rr := httptest.NewRecorder()
396-
_, encodingHeader, err := NegotiateEncodingWriter(request, rr, test.disableCompression, test.offeredCompressions)
480+
_, encodingHeader, _, err := NegotiateEncodingWriter(request, rr, test.disableCompression, test.offeredCompressions)
397481

398482
if !errors.Is(err, test.err) {
399483
t.Errorf("got error: %v, expected: %v", err, test.err)

0 commit comments

Comments
 (0)