@@ -23,15 +23,23 @@ import (
2323 "github.com/aws/aws-sdk-go-v2/config"
2424 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
2525 "github.com/aws/aws-sdk-go-v2/service/kms"
26+ "github.com/aws/aws-sdk-go-v2/aws/arn"
27+ smithymiddleware "github.com/aws/smithy-go/middleware"
28+ smithyhttp "github.com/aws/smithy-go/transport/http"
2629 "go.uber.org/zap"
2730)
2831
32+ const (
33+ headerSourceArn = "x-amz-source-arn"
34+ headerSourceAccount = "x-amz-source-account"
35+ )
36+
2937type AWSKMSv2 interface {
3038 Encrypt (ctx context.Context , params * kms.EncryptInput , optFns ... func (* kms.Options )) (* kms.EncryptOutput , error )
3139 Decrypt (ctx context.Context , params * kms.DecryptInput , optFns ... func (* kms.Options )) (* kms.DecryptOutput , error )
3240}
3341
34- func New (region , kmsEndpoint string , qps , burst , retryTokenCapacity int ) (AWSKMSv2 , error ) {
42+ func New (region , kmsEndpoint string , qps , burst , retryTokenCapacity int , sourceArn string ) (AWSKMSv2 , error ) {
3543 var optFns []func (* config.LoadOptions ) error
3644 if region != "" {
3745 optFns = append (optFns , config .WithRegion (region ))
@@ -69,6 +77,11 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
6977 return nil , fmt .Errorf ("failed to create AWS config: %w" , err )
7078 }
7179
80+ err = addConfusedDeputyHeaders (& cfg , sourceArn )
81+ if err != nil {
82+ return nil , err
83+ }
84+
7285 if cfg .Region == "" {
7386 ec2 := imds .NewFromConfig (cfg )
7487 region , err := ec2 .GetRegion (context .Background (), & imds.GetRegionInput {})
@@ -88,3 +101,48 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
88101 client := kms .NewFromConfig (cfg , kmsOptFns ... )
89102 return client , nil
90103}
104+
105+ func addConfusedDeputyHeaders (cfg * aws.Config , sourceArn string ) error {
106+ if sourceArn != "" {
107+ sourceAccount , err := getSourceAccount (sourceArn )
108+ if err != nil {
109+ return err
110+ }
111+
112+ cfg .APIOptions = append (cfg .APIOptions , func (stack * smithymiddleware.Stack ) error {
113+ return stack .Build .Add (smithymiddleware .BuildMiddlewareFunc ("KMSConfusedDeputyHeaders" , func (
114+ ctx context.Context , in smithymiddleware.BuildInput , next smithymiddleware.BuildHandler ,
115+ ) (smithymiddleware.BuildOutput , smithymiddleware.Metadata , error ) {
116+ req , ok := in .Request .(* smithyhttp.Request )
117+ if ok {
118+ req .Header .Set (headerSourceAccount , sourceAccount )
119+ req .Header .Set (headerSourceArn , sourceArn )
120+ }
121+ return next .HandleBuild (ctx , in )
122+ }), smithymiddleware .Before )
123+ })
124+
125+ zap .L ().Info ("configuring KMS client with confused deputy headers" ,
126+ zap .String ("sourceArn" , sourceArn ),
127+ zap .String ("sourceAccount" , sourceAccount ),
128+ )
129+ }
130+ return nil
131+ }
132+
133+ // getSourceAccount constructs source account and return them for use
134+ func getSourceAccount (sourceArn string ) (string , error ) {
135+ // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
136+ // arn:partition:service:region:account-id:resource-type/resource-id
137+ // arn:aws:eks:region:account:cluster/cluster-name
138+ if ! arn .IsARN (sourceArn ) {
139+ return "" , fmt .Errorf ("incorrect ARN format for source arn: %s" , sourceArn )
140+ }
141+
142+ parsedArn , err := arn .Parse (sourceArn )
143+ if err != nil {
144+ return "" , err
145+ }
146+
147+ return parsedArn .AccountID , nil
148+ }
0 commit comments