diff --git a/providers/aws/resources/aws_kms.go b/providers/aws/resources/aws_kms.go index 07ea8ff1e1..7a9743e5ff 100644 --- a/providers/aws/resources/aws_kms.go +++ b/providers/aws/resources/aws_kms.go @@ -7,10 +7,12 @@ import ( "context" "errors" "fmt" + "strings" "sync" "time" "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/google/uuid" "github.com/aws/aws-sdk-go-v2/service/kms" "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/rs/zerolog/log" @@ -27,6 +29,37 @@ const ( kmsKeyArnPattern = "arn:aws:kms:%s:%s:key/%s" ) +// NormalizeKmsKeyRef normalizes a KMS key reference to ARN format. +// KeyID supports multiple formats: Key ID (UUID), Key ARN, Alias Name, Alias ARN +func NormalizeKmsKeyRef(s, region, accountId string) (arn.ARN, error) { + // Try ARN parse first (common case) + parsed, arnErr := arn.Parse(s) + if arnErr == nil { + return parsed, nil + } + + // Fallback: check if it's a key ID (UUID format) + // Example: 7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2 + // This is an edge case where Secrets Manager returns just the key ID + if _, uuidErr := uuid.Parse(s); uuidErr == nil { + if region == "" { + return arn.ARN{}, fmt.Errorf("cannot normalize KMS key UUID %q without a region", s) + } + return arn.ARN{ + Partition: "aws", + Service: "kms", + Region: region, + AccountID: accountId, + Resource: "key/" + s, + }, nil + } + + // Todo add alias format handling here for Alias name and Alias ARN + + // If both checks fail, propagate the ARN parse error for better diagnostics + return arn.ARN{}, fmt.Errorf("invalid KMS key reference %q: %w", s, arnErr) +} + func (a *mqlAwsKms) id() (string, error) { return "aws.kms", nil } @@ -96,7 +129,23 @@ func (a *mqlAwsKms) getKeys(conn *connection.AwsConnection) []*jobpool.Job { return tasks } +// isCrossAccountKey returns true if this KMS key belongs to a different AWS account. +// Cross-account keys cannot be queried for details like metadata, tags, or aliases. +func (a *mqlAwsKmsKey) isCrossAccountKey() bool { + conn := a.MqlRuntime.Connection.(*connection.AwsConnection) + keyArn := a.Arn.Data + parsed, err := arn.Parse(keyArn) + if err != nil { + log.Warn().Err(err).Str("arn", keyArn).Msg("unable to parse KMS key ARN, treating as same-account") + return false + } + return parsed.AccountID != conn.AccountId() +} + func (a *mqlAwsKmsKey) metadata() (any, error) { + if a.isCrossAccountKey() { + return nil, nil + } md, err := a.getKeyMetadata() if err != nil { return nil, err @@ -105,6 +154,9 @@ func (a *mqlAwsKmsKey) metadata() (any, error) { } func (a *mqlAwsKmsKey) keyRotationEnabled() (bool, error) { + if a.isCrossAccountKey() { + return false, nil + } conn := a.MqlRuntime.Connection.(*connection.AwsConnection) keyId := a.Id.Data @@ -119,6 +171,9 @@ func (a *mqlAwsKmsKey) keyRotationEnabled() (bool, error) { } func (a *mqlAwsKmsKey) tags() (map[string]any, error) { + if a.isCrossAccountKey() { + return nil, nil + } conn := a.MqlRuntime.Connection.(*connection.AwsConnection) keyArn := a.Arn.Data @@ -142,6 +197,9 @@ func (a *mqlAwsKmsKey) tags() (map[string]any, error) { } func (a *mqlAwsKmsKey) aliases() ([]any, error) { + if a.isCrossAccountKey() { + return []any{}, nil + } conn := a.MqlRuntime.Connection.(*connection.AwsConnection) keyArn := a.Arn.Data @@ -164,6 +222,9 @@ func (a *mqlAwsKmsKey) aliases() ([]any, error) { } func (a *mqlAwsKmsKey) keyState() (string, error) { + if a.isCrossAccountKey() { + return "", nil + } md, err := a.getKeyMetadata() if err != nil { return "", err @@ -201,6 +262,9 @@ func (a *mqlAwsKmsKey) getKeyMetadata() (*types.KeyMetadata, error) { } func (a *mqlAwsKmsKey) createdAt() (*time.Time, error) { + if a.isCrossAccountKey() { + return nil, nil + } md, err := a.getKeyMetadata() if err != nil { return nil, err @@ -209,6 +273,9 @@ func (a *mqlAwsKmsKey) createdAt() (*time.Time, error) { } func (a *mqlAwsKmsKey) deletedAt() (*time.Time, error) { + if a.isCrossAccountKey() { + return nil, nil + } md, err := a.getKeyMetadata() if err != nil { return nil, err @@ -217,6 +284,9 @@ func (a *mqlAwsKmsKey) deletedAt() (*time.Time, error) { } func (a *mqlAwsKmsKey) enabled() (bool, error) { + if a.isCrossAccountKey() { + return false, nil + } md, err := a.getKeyMetadata() if err != nil { return false, err @@ -225,6 +295,9 @@ func (a *mqlAwsKmsKey) enabled() (bool, error) { } func (a *mqlAwsKmsKey) description() (string, error) { + if a.isCrossAccountKey() { + return "", nil + } md, err := a.getKeyMetadata() if err != nil { return "", err @@ -237,6 +310,9 @@ func (a *mqlAwsKmsKey) id() (string, error) { } func (a *mqlAwsKmsKey) grants() ([]any, error) { + if a.isCrossAccountKey() { + return []any{}, nil + } conn := a.MqlRuntime.Connection.(*connection.AwsConnection) keyArn := a.Arn.Data @@ -300,10 +376,38 @@ func initAwsKmsKey(runtime *plugin.Runtime, args map[string]*llx.RawData) (map[s if !ok { return nil, nil, errors.New("invalid arn") } - arnVal, err := arn.Parse(a) - if arnVal.AccountID != runtime.Connection.(*connection.AwsConnection).AccountId() { - // sometimes there are references to keys in other accounts, like the master account of an org - return nil, nil, fmt.Errorf("cannot access key from different AWS account %q", arnVal.AccountID) + + conn := runtime.Connection.(*connection.AwsConnection) + + // Get region from args if provided (needed when input is just a key ID, not full ARN) + var region string + if regionArg := args["region"]; regionArg != nil { + if r, ok := regionArg.Value.(string); ok { + region = r + } + } + + // KMS Keys API calls use KeyID = support multiple formats aliases, UUID(ID) or arn format we would not need to normalize here but there seems to be some edge cases + // for Example Secrets Manager can returns just the key ID instead of full ARN in KmsKeyId for EventBride connection secrets + // Current code only provides arn to kms function + arnVal, err := NormalizeKmsKeyRef(a, region, conn.AccountId()) + if err != nil { + return nil, nil, err + } + + // Use normalized ARN for cache lookup and resource creation + normalizedArn := arnVal.String() + args["arn"] = llx.StringData(normalizedArn) + args["region"] = llx.StringData(arnVal.Region) + + if arnVal.AccountID != conn.AccountId() { + // Cross-account key: we can't fetch details, but we should still return the ARN + // so security tools can see which KMS key is referenced + log.Warn().Str("arn", normalizedArn).Str("currentAccount", conn.AccountId()).Str("keyAccount", arnVal.AccountID).Msg("cross-account KMS keys are not supported yet, returning ARN only") + // Extract key ID from the ARN resource part (e.g., "key/uuid" -> "uuid") + keyId := strings.TrimPrefix(arnVal.Resource, "key/") + args["id"] = llx.StringData(keyId) + return args, nil, nil } obj, err := CreateResource(runtime, ResourceAwsKms, map[string]*llx.RawData{}) @@ -316,11 +420,18 @@ func initAwsKmsKey(runtime *plugin.Runtime, args map[string]*llx.RawData) (map[s if rawResources.Error != nil { return nil, nil, rawResources.Error } + for _, rawResource := range rawResources.Data { key := rawResource.(*mqlAwsKmsKey) - if key.Arn.Data == a { + // Match by ARN or by key ID (for UUID-only inputs) + if key.Arn.Data == normalizedArn || key.Id.Data == a { + // Use actual values from cache + args["arn"] = llx.StringData(key.Arn.Data) + args["region"] = llx.StringData(key.Region.Data) + args["id"] = llx.StringData(key.Id.Data) return args, key, nil } } + return nil, nil, errors.New("key not found") } diff --git a/providers/aws/resources/aws_kms_test.go b/providers/aws/resources/aws_kms_test.go new file mode 100644 index 0000000000..60fd1b6b37 --- /dev/null +++ b/providers/aws/resources/aws_kms_test.go @@ -0,0 +1,71 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package resources + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNormalizeKmsKeyRef(t *testing.T) { + tests := []struct { + name string + input string + region string + accountId string + wantARN string + wantErr string + }{ + { + name: "full ARN is returned as-is", + input: "arn:aws:kms:us-east-1:123456789012:key/7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2", + region: "us-west-2", + accountId: "999999999999", + wantARN: "arn:aws:kms:us-east-1:123456789012:key/7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2", + }, + { + name: "bare UUID is normalized to ARN", + input: "7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2", + region: "us-east-1", + accountId: "123456789012", + wantARN: "arn:aws:kms:us-east-1:123456789012:key/7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2", + }, + { + name: "bare UUID with empty region returns error", + input: "7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2", + region: "", + accountId: "123456789012", + wantErr: "cannot normalize KMS key UUID", + }, + { + name: "invalid input returns error", + input: "not-a-valid-key-ref", + region: "us-east-1", + accountId: "123456789012", + wantErr: "invalid KMS key reference", + }, + { + name: "alias ARN is returned as-is", + input: "arn:aws:kms:us-east-1:123456789012:alias/my-key", + region: "us-east-1", + accountId: "123456789012", + wantARN: "arn:aws:kms:us-east-1:123456789012:alias/my-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NormalizeKmsKeyRef(tt.input, tt.region, tt.accountId) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantARN, got.String()) + }) + } +}