Skip to content

Commit 4f07d24

Browse files
authored
Merge pull request #501 from liamg/patch-1
fix: Avoid panic when s3 URL is invalid
2 parents 5a63fd9 + 8339301 commit 4f07d24

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

get_s3.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -268,23 +268,35 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
268268
region = "us-east-1"
269269
}
270270
pathParts := strings.SplitN(u.Path, "/", 3)
271+
if len(pathParts) < 3 {
272+
err = fmt.Errorf("URL is not a valid S3 URL")
273+
return
274+
}
271275
bucket = pathParts[1]
272276
path = pathParts[2]
273277
// vhost-style, dash region indication
274278
case 4:
275-
// Parse the region out of the first part of the host
279+
// Parse the region out of the second part of the host
276280
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[1], "s3-"), "s3")
277281
if region == "" {
278282
err = fmt.Errorf("URL is not a valid S3 URL")
279283
return
280284
}
281285
pathParts := strings.SplitN(u.Path, "/", 2)
286+
if len(pathParts) < 2 {
287+
err = fmt.Errorf("URL is not a valid S3 URL")
288+
return
289+
}
282290
bucket = hostParts[0]
283291
path = pathParts[1]
284292
//vhost-style, dot region indication
285293
case 5:
286294
region = hostParts[2]
287295
pathParts := strings.SplitN(u.Path, "/", 2)
296+
if len(pathParts) < 2 {
297+
err = fmt.Errorf("URL is not a valid S3 URL")
298+
return
299+
}
288300
bucket = hostParts[0]
289301
path = pathParts[1]
290302

get_s3_test.go

+58-7
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,13 @@ func TestS3Getter_ClientMode_collision(t *testing.T) {
165165

166166
func TestS3Getter_Url(t *testing.T) {
167167
var s3tests = []struct {
168-
name string
169-
url string
170-
region string
171-
bucket string
172-
path string
173-
version string
168+
name string
169+
url string
170+
region string
171+
bucket string
172+
path string
173+
version string
174+
expectedErr string
174175
}{
175176
{
176177
name: "AWSv1234",
@@ -220,6 +221,11 @@ func TestS3Getter_Url(t *testing.T) {
220221
path: "hello.txt",
221222
version: "",
222223
},
224+
{
225+
name: "malformed s3 url",
226+
url: "s3::https://s3.amazonaws.com/bucket",
227+
expectedErr: "URL is not a valid S3 URL",
228+
},
223229
}
224230

225231
for i, pt := range s3tests {
@@ -238,7 +244,15 @@ func TestS3Getter_Url(t *testing.T) {
238244
region, bucket, path, version, creds, err := g.parseUrl(u)
239245

240246
if err != nil {
241-
t.Fatalf("err: %s", err)
247+
if pt.expectedErr == "" {
248+
t.Fatalf("err: %s", err)
249+
}
250+
if err.Error() != pt.expectedErr {
251+
t.Fatalf("expected %s, got %s", pt.expectedErr, err.Error())
252+
}
253+
return
254+
} else if pt.expectedErr != "" {
255+
t.Fatalf("expected error, got none")
242256
}
243257
if region != pt.region {
244258
t.Fatalf("expected %s, got %s", pt.region, region)
@@ -258,3 +272,40 @@ func TestS3Getter_Url(t *testing.T) {
258272
})
259273
}
260274
}
275+
276+
func Test_S3Getter_ParseUrl_Malformed(t *testing.T) {
277+
tests := []struct {
278+
name string
279+
url string
280+
}{
281+
{
282+
name: "path style",
283+
url: "https://s3.amazonaws.com/bucket",
284+
},
285+
{
286+
name: "vhost-style, dash region indication",
287+
url: "https://bucket.s3-us-east-1.amazonaws.com",
288+
},
289+
{
290+
name: "vhost-style, dot region indication",
291+
url: "https://bucket.s3.us-east-1.amazonaws.com",
292+
},
293+
}
294+
for _, tt := range tests {
295+
t.Run(tt.name, func(t *testing.T) {
296+
g := new(S3Getter)
297+
u, err := url.Parse(tt.url)
298+
if err != nil {
299+
t.Fatalf("unexpected error: %s", err)
300+
}
301+
_, _, _, _, _, err = g.parseUrl(u)
302+
if err == nil {
303+
t.Fatalf("expected error, got none")
304+
}
305+
if err.Error() != "URL is not a valid S3 URL" {
306+
t.Fatalf("expected error 'URL is not a valid S3 URL', got %s", err.Error())
307+
}
308+
})
309+
}
310+
311+
}

0 commit comments

Comments
 (0)