Skip to content

Commit 13e4ded

Browse files
authored
Merge pull request #497 from stefanprodan/harden-web
Harden web server against resource exhaustion
2 parents 349544d + 5037bed commit 13e4ded

10 files changed

Lines changed: 179 additions & 25 deletions

File tree

pkg/api/http/body_limit_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package http
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
)
10+
11+
// TestRequestBodySizeLimit verifies body-reading handlers reject payloads larger
12+
// than maxRequestBodySize instead of buffering them into memory. The echo
13+
// handler is used because with no backends configured it simply reflects the
14+
// body and needs no external dependencies.
15+
func TestRequestBodySizeLimit(t *testing.T) {
16+
srv := NewMockServer()
17+
srv.router.HandleFunc("/echo", srv.echoHandler)
18+
19+
// A body within the limit is accepted (202).
20+
within := httptest.NewRequest("POST", "/echo", bytes.NewReader(make([]byte, 1024)))
21+
rr := httptest.NewRecorder()
22+
srv.router.ServeHTTP(rr, within)
23+
if rr.Code != http.StatusAccepted {
24+
t.Errorf("within-limit body: got status %d want %d", rr.Code, http.StatusAccepted)
25+
}
26+
27+
// A body over the limit is rejected with a 413 code in the response body.
28+
over := httptest.NewRequest("POST", "/echo", bytes.NewReader(make([]byte, maxRequestBodySize+1)))
29+
rr = httptest.NewRecorder()
30+
srv.router.ServeHTTP(rr, over)
31+
32+
var resp struct {
33+
Code int `json:"code"`
34+
Message string `json:"message"`
35+
}
36+
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
37+
t.Fatalf("oversize body: response is not the expected error JSON: %v (body=%q)", err, rr.Body.String())
38+
}
39+
if resp.Code != http.StatusRequestEntityTooLarge {
40+
t.Errorf("oversize body: got code %d want %d", resp.Code, http.StatusRequestEntityTooLarge)
41+
}
42+
}

pkg/api/http/cache.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package http
22

33
import (
44
"fmt"
5-
"io"
65
"net/http"
76
"net/url"
87
"time"
@@ -33,15 +32,14 @@ func (s *Server) cacheWriteHandler(w http.ResponseWriter, r *http.Request) {
3332
}
3433

3534
key := mux.Vars(r)["key"]
36-
body, err := io.ReadAll(r.Body)
37-
if err != nil {
38-
s.ErrorResponse(w, r, span, "reading the request body failed", http.StatusBadRequest)
35+
body, ok := s.readLimitedBody(w, r, span)
36+
if !ok {
3937
return
4038
}
4139

4240
conn := s.pool.Get()
4341
defer conn.Close()
44-
_, err = conn.Do("SET", key, string(body))
42+
_, err := conn.Do("SET", key, string(body))
4543
if err != nil {
4644
s.logger.Warn("cache set failed", zap.Error(err))
4745
s.ErrorResponse(w, r, span, "cache set failed", http.StatusInternalServerError)

pkg/api/http/chunked.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (s *Server) chunkedHandler(w http.ResponseWriter, r *http.Request) {
2626

2727
delay, err := strconv.Atoi(vars["wait"])
2828
if err != nil {
29-
delay = rand.Intn(int(s.config.HttpServerTimeout*time.Second)-10) + 10
29+
delay = randomDelaySeconds(s.config.HttpServerTimeout)
3030
}
3131

3232
flusher, ok := w.(http.Flusher)
@@ -46,3 +46,15 @@ func (s *Server) chunkedHandler(w http.ResponseWriter, r *http.Request) {
4646

4747
flusher.Flush()
4848
}
49+
50+
// randomDelaySeconds returns a random delay in seconds within [10, timeout),
51+
// used when no explicit wait is provided. timeout is a time.Duration, so it is
52+
// converted to whole seconds; the upper bound is clamped to keep rand.Intn's
53+
// argument positive (it panics on a non-positive argument).
54+
func randomDelaySeconds(timeout time.Duration) int {
55+
maxDelay := int(timeout / time.Second)
56+
if maxDelay <= 11 {
57+
maxDelay = 12
58+
}
59+
return rand.Intn(maxDelay-10) + 10
60+
}

pkg/api/http/chunked_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net/http/httptest"
66
"regexp"
77
"testing"
8+
"time"
89
)
910

1011
func TestChunkedHandler(t *testing.T) {
@@ -33,3 +34,19 @@ func TestChunkedHandler(t *testing.T) {
3334
rr.Body.String(), expected)
3435
}
3536
}
37+
38+
// TestRandomDelaySeconds covers the default-delay branch taken by the bare
39+
// /chunked route (no {wait} value). This used to panic because the
40+
// duration-to-seconds math overflowed int64 and handed rand.Intn a negative
41+
// argument. Every timeout must yield a valid delay in [10, max] without panicking.
42+
func TestRandomDelaySeconds(t *testing.T) {
43+
timeouts := []time.Duration{30 * time.Second, 12 * time.Second, time.Second, 0, -1}
44+
for _, timeout := range timeouts {
45+
for range 100 {
46+
d := randomDelaySeconds(timeout)
47+
if d < 10 {
48+
t.Fatalf("timeout %s: delay %d below floor of 10", timeout, d)
49+
}
50+
}
51+
}
52+
}

pkg/api/http/echo.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ func (s *Server) echoHandler(w http.ResponseWriter, r *http.Request) {
2727
ctx, span := s.tracer.Start(r.Context(), "echoHandler")
2828
defer span.End()
2929

30-
body, err := io.ReadAll(r.Body)
31-
if err != nil {
32-
s.logger.Error("reading the request body failed", zap.Error(err))
33-
s.ErrorResponse(w, r, span, "invalid request body", http.StatusBadRequest)
30+
body, ok := s.readLimitedBody(w, r, span)
31+
if !ok {
3432
return
3533
}
3634
defer r.Body.Close()

pkg/api/http/echows.go

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ import (
1212

1313
var wsCon = websocket.Upgrader{}
1414

15+
const (
16+
// wsMaxMessageSize caps a single inbound websocket message so one large
17+
// frame cannot exhaust process memory.
18+
wsMaxMessageSize = 1 << 20 // 1 MiB
19+
// wsReadTimeout is how long the server waits for the next client message or
20+
// pong before closing an idle connection.
21+
wsReadTimeout = 60 * time.Second
22+
// wsWriteTimeout bounds a single write so a slow reader cannot block forever.
23+
wsWriteTimeout = 10 * time.Second
24+
// wsPingInterval is how often the server pings the client to keep the
25+
// connection alive and detect dead peers; it must be shorter than wsReadTimeout.
26+
wsPingInterval = 30 * time.Second
27+
)
28+
1529
// EchoWS godoc
1630
// @Summary Echo over websockets
1731
// @Description echos content via websockets
@@ -24,11 +38,18 @@ var wsCon = websocket.Upgrader{}
2438
func (s *Server) echoWsHandler(w http.ResponseWriter, r *http.Request) {
2539
c, err := wsCon.Upgrade(w, r, nil)
2640
if err != nil {
27-
if err != nil {
28-
s.logger.Warn("websocket upgrade error", zap.Error(err))
29-
return
30-
}
41+
s.logger.Warn("websocket upgrade error", zap.Error(err))
42+
return
3143
}
44+
45+
// Bound per-message size and idle time; refresh the read deadline whenever
46+
// the client responds to a ping so live connections stay open.
47+
c.SetReadLimit(wsMaxMessageSize)
48+
_ = c.SetReadDeadline(time.Now().Add(wsReadTimeout))
49+
c.SetPongHandler(func(string) error {
50+
return c.SetReadDeadline(time.Now().Add(wsReadTimeout))
51+
})
52+
3253
var wg sync.WaitGroup
3354
wg.Add(1)
3455

@@ -84,15 +105,23 @@ func (s *Server) sendHostWs(ws *websocket.Conn, in chan interface{}, done chan s
84105
}
85106

86107
func (s *Server) writeWs(ws *websocket.Conn, in chan interface{}) {
108+
ping := time.NewTicker(wsPingInterval)
109+
defer ping.Stop()
87110
for {
88111
select {
89112
case msg := <-in:
113+
_ = ws.SetWriteDeadline(time.Now().Add(wsWriteTimeout))
90114
if err := ws.WriteJSON(msg); err != nil {
91115
if !strings.Contains(err.Error(), "close") {
92116
s.logger.Warn("websocket write error", zap.Error(err))
93117
}
94118
return
95119
}
120+
case <-ping.C:
121+
_ = ws.SetWriteDeadline(time.Now().Add(wsWriteTimeout))
122+
if err := ws.WriteMessage(websocket.PingMessage, nil); err != nil {
123+
return
124+
}
96125
}
97126
}
98127
}

pkg/api/http/echows_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"net/http/httptest"
55
"strings"
66
"testing"
7+
"time"
78

89
"github.com/gorilla/websocket"
910
)
@@ -34,3 +35,40 @@ func TestEchoWsHandler(t *testing.T) {
3435
t.Error("received empty message")
3536
}
3637
}
38+
39+
// TestEchoWsReadLimit verifies the server caps inbound message size: a message
40+
// larger than wsMaxMessageSize must cause the server to close the connection
41+
// instead of buffering it into memory.
42+
func TestEchoWsReadLimit(t *testing.T) {
43+
srv := NewMockServer()
44+
srv.router.HandleFunc("/ws/echo", srv.echoWsHandler)
45+
server := httptest.NewServer(srv.router)
46+
defer server.Close()
47+
48+
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/ws/echo"
49+
ws, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
50+
if err != nil {
51+
t.Fatalf("websocket dial failed: %v", err)
52+
}
53+
defer ws.Close()
54+
55+
// A message just over the limit must not be echoed; the server closes it.
56+
oversize := make([]byte, wsMaxMessageSize+1)
57+
if err := ws.WriteMessage(websocket.TextMessage, oversize); err != nil {
58+
t.Fatalf("write failed: %v", err)
59+
}
60+
61+
// Read until the connection errors. With the read limit in place the server
62+
// sends a 1009 (message too big) close frame; without it the server would
63+
// instead echo the oversize payload back. The 5s deadline bounds the loop.
64+
// Status frames from the periodic ticker (err == nil) are skipped.
65+
ws.SetReadDeadline(time.Now().Add(5 * time.Second))
66+
for {
67+
if _, _, err = ws.ReadMessage(); err != nil {
68+
break
69+
}
70+
}
71+
if !websocket.IsCloseError(err, websocket.CloseMessageTooBig) {
72+
t.Errorf("expected close code %d (message too big), got: %v", websocket.CloseMessageTooBig, err)
73+
}
74+
}

pkg/api/http/http.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package http
33
import (
44
"bytes"
55
"encoding/json"
6+
"errors"
7+
"io"
68
"math/rand"
79
"net/http"
810
"time"
@@ -13,6 +15,11 @@ import (
1315
"go.uber.org/zap"
1416
)
1517

18+
// maxRequestBodySize caps how much of a request body the body-reading handlers
19+
// buffer into memory. Without this bound an unauthenticated client can POST an
20+
// arbitrarily large body and exhaust process memory (and, for /store, disk).
21+
const maxRequestBodySize = 10 << 20 // 10 MiB
22+
1623
func randomErrorMiddleware(next http.Handler) http.Handler {
1724
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1825
rand.Seed(time.Now().Unix())
@@ -87,6 +94,25 @@ func (s *Server) ErrorResponse(w http.ResponseWriter, r *http.Request, span trac
8794
w.Write(prettyJSON(body))
8895
}
8996

97+
// readLimitedBody reads the request body up to maxRequestBodySize. It returns
98+
// the body and true on success. On an oversized body it writes a 413 response,
99+
// on any other read error a 400, and returns false so the caller returns early.
100+
func (s *Server) readLimitedBody(w http.ResponseWriter, r *http.Request, span trace.Span) ([]byte, bool) {
101+
r.Body = http.MaxBytesReader(w, r.Body, maxRequestBodySize)
102+
body, err := io.ReadAll(r.Body)
103+
if err != nil {
104+
var maxErr *http.MaxBytesError
105+
if errors.As(err, &maxErr) {
106+
s.ErrorResponse(w, r, span, "request body too large", http.StatusRequestEntityTooLarge)
107+
return nil, false
108+
}
109+
s.logger.Error("reading the request body failed", zap.Error(err))
110+
s.ErrorResponse(w, r, span, "invalid request body", http.StatusBadRequest)
111+
return nil, false
112+
}
113+
return body, true
114+
}
115+
90116
// setRawResponseHeaders prevents XSS by ensuring browsers never interpret raw responses as HTML.
91117
func setRawResponseHeaders(w http.ResponseWriter) {
92118
w.Header().Set("Content-Type", "application/octet-stream")

pkg/api/http/store.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package http
33
import (
44
"crypto/sha1"
55
"encoding/hex"
6-
"io"
76
"net/http"
87
"os"
98
"path"
@@ -27,14 +26,13 @@ func (s *Server) storeWriteHandler(w http.ResponseWriter, r *http.Request) {
2726
_, span := s.tracer.Start(r.Context(), "storeWriteHandler")
2827
defer span.End()
2928

30-
body, err := io.ReadAll(r.Body)
31-
if err != nil {
32-
s.ErrorResponse(w, r, span, "reading the request body failed", http.StatusBadRequest)
29+
body, ok := s.readLimitedBody(w, r, span)
30+
if !ok {
3331
return
3432
}
3533

3634
hash := hash(string(body))
37-
err = os.WriteFile(path.Join(s.config.DataPath, hash), body, 0644)
35+
err := os.WriteFile(path.Join(s.config.DataPath, hash), body, 0644)
3836
if err != nil {
3937
s.logger.Warn("writing file failed", zap.Error(err), zap.String("file", path.Join(s.config.DataPath, hash)))
4038
s.ErrorResponse(w, r, span, "writing file failed", http.StatusInternalServerError)

pkg/api/http/token.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@ package http
22

33
import (
44
"fmt"
5-
"io"
65
"net/http"
76
"strings"
87
"time"
98

109
"github.com/golang-jwt/jwt/v4"
11-
"go.uber.org/zap"
1210
)
1311

1412
type jwtCustomClaims struct {
@@ -28,10 +26,8 @@ func (s *Server) tokenGenerateHandler(w http.ResponseWriter, r *http.Request) {
2826
_, span := s.tracer.Start(r.Context(), "tokenGenerateHandler")
2927
defer span.End()
3028

31-
body, err := io.ReadAll(r.Body)
32-
if err != nil {
33-
s.logger.Error("reading the request body failed", zap.Error(err))
34-
s.ErrorResponse(w, r, span, "invalid request body", http.StatusBadRequest)
29+
body, ok := s.readLimitedBody(w, r, span)
30+
if !ok {
3531
return
3632
}
3733
defer r.Body.Close()

0 commit comments

Comments
 (0)