Skip to content
121 changes: 116 additions & 5 deletions providers/aws/resources/aws_kms.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"context"
"errors"
"fmt"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws/arn"
"github.com/google/uuid"

Check failure on line 15 in providers/aws/resources/aws_kms.go

View workflow job for this annotation

GitHub Actions / golangci-lint-providers / lint

File is not properly formatted (gofmt)
"github.com/aws/aws-sdk-go-v2/service/kms"
"github.com/aws/aws-sdk-go-v2/service/kms/types"
"github.com/rs/zerolog/log"
Expand All @@ -27,6 +29,37 @@
kmsKeyArnPattern = "arn:aws:kms:%s:%s:key/%s"
)

// NormalizeKmsKeyRef normalizes a KMS key reference to ARN format.
// KeyID supports multiple formats: Key ID (UUID), Key ARN, Alias Name, Alias ARN
func NormalizeKmsKeyRef(s, region, accountId string) (arn.ARN, error) {
// Try ARN parse first (common case)
parsed, arnErr := arn.Parse(s)
if arnErr == nil {
return parsed, nil
}

// Fallback: check if it's a key ID (UUID format)
// Example: 7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2
// This is an edge case where Secrets Manager returns just the key ID
if _, uuidErr := uuid.Parse(s); uuidErr == nil {
if region == "" {
return arn.ARN{}, fmt.Errorf("cannot normalize KMS key UUID %q without a region", s)
}
return arn.ARN{
Partition: "aws",
Service: "kms",
Region: region,
AccountID: accountId,
Resource: "key/" + s,
}, nil
}

// Todo add alias format handling here for Alias name and Alias ARN

// If both checks fail, propagate the ARN parse error for better diagnostics
return arn.ARN{}, fmt.Errorf("invalid KMS key reference %q: %w", s, arnErr)
}

func (a *mqlAwsKms) id() (string, error) {
return "aws.kms", nil
}
Expand Down Expand Up @@ -96,7 +129,23 @@
return tasks
}

// isCrossAccountKey returns true if this KMS key belongs to a different AWS account.
// Cross-account keys cannot be queried for details like metadata, tags, or aliases.
func (a *mqlAwsKmsKey) isCrossAccountKey() bool {
conn := a.MqlRuntime.Connection.(*connection.AwsConnection)
keyArn := a.Arn.Data
parsed, err := arn.Parse(keyArn)
if err != nil {
log.Warn().Err(err).Str("arn", keyArn).Msg("unable to parse KMS key ARN, treating as same-account")
return false
}
return parsed.AccountID != conn.AccountId()
}

func (a *mqlAwsKmsKey) metadata() (any, error) {
if a.isCrossAccountKey() {
return nil, nil
}
md, err := a.getKeyMetadata()
if err != nil {
return nil, err
Expand All @@ -105,6 +154,9 @@
}

func (a *mqlAwsKmsKey) keyRotationEnabled() (bool, error) {
if a.isCrossAccountKey() {
return false, nil
}
conn := a.MqlRuntime.Connection.(*connection.AwsConnection)
keyId := a.Id.Data

Expand All @@ -119,6 +171,9 @@
}

func (a *mqlAwsKmsKey) tags() (map[string]any, error) {
if a.isCrossAccountKey() {
return nil, nil
}
conn := a.MqlRuntime.Connection.(*connection.AwsConnection)
keyArn := a.Arn.Data

Expand All @@ -142,6 +197,9 @@
}

func (a *mqlAwsKmsKey) aliases() ([]any, error) {
if a.isCrossAccountKey() {
return []any{}, nil
}
conn := a.MqlRuntime.Connection.(*connection.AwsConnection)
keyArn := a.Arn.Data

Expand All @@ -164,6 +222,9 @@
}

func (a *mqlAwsKmsKey) keyState() (string, error) {
if a.isCrossAccountKey() {
return "", nil
}
md, err := a.getKeyMetadata()
if err != nil {
return "", err
Expand Down Expand Up @@ -201,6 +262,9 @@
}

func (a *mqlAwsKmsKey) createdAt() (*time.Time, error) {
if a.isCrossAccountKey() {
return nil, nil
}
md, err := a.getKeyMetadata()
if err != nil {
return nil, err
Expand All @@ -209,6 +273,9 @@
}

func (a *mqlAwsKmsKey) deletedAt() (*time.Time, error) {
if a.isCrossAccountKey() {
return nil, nil
}
md, err := a.getKeyMetadata()
if err != nil {
return nil, err
Expand All @@ -217,6 +284,9 @@
}

func (a *mqlAwsKmsKey) enabled() (bool, error) {
if a.isCrossAccountKey() {
return false, nil
}
md, err := a.getKeyMetadata()
if err != nil {
return false, err
Expand All @@ -225,6 +295,9 @@
}

func (a *mqlAwsKmsKey) description() (string, error) {
if a.isCrossAccountKey() {
return "", nil
}
md, err := a.getKeyMetadata()
if err != nil {
return "", err
Expand All @@ -237,6 +310,9 @@
}

func (a *mqlAwsKmsKey) grants() ([]any, error) {
if a.isCrossAccountKey() {
return []any{}, nil
}
conn := a.MqlRuntime.Connection.(*connection.AwsConnection)
keyArn := a.Arn.Data

Expand Down Expand Up @@ -300,10 +376,38 @@
if !ok {
return nil, nil, errors.New("invalid arn")
}
arnVal, err := arn.Parse(a)
if arnVal.AccountID != runtime.Connection.(*connection.AwsConnection).AccountId() {
// sometimes there are references to keys in other accounts, like the master account of an org
return nil, nil, fmt.Errorf("cannot access key from different AWS account %q", arnVal.AccountID)

conn := runtime.Connection.(*connection.AwsConnection)

// Get region from args if provided (needed when input is just a key ID, not full ARN)
var region string
if regionArg := args["region"]; regionArg != nil {
if r, ok := regionArg.Value.(string); ok {
region = r
}
}

// KMS Keys API calls use KeyID = support multiple formats aliases, UUID(ID) or arn format we would not need to normalize here but there seems to be some edge cases
// for Example Secrets Manager can returns just the key ID instead of full ARN in KmsKeyId for EventBride connection secrets
// Current code only provides arn to kms function
arnVal, err := NormalizeKmsKeyRef(a, region, conn.AccountId())
if err != nil {
return nil, nil, err
}

// Use normalized ARN for cache lookup and resource creation
normalizedArn := arnVal.String()
args["arn"] = llx.StringData(normalizedArn)
args["region"] = llx.StringData(arnVal.Region)

if arnVal.AccountID != conn.AccountId() {
// Cross-account key: we can't fetch details, but we should still return the ARN
// so security tools can see which KMS key is referenced
log.Warn().Str("arn", normalizedArn).Str("currentAccount", conn.AccountId()).Str("keyAccount", arnVal.AccountID).Msg("cross-account KMS keys are not supported yet, returning ARN only")
// Extract key ID from the ARN resource part (e.g., "key/uuid" -> "uuid")
keyId := strings.TrimPrefix(arnVal.Resource, "key/")
args["id"] = llx.StringData(keyId)
return args, nil, nil
}

obj, err := CreateResource(runtime, ResourceAwsKms, map[string]*llx.RawData{})
Expand All @@ -316,11 +420,18 @@
if rawResources.Error != nil {
return nil, nil, rawResources.Error
}

for _, rawResource := range rawResources.Data {
key := rawResource.(*mqlAwsKmsKey)
if key.Arn.Data == a {
// Match by ARN or by key ID (for UUID-only inputs)
if key.Arn.Data == normalizedArn || key.Id.Data == a {
// Use actual values from cache
args["arn"] = llx.StringData(key.Arn.Data)
args["region"] = llx.StringData(key.Region.Data)
args["id"] = llx.StringData(key.Id.Data)
return args, key, nil
}
}

return nil, nil, errors.New("key not found")
}
71 changes: 71 additions & 0 deletions providers/aws/resources/aws_kms_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) Mondoo, Inc.
// SPDX-License-Identifier: BUSL-1.1

package resources

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNormalizeKmsKeyRef(t *testing.T) {
tests := []struct {
name string
input string
region string
accountId string
wantARN string
wantErr string
}{
{
name: "full ARN is returned as-is",
input: "arn:aws:kms:us-east-1:123456789012:key/7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2",
region: "us-west-2",
accountId: "999999999999",
wantARN: "arn:aws:kms:us-east-1:123456789012:key/7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2",
},
{
name: "bare UUID is normalized to ARN",
input: "7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2",
region: "us-east-1",
accountId: "123456789012",
wantARN: "arn:aws:kms:us-east-1:123456789012:key/7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2",
},
{
name: "bare UUID with empty region returns error",
input: "7a4eb143-c07b-4e24-b0b7-f3abfdbbb2c2",
region: "",
accountId: "123456789012",
wantErr: "cannot normalize KMS key UUID",
},
{
name: "invalid input returns error",
input: "not-a-valid-key-ref",
region: "us-east-1",
accountId: "123456789012",
wantErr: "invalid KMS key reference",
},
{
name: "alias ARN is returned as-is",
input: "arn:aws:kms:us-east-1:123456789012:alias/my-key",
region: "us-east-1",
accountId: "123456789012",
wantARN: "arn:aws:kms:us-east-1:123456789012:alias/my-key",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeKmsKeyRef(tt.input, tt.region, tt.accountId)
if tt.wantErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantARN, got.String())
})
}
}
Loading