@@ -23,15 +23,23 @@ import (
2323 "github.com/aws/aws-sdk-go-v2/config"
2424 "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
2525 "github.com/aws/aws-sdk-go-v2/service/kms"
26+ "github.com/aws/aws-sdk-go-v2/aws/arn"
27+ smithymiddleware "github.com/aws/smithy-go/middleware"
28+ smithyhttp "github.com/aws/smithy-go/transport/http"
2629 "go.uber.org/zap"
2730)
2831
32+ const (
33+ headerSourceArn = "x-amz-source-arn"
34+ headerSourceAccount = "x-amz-source-account"
35+ )
36+
2937type AWSKMSv2 interface {
3038 Encrypt (ctx context.Context , params * kms.EncryptInput , optFns ... func (* kms.Options )) (* kms.EncryptOutput , error )
3139 Decrypt (ctx context.Context , params * kms.DecryptInput , optFns ... func (* kms.Options )) (* kms.DecryptOutput , error )
3240}
3341
34- func New (region , kmsEndpoint string , qps , burst , retryTokenCapacity int ) (AWSKMSv2 , error ) {
42+ func New (region , kmsEndpoint string , qps , burst , retryTokenCapacity int , sourceArn string ) (AWSKMSv2 , error ) {
3543 var optFns []func (* config.LoadOptions ) error
3644 if region != "" {
3745 optFns = append (optFns , config .WithRegion (region ))
@@ -69,6 +77,8 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
6977 return nil , fmt .Errorf ("failed to create AWS config: %w" , err )
7078 }
7179
80+ addConfusedDeputyHeaders (& cfg , sourceArn )
81+
7282 if cfg .Region == "" {
7383 ec2 := imds .NewFromConfig (cfg )
7484 region , err := ec2 .GetRegion (context .Background (), & imds.GetRegionInput {})
@@ -88,3 +98,45 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
8898 client := kms .NewFromConfig (cfg , kmsOptFns ... )
8999 return client , nil
90100}
101+
102+ func addConfusedDeputyHeaders (cfg * aws.Config , sourceArn string ) {
103+ if sourceArn != "" {
104+ sourceAccount , err := getSourceAccount (sourceArn )
105+ if err != nil {
106+ panic (fmt .Sprintf ("%s is not a valid arn, err: %v" , sourceArn , err ))
107+ }
108+
109+ cfg .APIOptions = append (cfg .APIOptions , func (stack * smithymiddleware.Stack ) error {
110+ return stack .Build .Add (smithymiddleware .BuildMiddlewareFunc ("KMSConfusedDeputyHeaders" , func (
111+ ctx context.Context , in smithymiddleware.BuildInput , next smithymiddleware.BuildHandler ,
112+ ) (smithymiddleware.BuildOutput , smithymiddleware.Metadata , error ) {
113+ req , ok := in .Request .(* smithyhttp.Request )
114+ if ok {
115+ req .Header .Set (headerSourceAccount , sourceAccount )
116+ req .Header .Set (headerSourceArn , sourceArn )
117+ }
118+ return next .HandleBuild (ctx , in )
119+ }), smithymiddleware .Before )
120+ })
121+
122+ zap .L ().Info ("configuring KMS client with confused deputy headers" ,
123+ zap .String ("sourceArn" , sourceArn ), zap .String ("sourceAccount" , sourceAccount ))
124+ }
125+ }
126+
127+ // getSourceAccount constructs source account and return them for use
128+ func getSourceAccount (sourceArn string ) (string , error ) {
129+ // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
130+ // arn:partition:service:region:account-id:resource-type/resource-id
131+ // arn:aws:eks:region:account:cluster/cluster-name
132+ if ! arn .IsARN (sourceArn ) {
133+ return "" , fmt .Errorf ("incorrect ARN format for source arn: %s" , sourceArn )
134+ }
135+
136+ parsedArn , err := arn .Parse (sourceArn )
137+ if err != nil {
138+ return "" , err
139+ }
140+
141+ return parsedArn .AccountID , nil
142+ }
0 commit comments