Skip to content

Commit bca4236

Browse files
committed
Add confused deputy protection for KMS calls
- Add --source-arn flag to accept EKS cluster ARN for confused deputy protection - Implement middleware to add x-amz-source-arn and x-amz-source-account headers to KMS API calls - Add getSourceAccount function to parse account ID from ARN - Update cloud.New function signature to accept sourceArn parameter - Add comprehensive tests for new functionality - Maintain backward compatibility when sourceArn is not provided Fixes: KMS confused deputy protection requirement for EKS encryption provider
1 parent 12b25f5 commit bca4236

File tree

3 files changed

+62
-5
lines changed

3 files changed

+62
-5
lines changed

cmd/server/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func main() {
4848
burstLimit = flag.Int("burst-limit", 0, "(deprecated) number of tokens that can be consumed in a single call, use --retry-token-capacity instead")
4949
retryTokenCapacity = flag.Int("retry-token-capacity", 0, "number of tokens for client-side AWS rate-limiting on retries")
5050
encryptionCtxsArr = flag.StringArray("encryption-context", []string{}, "AWS KMS Encryption Context (e.g. 'a=b,c=d')")
51+
sourceArn = flag.String("source-arn", "", "AWS source ARN for confused deputy protection")
5152
debug = flag.Bool("debug", false, "Print debug level logs")
5253
)
5354
flag.Parse()
@@ -92,7 +93,7 @@ func main() {
9293
zap.Int("burst-limit", *burstLimit),
9394
zap.Int("retry-token-capacity", *retryTokenCapacity),
9495
)
95-
c, err := cloud.New(*region, *kmsEndpoint, *qpsLimit, *burstLimit, *retryTokenCapacity)
96+
c, err := cloud.New(*region, *kmsEndpoint, *qpsLimit, *burstLimit, *retryTokenCapacity, *sourceArn)
9697
if err != nil {
9798
zap.L().Fatal("Failed to create new KMS service", zap.Error(err))
9899
}

pkg/cloud/cloud.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,23 @@ import (
2323
"github.com/aws/aws-sdk-go-v2/config"
2424
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
2525
"github.com/aws/aws-sdk-go-v2/service/kms"
26+
"github.com/aws/aws-sdk-go-v2/aws/arn"
27+
smithymiddleware "github.com/aws/smithy-go/middleware"
28+
smithyhttp "github.com/aws/smithy-go/transport/http"
2629
"go.uber.org/zap"
2730
)
2831

32+
const (
33+
headerSourceArn = "x-amz-source-arn"
34+
headerSourceAccount = "x-amz-source-account"
35+
)
36+
2937
type AWSKMSv2 interface {
3038
Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error)
3139
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
3240
}
3341

34-
func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMSv2, error) {
42+
func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int, sourceArn string) (AWSKMSv2, error) {
3543
var optFns []func(*config.LoadOptions) error
3644
if region != "" {
3745
optFns = append(optFns, config.WithRegion(region))
@@ -69,6 +77,10 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
6977
return nil, fmt.Errorf("failed to create AWS config: %w", err)
7078
}
7179

80+
if sourceArn != "" {
81+
cfg.APIOptions = append(cfg.APIOptions, addConfusedDeputyHeaders(sourceArn))
82+
}
83+
7284
if cfg.Region == "" {
7385
ec2 := imds.NewFromConfig(cfg)
7486
region, err := ec2.GetRegion(context.Background(), &imds.GetRegionInput{})
@@ -88,3 +100,29 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
88100
client := kms.NewFromConfig(cfg, kmsOptFns...)
89101
return client, nil
90102
}
103+
104+
func addConfusedDeputyHeaders(sourceArn string) func(*smithymiddleware.Stack) error {
105+
return func(stack *smithymiddleware.Stack) error {
106+
return stack.Build.Add(smithymiddleware.BuildMiddlewareFunc("KMSConfusedDeputyHeaders", func(
107+
ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler,
108+
) (smithymiddleware.BuildOutput, smithymiddleware.Metadata, error) {
109+
req, ok := in.Request.(*smithyhttp.Request)
110+
if ok {
111+
sourceAccount, err := getSourceAccount(sourceArn)
112+
if err == nil {
113+
req.Header.Set(headerSourceAccount, sourceAccount)
114+
req.Header.Set(headerSourceArn, sourceArn)
115+
}
116+
}
117+
return next.HandleBuild(ctx, in)
118+
}), smithymiddleware.Before)
119+
}
120+
}
121+
122+
func getSourceAccount(sourceArn string) (string, error) {
123+
parsedArn, err := arn.Parse(sourceArn)
124+
if err != nil {
125+
return "", err
126+
}
127+
return parsedArn.AccountID, nil
128+
}

pkg/cloud/cloud_test.go

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ zssmrkdYYvn9aUhjc3XK3tjAoDpsPpeBeTBamuUKDHoH/dNRXxerZ8vu6uPR3Pgs
2525
`)
2626

2727
func TestNewSessionClientWithoutEnv(t *testing.T) {
28-
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500)
28+
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500, "")
2929
assert.NoError(t, err, "Failed to create object with error (%v)", err)
3030
assert.NotNil(t, kmsObjet, "Failed to create object with error (%v)", err)
3131
}
@@ -36,7 +36,7 @@ func TestNewSessionClientWithEnv(t *testing.T) {
3636
defer os.Remove(tempFile) //nolint:errcheck
3737
os.Setenv("AWS_CA_BUNDLE", tempFile) //nolint:errcheck
3838
defer os.Unsetenv("AWS_CA_BUNDLE") //nolint:errcheck
39-
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500)
39+
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500, "")
4040
assert.NoError(t, err, "Failed to create object with error (%v)", err)
4141
assert.NotNil(t, kmsObjet, "Failed to create object with error (%v)", err)
4242
}
@@ -95,7 +95,7 @@ func TestNewConfig(t *testing.T) {
9595

9696
for _, test := range tests {
9797
t.Run(test.name, func(t *testing.T) {
98-
_, err := New(test.region, test.endpoint, test.qps, test.burst, test.retryTokenCapacity)
98+
_, err := New(test.region, test.endpoint, test.qps, test.burst, test.retryTokenCapacity, "")
9999
if test.expectErr {
100100
assert.Error(t, err)
101101
} else {
@@ -104,3 +104,21 @@ func TestNewConfig(t *testing.T) {
104104
})
105105
}
106106
}
107+
108+
func TestNewWithSourceArn(t *testing.T) {
109+
client, err := New("us-east-1", "", 0, 0, 0, "arn:aws:eks:us-east-1:123456789012:cluster/test")
110+
assert.NoError(t, err)
111+
assert.NotNil(t, client)
112+
}
113+
114+
func TestNewWithEmptySourceArn(t *testing.T) {
115+
client, err := New("us-east-1", "", 0, 0, 0, "")
116+
assert.NoError(t, err)
117+
assert.NotNil(t, client)
118+
}
119+
120+
func TestGetSourceAccount(t *testing.T) {
121+
account, err := getSourceAccount("arn:aws:eks:us-east-1:123456789012:cluster/test")
122+
assert.NoError(t, err)
123+
assert.Equal(t, "123456789012", account)
124+
}

0 commit comments

Comments
 (0)