Skip to content

Commit 81ed59a

Browse files
committed
Integrate review comments
* String typed enum Signed-off-by: Manuel Rüger <[email protected]>
1 parent adcec1e commit 81ed59a

File tree

2 files changed

+75
-63
lines changed

2 files changed

+75
-63
lines changed

prometheus/promhttp/http.go

+59-53
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,14 @@ const (
5555
processStartTimeHeader = "Process-Start-Time-Unix"
5656
)
5757

58-
type Compression int
58+
type Compression string
5959

6060
const (
61-
Identity Compression = iota
62-
Gzip
63-
Zstd
61+
Identity Compression = "identity"
62+
Gzip Compression = "gzip"
63+
Zstd Compression = "zstd"
6464
)
6565

66-
var compressions = [...]string{
67-
"identity",
68-
"gzip",
69-
"zstd",
70-
}
71-
72-
func (c Compression) String() string {
73-
return compressions[c]
74-
}
75-
7666
var defaultCompressionFormats = []Compression{Identity, Gzip, Zstd}
7767

7868
var gzipPool = sync.Pool{
@@ -143,6 +133,18 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
143133
}
144134
}
145135

136+
// Select all supported compression formats
137+
var compressions []string
138+
if !opts.DisableCompression {
139+
offers := defaultCompressionFormats
140+
if len(opts.OfferedCompressions) > 0 {
141+
offers = opts.OfferedCompressions
142+
}
143+
for _, comp := range offers {
144+
compressions = append(compressions, string(comp))
145+
}
146+
}
147+
146148
h := http.HandlerFunc(func(rsp http.ResponseWriter, req *http.Request) {
147149
if !opts.ProcessStartTime.IsZero() {
148150
rsp.Header().Set(processStartTimeHeader, strconv.FormatInt(opts.ProcessStartTime.Unix(), 10))
@@ -188,12 +190,19 @@ func HandlerForTransactional(reg prometheus.TransactionalGatherer, opts HandlerO
188190
}
189191
rsp.Header().Set(contentTypeHeader, string(contentType))
190192

191-
w, err := GetWriter(req, rsp, opts.DisableCompression, opts.OfferedCompressions)
193+
w, encodingHeader, err := NegotiateEncodingWriter(req, rsp, opts.DisableCompression, compressions)
192194
if err != nil {
193195
if opts.ErrorLog != nil {
194196
opts.ErrorLog.Println("error getting writer", err)
195197
}
198+
// Since the writer received from NegotiateEncodingWriter will be nil, in case there's an error, we set it here
199+
w = io.Writer(rsp)
200+
}
201+
202+
if encodingHeader == "" {
203+
encodingHeader = Identity
196204
}
205+
rsp.Header().Set(contentEncodingHeader, encodingHeader)
197206

198207
enc := expfmt.NewEncoder(w, contentType)
199208

@@ -419,48 +428,45 @@ func httpError(rsp http.ResponseWriter, err error) {
419428
)
420429
}
421430

422-
func GetWriter(r *http.Request, rsp http.ResponseWriter, disableCompression bool, offeredCompressions []Compression) (io.Writer, error) {
423-
w := io.Writer(rsp)
424-
rsp.Header().Set(contentEncodingHeader, "identity")
425-
if !disableCompression {
426-
offers := defaultCompressionFormats
427-
if len(offeredCompressions) > 0 {
428-
offers = offeredCompressions
429-
}
430-
var compressions []string
431-
for _, comp := range offers {
432-
compressions = append(compressions, comp.String())
431+
// NegotiateEncodingWriter reads the Accept-Encoding header from a request and
432+
// selects the right compression based on an allow-list of supported
433+
// compressions. It returns a writer implementing the compression and an the
434+
// correct value that the caller can set in the response header.
435+
func NegotiateEncodingWriter(r *http.Request, rw io.Writer, disableCompression bool, compressions []string) (_ io.Writer, encodingHeaderValue string, _ error) {
436+
w := rw
437+
438+
if disableCompression {
439+
return w, string(Identity), nil
440+
}
441+
442+
// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
443+
compression := httputil.NegotiateContentEncoding(r, compressions)
444+
445+
switch compression {
446+
case "zstd":
447+
// TODO(mrueg): Replace klauspost/compress with stdlib implementation once https://github.com/golang/go/issues/62513 is implemented.
448+
z, err := zstd.NewWriter(rw, zstd.WithEncoderLevel(zstd.SpeedFastest))
449+
if err != nil {
450+
return nil, "", err
433451
}
434-
// TODO(mrueg): Replace internal/github.com/gddo once https://github.com/golang/go/issues/19307 is implemented.
435-
compression := httputil.NegotiateContentEncoding(r, compressions)
436-
switch compression {
437-
case "zstd":
438-
rsp.Header().Set(contentEncodingHeader, "zstd")
439-
// TODO(mrueg): Replace klauspost/compress with stdlib implementation once https://github.com/golang/go/issues/62513 is implemented.
440-
z, err := zstd.NewWriter(rsp, zstd.WithEncoderLevel(zstd.SpeedFastest))
441-
if err != nil {
442-
return nil, err
443-
}
444452

445-
z.Reset(w)
446-
defer z.Close()
453+
z.Reset(w)
454+
defer z.Close()
447455

448-
w = z
449-
case "gzip":
450-
rsp.Header().Set(contentEncodingHeader, "gzip")
451-
gz := gzipPool.Get().(*gzip.Writer)
452-
defer gzipPool.Put(gz)
456+
w = z
457+
case "gzip":
458+
gz := gzipPool.Get().(*gzip.Writer)
459+
defer gzipPool.Put(gz)
453460

454-
gz.Reset(w)
455-
defer gz.Close()
461+
gz.Reset(w)
462+
defer gz.Close()
456463

457-
w = gz
458-
case "identity":
459-
// This means the content is not compressed.
460-
default:
461-
// The content encoding was not implemented yet.
462-
return w, fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
463-
}
464+
w = gz
465+
case "identity":
466+
// This means the content is not compressed.
467+
default:
468+
// The content encoding was not implemented yet.
469+
return nil, "", fmt.Errorf("content compression format not recognized: %s. Valid formats are: %s", compression, defaultCompressionFormats)
464470
}
465-
return w, nil
471+
return w, compression, nil
466472
}

prometheus/promhttp/http_test.go

+16-10
Original file line numberDiff line numberDiff line change
@@ -332,51 +332,57 @@ func TestHandlerTimeout(t *testing.T) {
332332
close(c.Block) // To not leak a goroutine.
333333
}
334334

335-
func TestGetWriter(t *testing.T) {
335+
func TestNegotiateEncodingWriter(t *testing.T) {
336+
var defaultCompressions []string
337+
338+
for _, comp := range defaultCompressionFormats {
339+
defaultCompressions = append(defaultCompressions, string(comp))
340+
}
341+
336342
testCases := []struct {
337343
name string
338344
disableCompression bool
339-
offeredCompressions []Compression
345+
offeredCompressions []string
340346
acceptEncoding string
341347
expectedCompression string
342348
err error
343349
}{
344350
{
345351
name: "test without compression enabled",
346352
disableCompression: true,
347-
offeredCompressions: defaultCompressionFormats,
353+
offeredCompressions: defaultCompressions,
348354
acceptEncoding: "",
349355
expectedCompression: "identity",
350356
err: nil,
351357
},
352358
{
353359
name: "test with compression enabled with empty accept-encoding header",
354360
disableCompression: false,
355-
offeredCompressions: defaultCompressionFormats,
361+
offeredCompressions: defaultCompressions,
356362
acceptEncoding: "",
357363
expectedCompression: "identity",
358364
err: nil,
359365
},
360366
{
361367
name: "test with gzip compression requested",
362368
disableCompression: false,
363-
offeredCompressions: defaultCompressionFormats,
369+
offeredCompressions: defaultCompressions,
364370
acceptEncoding: "gzip",
365371
expectedCompression: "gzip",
366372
err: nil,
367373
},
368374
{
369375
name: "test with gzip, zstd compression requested",
370376
disableCompression: false,
371-
offeredCompressions: defaultCompressionFormats,
377+
offeredCompressions: defaultCompressions,
372378
acceptEncoding: "gzip,zstd",
373379
expectedCompression: "gzip",
374380
err: nil,
375381
},
376382
{
377383
name: "test with zstd, gzip compression requested",
378384
disableCompression: false,
379-
offeredCompressions: defaultCompressionFormats,
385+
offeredCompressions: defaultCompressions,
380386
acceptEncoding: "zstd,gzip",
381387
expectedCompression: "gzip",
382388
err: nil,
@@ -387,14 +393,14 @@ func TestGetWriter(t *testing.T) {
387393
request, _ := http.NewRequest("GET", "/", nil)
388394
request.Header.Add(acceptEncodingHeader, test.acceptEncoding)
389395
rr := httptest.NewRecorder()
390-
_, err := GetWriter(request, rr, test.disableCompression, test.offeredCompressions)
396+
_, encodingHeader, err := NegotiateEncodingWriter(request, rr, test.disableCompression, test.offeredCompressions)
391397

392398
if !errors.Is(err, test.err) {
393399
t.Errorf("got error: %v, expected: %v", err, test.err)
394400
}
395401

396-
if rr.Header().Get(contentEncodingHeader) != test.expectedCompression {
397-
t.Errorf("got different compression type: %v, expected: %v", rr.Header().Get(contentEncodingHeader), test.expectedCompression)
402+
if encodingHeader != test.expectedCompression {
403+
t.Errorf("got different compression type: %v, expected: %v", encodingHeader, test.expectedCompression)
398404
}
399405
}
400406
}

0 commit comments

Comments
 (0)