Skip to content

Commit 9cdb082

Browse files
authored
do not allow for unbounded reads for user controlled input (#681)
Signed-off-by: Alex Goodman <wagoodman@users.noreply.github.com>
1 parent 80cf3fe commit 9cdb082

File tree

8 files changed

+303
-16
lines changed

8 files changed

+303
-16
lines changed

.golangci.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ linters:
5050
- linters:
5151
- revive
5252
text: "var-naming: avoid package names that conflict"
53+
# utils is a commonly used helper package name
54+
- linters:
55+
- revive
56+
text: "var-naming: avoid meaningless package names"
5357
paths:
5458
- third_party$
5559
- builtin$

internal/utils/indent.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package utils //nolint:revive // existing package name
1+
package utils
22

33
import "strings"
44

internal/utils/io.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package utils
2+
3+
import (
4+
"fmt"
5+
"io"
6+
)
7+
8+
// ReadAllLimited reads up to maxBytes from r. Returns error if limit exceeded.
9+
func ReadAllLimited(r io.Reader, maxBytes int64) ([]byte, error) {
10+
limitedReader := io.LimitReader(r, maxBytes+1)
11+
data, err := io.ReadAll(limitedReader)
12+
if err != nil {
13+
return nil, err
14+
}
15+
if int64(len(data)) > maxBytes {
16+
return nil, fmt.Errorf("response size exceeds limit of %d bytes", maxBytes)
17+
}
18+
return data, nil
19+
}

internal/utils/io_test.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package utils
2+
3+
import (
4+
"bytes"
5+
"strings"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestReadAllLimited(t *testing.T) {
12+
tests := []struct {
13+
name string
14+
input string
15+
maxBytes int64
16+
want string
17+
wantErr require.ErrorAssertionFunc
18+
}{
19+
{
20+
name: "reads data under limit",
21+
input: "hello world",
22+
maxBytes: 100,
23+
want: "hello world",
24+
},
25+
{
26+
name: "reads data exactly at limit",
27+
input: "hello",
28+
maxBytes: 5,
29+
want: "hello",
30+
},
31+
{
32+
name: "returns error when data exceeds limit",
33+
input: "hello world",
34+
maxBytes: 5,
35+
wantErr: require.Error,
36+
},
37+
{
38+
name: "handles empty input",
39+
input: "",
40+
maxBytes: 100,
41+
want: "",
42+
},
43+
{
44+
name: "handles zero limit with empty input",
45+
input: "",
46+
maxBytes: 0,
47+
want: "",
48+
},
49+
{
50+
name: "returns error for zero limit with data",
51+
input: "a",
52+
maxBytes: 0,
53+
wantErr: require.Error,
54+
},
55+
{
56+
name: "handles large data at boundary minus one",
57+
input: strings.Repeat("x", 999),
58+
maxBytes: 1000,
59+
want: strings.Repeat("x", 999),
60+
},
61+
{
62+
name: "handles large data at exact boundary",
63+
input: strings.Repeat("x", 1000),
64+
maxBytes: 1000,
65+
want: strings.Repeat("x", 1000),
66+
},
67+
{
68+
name: "returns error for large data over boundary",
69+
input: strings.Repeat("x", 1001),
70+
maxBytes: 1000,
71+
wantErr: require.Error,
72+
},
73+
}
74+
75+
for _, tt := range tests {
76+
t.Run(tt.name, func(t *testing.T) {
77+
if tt.wantErr == nil {
78+
tt.wantErr = require.NoError
79+
}
80+
81+
reader := bytes.NewReader([]byte(tt.input))
82+
got, err := ReadAllLimited(reader, tt.maxBytes)
83+
tt.wantErr(t, err)
84+
85+
if err != nil {
86+
return
87+
}
88+
require.Equal(t, tt.want, string(got))
89+
})
90+
}
91+
}

quill/notary/api_client.go

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8-
"io"
98
"net/http"
109
"net/url"
1110
"path"
@@ -22,6 +21,12 @@ import (
2221
"github.com/anchore/quill/internal/log"
2322
"github.com/anchore/quill/internal/redact"
2423
"github.com/anchore/quill/internal/urlvalidate"
24+
"github.com/anchore/quill/internal/utils"
25+
)
26+
27+
const (
28+
maxAPIResponseSize = 5 * 1024 * 1024 // 5 MB for API JSON responses
29+
maxLogResponseSize = 50 * 1024 * 1024 // 50 MB for log files
2530
)
2631

2732
type api interface {
@@ -63,7 +68,7 @@ func (s APIClient) submissionRequest(ctx context.Context, request submissionRequ
6368
return nil, err
6469
}
6570

66-
response, err := s.http.post(ctx, s.api, bytes.NewReader(requestBytes))
71+
response, err := s.http.post(ctx, s.api, bytes.NewReader(requestBytes)) //nolint:bodyclose // body is closed in handleResponse
6772
body, err := s.handleResponse(response, err)
6873
if err != nil {
6974
return nil, err
@@ -119,7 +124,7 @@ func (s APIClient) uploadBinary(ctx context.Context, response submissionResponse
119124
}
120125

121126
func (s APIClient) submissionStatusRequest(ctx context.Context, id string) (*submissionStatusResponse, error) {
122-
response, err := s.http.get(ctx, joinURL(s.api, id), nil)
127+
response, err := s.http.get(ctx, joinURL(s.api, id), nil) //nolint:bodyclose // body is closed in handleResponse
123128
body, err := s.handleResponse(response, err)
124129
if err != nil {
125130
return nil, err
@@ -133,7 +138,7 @@ func (s APIClient) submissionStatusRequest(ctx context.Context, id string) (*sub
133138
}
134139

135140
func (s APIClient) submissionList(ctx context.Context) (*submissionListResponse, error) {
136-
response, err := s.http.get(ctx, s.api, nil)
141+
response, err := s.http.get(ctx, s.api, nil) //nolint:bodyclose // body is closed in handleResponse
137142
body, err := s.handleResponse(response, err)
138143
if err != nil {
139144
return nil, err
@@ -147,7 +152,7 @@ func (s APIClient) submissionList(ctx context.Context) (*submissionListResponse,
147152
}
148153

149154
func (s APIClient) submissionLogs(ctx context.Context, id string) (string, error) {
150-
metadataResp, err := s.http.get(ctx, joinURL(s.api, id, "logs"), nil)
155+
metadataResp, err := s.http.get(ctx, joinURL(s.api, id, "logs"), nil) //nolint:bodyclose // body is closed in handleResponse
151156
body, err := s.handleResponse(metadataResp, err)
152157
if err != nil {
153158
return "", fmt.Errorf("unable to fetch log metadata with ID=%s: %w", id, err)
@@ -160,9 +165,10 @@ func (s APIClient) submissionLogs(ctx context.Context, id string) (string, error
160165

161166
redactPresignedURLParams(resp.Data.Attributes.DeveloperLogURL)
162167

163-
// fetch logs without auth header (it's a presigned URL with its own auth)
168+
// fetch logs without auth (presigned URL), with redirect validation for SSRF protection.
169+
// use a larger size limit since log files can be bigger than typical API responses.
164170
logsResp, err := s.http.getUnauthenticated(ctx, resp.Data.Attributes.DeveloperLogURL)
165-
contents, err := s.handleResponse(logsResp, err)
171+
contents, err := s.handleResponseWithLimit(logsResp, err, maxLogResponseSize)
166172
if err != nil {
167173
return "", fmt.Errorf("unable to fetch log destination with ID=%s: %w", id, err)
168174
}
@@ -171,16 +177,28 @@ func (s APIClient) submissionLogs(ctx context.Context, id string) (string, error
171177
}
172178

173179
func (s APIClient) handleResponse(response *http.Response, err error) ([]byte, error) {
180+
return s.handleResponseWithLimit(response, err, maxAPIResponseSize)
181+
}
182+
183+
func (s APIClient) handleResponseWithLimit(response *http.Response, err error, maxBytes int64) ([]byte, error) {
184+
// ensure body is always closed, even if there's an error
185+
if response != nil && response.Body != nil {
186+
defer response.Body.Close()
187+
}
188+
174189
if err != nil {
175190
return nil, err
176191
}
177192

193+
if response == nil {
194+
return nil, fmt.Errorf("nil response")
195+
}
196+
178197
var body []byte
179198

180199
if response.Body != nil {
181-
defer response.Body.Close()
182-
183-
body, err = io.ReadAll(response.Body)
200+
// limit response size to prevent memory exhaustion from malicious responses
201+
body, err = utils.ReadAllLimited(response.Body, maxBytes)
184202
if err != nil {
185203
return nil, err
186204
}

quill/notary/api_client_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"net/http"
77
"net/http/httptest"
8+
"strings"
89
"testing"
910
"time"
1011

@@ -195,3 +196,151 @@ func Test_apiClient_submissionLogs_rejectsDeniedURLs(t *testing.T) {
195196
})
196197
}
197198
}
199+
200+
func Test_apiClient_handleResponse_enforcesMaxSize(t *testing.T) {
201+
tests := []struct {
202+
name string
203+
size int
204+
wantErr require.ErrorAssertionFunc
205+
errContains string
206+
}{
207+
{
208+
name: "accepts response under limit",
209+
size: 1024, // 1 KB
210+
},
211+
{
212+
name: "accepts response at limit",
213+
size: maxAPIResponseSize,
214+
},
215+
{
216+
name: "rejects response over limit",
217+
size: maxAPIResponseSize + 1,
218+
wantErr: require.Error,
219+
errContains: "exceeds limit",
220+
},
221+
}
222+
223+
for _, tt := range tests {
224+
t.Run(tt.name, func(t *testing.T) {
225+
if tt.wantErr == nil {
226+
tt.wantErr = require.NoError
227+
}
228+
229+
mux := http.NewServeMux()
230+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
231+
// write a response of the specified size
232+
data := strings.Repeat("x", tt.size)
233+
w.Write([]byte(data))
234+
})
235+
236+
s := httptest.NewServer(mux)
237+
defer s.Close()
238+
239+
resp, err := http.Get(s.URL)
240+
require.NoError(t, err)
241+
242+
c := APIClient{}
243+
_, err = c.handleResponse(resp, nil)
244+
tt.wantErr(t, err)
245+
246+
if tt.errContains != "" {
247+
require.Contains(t, err.Error(), tt.errContains)
248+
}
249+
})
250+
}
251+
}
252+
253+
func Test_apiClient_handleResponseWithLimit_enforcesCustomLimit(t *testing.T) {
254+
customLimit := int64(100)
255+
256+
tests := []struct {
257+
name string
258+
size int
259+
wantErr require.ErrorAssertionFunc
260+
errContains string
261+
}{
262+
{
263+
name: "accepts response under custom limit",
264+
size: 50,
265+
},
266+
{
267+
name: "accepts response at custom limit",
268+
size: int(customLimit),
269+
},
270+
{
271+
name: "rejects response over custom limit",
272+
size: int(customLimit) + 1,
273+
wantErr: require.Error,
274+
errContains: "exceeds limit",
275+
},
276+
}
277+
278+
for _, tt := range tests {
279+
t.Run(tt.name, func(t *testing.T) {
280+
if tt.wantErr == nil {
281+
tt.wantErr = require.NoError
282+
}
283+
284+
mux := http.NewServeMux()
285+
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
286+
data := strings.Repeat("x", tt.size)
287+
w.Write([]byte(data))
288+
})
289+
290+
s := httptest.NewServer(mux)
291+
defer s.Close()
292+
293+
resp, err := http.Get(s.URL)
294+
require.NoError(t, err)
295+
296+
c := APIClient{}
297+
_, err = c.handleResponseWithLimit(resp, nil, customLimit)
298+
tt.wantErr(t, err)
299+
300+
if tt.errContains != "" {
301+
require.Contains(t, err.Error(), tt.errContains)
302+
}
303+
})
304+
}
305+
}
306+
307+
func Test_apiClient_submissionLogs_usesLargerLimit(t *testing.T) {
308+
// create a log response larger than maxAPIResponseSize but smaller than maxLogResponseSize
309+
logSize := maxAPIResponseSize + 1024 // just over the API limit
310+
require.True(t, logSize < maxLogResponseSize, "test assumes logSize < maxLogResponseSize")
311+
312+
id := "the-id"
313+
expectedLogResponse := submissionLogsResponse{
314+
Data: submissionLogsResponseData{
315+
submissionResponseDescriptor: submissionResponseDescriptor{
316+
Type: "the-ty",
317+
ID: id,
318+
},
319+
},
320+
}
321+
322+
mux := http.NewServeMux()
323+
mux.HandleFunc("/"+id+"/logs", func(w http.ResponseWriter, r *http.Request) {
324+
by, err := json.Marshal(expectedLogResponse)
325+
require.NoError(t, err)
326+
w.Write(by)
327+
})
328+
329+
mux.HandleFunc("/place-where-the-logs-are", func(w http.ResponseWriter, r *http.Request) {
330+
// write a large log response
331+
data := strings.Repeat("x", logSize)
332+
w.Write([]byte(data))
333+
})
334+
335+
s := httptest.NewServer(mux)
336+
expectedLogResponse.Data.Attributes.DeveloperLogURL = s.URL + "/place-where-the-logs-are"
337+
defer s.Close()
338+
339+
c := newTestAPIClient("the-token", time.Second*30)
340+
c.api = s.URL
341+
342+
// this should succeed because logs use maxLogResponseSize, not maxAPIResponseSize
343+
actual, err := c.submissionLogs(context.Background(), id)
344+
require.NoError(t, err)
345+
require.Len(t, actual, logSize)
346+
}

0 commit comments

Comments
 (0)