diff --git a/providers/aws/resources/aws_s3.go b/providers/aws/resources/aws_s3.go index 2007e0ec5a..ef3af0cfe3 100644 --- a/providers/aws/resources/aws_s3.go +++ b/providers/aws/resources/aws_s3.go @@ -12,6 +12,7 @@ import ( "sync" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/s3control" @@ -166,17 +167,18 @@ func initAwsS3Bucket(runtime *plugin.Runtime, args map[string]*llx.RawData) (map } // construct arn of bucket name if misssing - var arn string + var arnVal string if args["arn"] != nil { - arn = args["arn"].Value.(string) - if !strings.HasPrefix(arn, "arn:aws:s3:") { - return nil, nil, errors.Newf("not a valid bucket ARN '%s'", arn) + arnVal = args["arn"].Value.(string) + parsed, err := arn.Parse(arnVal) + if err != nil || parsed.Service != "s3" { + return nil, nil, errors.Newf("not a valid bucket ARN '%s'", arnVal) } } else { nameVal := args["name"].Value.(string) - arn = fmt.Sprintf(s3ArnPattern, nameVal) + arnVal = fmt.Sprintf(s3ArnPattern, nameVal) } - log.Debug().Str("arn", arn).Msg("init s3 bucket with arn") + log.Debug().Str("arn", arnVal).Msg("init s3 bucket with arn") // load all s3 buckets obj, err := runtime.CreateResource(runtime, "aws.s3", map[string]*llx.RawData{}) @@ -193,21 +195,21 @@ func initAwsS3Bucket(runtime *plugin.Runtime, args map[string]*llx.RawData) (map // iterate over security groups and find the one with the arn for _, rawResource := range rawResources.Data { bucket := rawResource.(*mqlAwsS3Bucket) - if bucket.Arn.Data == arn { + if bucket.Arn.Data == arnVal { return args, bucket, nil } } // it is possible for a resource to reference a non-existent/deleted bucket, so here we // create the object, noting that it no longer exists but is still recorded as part of some resources - splitArn := strings.Split(arn, ":::") + splitArn := strings.Split(arnVal, ":::") if len(splitArn) != 2 { return args, nil, nil } name := splitArn[1] - log.Debug().Msgf("no bucket found for %s", arn) + log.Debug().Msgf("no bucket found for %s", arnVal) mqlAwsS3Bucket, err := CreateResource(runtime, "aws.s3.bucket", map[string]*llx.RawData{ - "arn": llx.StringData(arn), + "arn": llx.StringData(arnVal), "name": llx.StringData(name), "exists": llx.BoolData(false), }) diff --git a/providers/aws/resources/aws_s3_test.go b/providers/aws/resources/aws_s3_test.go new file mode 100644 index 0000000000..e2abbe8c38 --- /dev/null +++ b/providers/aws/resources/aws_s3_test.go @@ -0,0 +1,34 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package resources + +import ( + "testing" + + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/stretchr/testify/assert" +) + +func TestS3BucketArnValidation(t *testing.T) { + tests := []struct { + name string + arnStr string + valid bool + }{ + {"standard partition", "arn:aws:s3:::my-bucket", true}, + {"govcloud partition", "arn:aws-us-gov:s3:::my-bucket", true}, + {"china partition", "arn:aws-cn:s3:::my-bucket", true}, + {"wrong service", "arn:aws:ec2:us-east-1:123456789012:instance/i-1234", false}, + {"not an ARN", "not-an-arn", false}, + {"empty string", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := arn.Parse(tt.arnStr) + isValidS3 := err == nil && parsed.Service == "s3" + assert.Equal(t, tt.valid, isValidS3) + }) + } +}