Skip to content

Commit c9f4177

Browse files
authored
Consistently parse and validate user-provided status codes (#137)
In testing out error handling after #135, I happened to stumble across an unexpected panic for requests like `/status/1024` where the user-provided status code is outside the legal bounds. So, here we take a quick pass to ensure we're parsing and validating status codes the same way everywhere.
1 parent 9e30640 commit c9f4177

File tree

3 files changed

+58
-17
lines changed

3 files changed

+58
-17
lines changed

httpbin/handlers.go

+10-17
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,9 @@ func (h *HTTPBin) Status(w http.ResponseWriter, r *http.Request) {
254254
writeError(w, http.StatusNotFound, nil)
255255
return
256256
}
257-
code, err := strconv.Atoi(parts[2])
257+
code, err := parseStatusCode(parts[2])
258258
if err != nil {
259-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status %q: %w", parts[2], err))
259+
writeError(w, http.StatusBadRequest, err)
260260
return
261261
}
262262

@@ -297,7 +297,7 @@ func (h *HTTPBin) Unstable(w http.ResponseWriter, r *http.Request) {
297297
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %w", err))
298298
return
299299
} else if failureRate < 0 || failureRate > 1 {
300-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in interval [0, 1]", err))
300+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid failure rate: %d not in range [0, 1]", err))
301301
return
302302
}
303303
}
@@ -414,14 +414,10 @@ func (h *HTTPBin) RedirectTo(w http.ResponseWriter, r *http.Request) {
414414
}
415415

416416
statusCode := http.StatusFound
417-
rawStatusCode := q.Get("status_code")
418-
if rawStatusCode != "" {
419-
statusCode, err = strconv.Atoi(q.Get("status_code"))
417+
if userStatusCode := q.Get("status_code"); userStatusCode != "" {
418+
statusCode, err = parseBoundedStatusCode(userStatusCode, 300, 399)
420419
if err != nil {
421-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid status code: %w", err))
422-
return
423-
} else if statusCode < 300 || statusCode > 399 {
424-
writeError(w, http.StatusBadRequest, errors.New("invalid status code: must be in range [300, 399]"))
420+
writeError(w, http.StatusBadRequest, err)
425421
return
426422
}
427423
}
@@ -617,18 +613,15 @@ func (h *HTTPBin) Drip(w http.ResponseWriter, r *http.Request) {
617613
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %w", err))
618614
return
619615
} else if numBytes < 1 || numBytes > h.MaxBodySize {
620-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in interval [1, %d]", numBytes, h.MaxBodySize))
616+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid numbytes: %d not in range [1, %d]", numBytes, h.MaxBodySize))
621617
return
622618
}
623619
}
624620

625621
if userCode := q.Get("code"); userCode != "" {
626-
code, err = strconv.Atoi(userCode)
622+
code, err = parseStatusCode(userCode)
627623
if err != nil {
628-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %w", err))
629-
return
630-
} else if code < 100 || code >= 600 {
631-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid code: %d not in interval [100, 599]", code))
624+
writeError(w, http.StatusBadRequest, err)
632625
return
633626
}
634627
}
@@ -713,7 +706,7 @@ func (h *HTTPBin) Range(w http.ResponseWriter, r *http.Request) {
713706
w.Header().Add("Accept-Ranges", "bytes")
714707

715708
if numBytes <= 0 || numBytes > h.MaxBodySize {
716-
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in interval [1, %d]", numBytes, h.MaxBodySize))
709+
writeError(w, http.StatusBadRequest, fmt.Errorf("invalid count: %d not in range [1, %d]", numBytes, h.MaxBodySize))
717710
return
718711
}
719712

httpbin/handlers_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ func TestStatus(t *testing.T) {
799799
headers map[string]string
800800
body string
801801
}{
802+
// 100 is tested as a special case below
802803
{200, nil, ""},
803804
{300, map[string]string{"Location": "/image/jpeg"}, `<!doctype html>
804805
<head>
@@ -822,6 +823,8 @@ func TestStatus(t *testing.T) {
822823
</html>`},
823824
{401, unauthorizedHeaders, ""},
824825
{418, nil, "I'm a teapot!"},
826+
{500, nil, ""}, // maximum allowed status code
827+
{599, nil, ""}, // maximum allowed status code
825828
}
826829

827830
for _, test := range tests {
@@ -848,6 +851,8 @@ func TestStatus(t *testing.T) {
848851
{"/status/200/foo", http.StatusNotFound},
849852
{"/status/3.14", http.StatusBadRequest},
850853
{"/status/foo", http.StatusBadRequest},
854+
{"/status/600", http.StatusBadRequest},
855+
{"/status/1024", http.StatusBadRequest},
851856
}
852857

853858
for _, test := range errorTests {
@@ -860,6 +865,34 @@ func TestStatus(t *testing.T) {
860865
assert.StatusCode(t, resp, test.status)
861866
})
862867
}
868+
869+
t.Run("HTTP 100 Continue status code supported", func(t *testing.T) {
870+
// The stdlib http client automagically handles 100 Continue responses
871+
// by continuing the request until a "final" 200 OK response is
872+
// received, which prevents us from confirming that a 100 Continue
873+
// response is sent when using the http client directly.
874+
//
875+
// So, here we instead manally write the request to the wire and read
876+
// the initial response, which will give us access to the 100 Continue
877+
// indication we need.
878+
t.Parallel()
879+
880+
conn, err := net.Dial("tcp", srv.Listener.Addr().String())
881+
assert.NilError(t, err)
882+
defer conn.Close()
883+
884+
req := newTestRequest(t, "GET", "/status/100")
885+
reqBytes, err := httputil.DumpRequestOut(req, false)
886+
assert.NilError(t, err)
887+
888+
n, err := conn.Write(reqBytes)
889+
assert.NilError(t, err)
890+
assert.Equal(t, n, len(reqBytes), "incorrect number of bytes written")
891+
892+
resp, err := http.ReadResponse(bufio.NewReader(conn), req)
893+
assert.NilError(t, err)
894+
assert.StatusCode(t, resp, http.StatusContinue)
895+
})
863896
}
864897

865898
func TestUnstable(t *testing.T) {

httpbin/helpers.go

+15
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,21 @@ func encodeData(body []byte, contentType string) string {
234234
return string("data:" + contentType + ";base64," + data)
235235
}
236236

237+
func parseStatusCode(input string) (int, error) {
238+
return parseBoundedStatusCode(input, 100, 599)
239+
}
240+
241+
func parseBoundedStatusCode(input string, min, max int) (int, error) {
242+
code, err := strconv.Atoi(input)
243+
if err != nil {
244+
return 0, fmt.Errorf("invalid status code: %q: %w", input, err)
245+
}
246+
if code < min || code > max {
247+
return 0, fmt.Errorf("invalid status code: %d not in range [%d, %d]", code, min, max)
248+
}
249+
return code, nil
250+
}
251+
237252
// parseDuration takes a user's input as a string and attempts to convert it
238253
// into a time.Duration. If not given as a go-style duration string, the input
239254
// is assumed to be seconds as a float.

0 commit comments

Comments
 (0)