Skip to content

Commit 0831a48

Browse files
committed
Add confused deputy protection for KMS calls
- Add --source-arn flag to accept EKS cluster ARN for confused deputy protection - Implement middleware to add x-amz-source-arn and x-amz-source-account headers to KMS API calls - Add getSourceAccount function to parse account ID from ARN - Update cloud.New function signature to accept sourceArn parameter - Add comprehensive tests for new functionality - Maintain backward compatibility when sourceArn is not provided Fixes: KMS confused deputy protection requirement for EKS encryption provider
1 parent bca4236 commit 0831a48

File tree

2 files changed

+38
-16
lines changed

2 files changed

+38
-16
lines changed

pkg/cloud/cloud.go

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int, sourceA
7777
return nil, fmt.Errorf("failed to create AWS config: %w", err)
7878
}
7979

80-
if sourceArn != "" {
81-
cfg.APIOptions = append(cfg.APIOptions, addConfusedDeputyHeaders(sourceArn))
82-
}
80+
addConfusedDeputyHeaders(&cfg, sourceArn)
8381

8482
if cfg.Region == "" {
8583
ec2 := imds.NewFromConfig(cfg)
@@ -101,28 +99,46 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int, sourceA
10199
return client, nil
102100
}
103101

104-
func addConfusedDeputyHeaders(sourceArn string) func(*smithymiddleware.Stack) error {
105-
return func(stack *smithymiddleware.Stack) error {
106-
return stack.Build.Add(smithymiddleware.BuildMiddlewareFunc("KMSConfusedDeputyHeaders", func(
107-
ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler,
108-
) (smithymiddleware.BuildOutput, smithymiddleware.Metadata, error) {
109-
req, ok := in.Request.(*smithyhttp.Request)
110-
if ok {
111-
sourceAccount, err := getSourceAccount(sourceArn)
112-
if err == nil {
102+
func addConfusedDeputyHeaders(cfg *aws.Config, sourceArn string) {
103+
if sourceArn != "" {
104+
sourceAccount, err := getSourceAccount(sourceArn)
105+
if err != nil {
106+
panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceArn, err))
107+
}
108+
109+
cfg.APIOptions = append(cfg.APIOptions, func(stack *smithymiddleware.Stack) error {
110+
return stack.Build.Add(smithymiddleware.BuildMiddlewareFunc("KMSConfusedDeputyHeaders", func(
111+
ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler,
112+
) (smithymiddleware.BuildOutput, smithymiddleware.Metadata, error) {
113+
req, ok := in.Request.(*smithyhttp.Request)
114+
if ok {
113115
req.Header.Set(headerSourceAccount, sourceAccount)
114116
req.Header.Set(headerSourceArn, sourceArn)
115117
}
116-
}
117-
return next.HandleBuild(ctx, in)
118-
}), smithymiddleware.Before)
118+
return next.HandleBuild(ctx, in)
119+
}), smithymiddleware.Before)
120+
})
119121
}
122+
123+
zap.L().Info("configuring KMS client with confused deputy headers",
124+
zap.String("sourceArn", sourceArn),
125+
zap.String("sourceAccount", sourceAccount)
126+
)
120127
}
121128

129+
// getSourceAccount constructs source account and return them for use
122130
func getSourceAccount(sourceArn string) (string, error) {
131+
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
132+
// arn:partition:service:region:account-id:resource-type/resource-id
133+
// arn:aws:eks:region:account:cluster/cluster-name
134+
if !arn.IsARN(roleARN) {
135+
return "", fmt.Errorf("incorrect ARN format for role %s", roleARN)
136+
}
137+
123138
parsedArn, err := arn.Parse(sourceArn)
124139
if err != nil {
125140
return "", err
126141
}
142+
127143
return parsedArn.AccountID, nil
128144
}

pkg/cloud/cloud_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,14 @@ func TestNewWithEmptySourceArn(t *testing.T) {
117117
assert.NotNil(t, client)
118118
}
119119

120+
func TestNewWithMalformedSourceArn(t *testing.T) {
121+
assert.Panics(t, func() {
122+
New("us-east-1", "", 0, 0, 0, "invalid-arn-format")
123+
})
124+
}
125+
120126
func TestGetSourceAccount(t *testing.T) {
121127
account, err := getSourceAccount("arn:aws:eks:us-east-1:123456789012:cluster/test")
122128
assert.NoError(t, err)
123129
assert.Equal(t, "123456789012", account)
124-
}
130+
}

0 commit comments

Comments
 (0)