Skip to content

Commit e841dc1

Browse files
committed
Add confused deputy protection for KMS calls
1 parent 12b25f5 commit e841dc1

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

cmd/server/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ func main() {
4848
burstLimit = flag.Int("burst-limit", 0, "(deprecated) number of tokens that can be consumed in a single call, use --retry-token-capacity instead")
4949
retryTokenCapacity = flag.Int("retry-token-capacity", 0, "number of tokens for client-side AWS rate-limiting on retries")
5050
encryptionCtxsArr = flag.StringArray("encryption-context", []string{}, "AWS KMS Encryption Context (e.g. 'a=b,c=d')")
51+
sourceArn = flag.String("source-arn", "", "AWS source ARN for confused deputy protection")
5152
debug = flag.Bool("debug", false, "Print debug level logs")
5253
)
5354
flag.Parse()
@@ -92,7 +93,7 @@ func main() {
9293
zap.Int("burst-limit", *burstLimit),
9394
zap.Int("retry-token-capacity", *retryTokenCapacity),
9495
)
95-
c, err := cloud.New(*region, *kmsEndpoint, *qpsLimit, *burstLimit, *retryTokenCapacity)
96+
c, err := cloud.New(*region, *kmsEndpoint, *qpsLimit, *burstLimit, *retryTokenCapacity, *sourceArn)
9697
if err != nil {
9798
zap.L().Fatal("Failed to create new KMS service", zap.Error(err))
9899
}

pkg/cloud/cloud.go

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,23 @@ import (
2323
"github.com/aws/aws-sdk-go-v2/config"
2424
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
2525
"github.com/aws/aws-sdk-go-v2/service/kms"
26+
"github.com/aws/aws-sdk-go-v2/aws/arn"
27+
smithymiddleware "github.com/aws/smithy-go/middleware"
28+
smithyhttp "github.com/aws/smithy-go/transport/http"
2629
"go.uber.org/zap"
2730
)
2831

32+
const (
33+
headerSourceArn = "x-amz-source-arn"
34+
headerSourceAccount = "x-amz-source-account"
35+
)
36+
2937
type AWSKMSv2 interface {
3038
Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error)
3139
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
3240
}
3341

34-
func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMSv2, error) {
42+
func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int, sourceArn string) (AWSKMSv2, error) {
3543
var optFns []func(*config.LoadOptions) error
3644
if region != "" {
3745
optFns = append(optFns, config.WithRegion(region))
@@ -69,6 +77,8 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
6977
return nil, fmt.Errorf("failed to create AWS config: %w", err)
7078
}
7179

80+
addConfusedDeputyHeaders(&cfg, sourceArn)
81+
7282
if cfg.Region == "" {
7383
ec2 := imds.NewFromConfig(cfg)
7484
region, err := ec2.GetRegion(context.Background(), &imds.GetRegionInput{})
@@ -88,3 +98,45 @@ func New(region, kmsEndpoint string, qps, burst, retryTokenCapacity int) (AWSKMS
8898
client := kms.NewFromConfig(cfg, kmsOptFns...)
8999
return client, nil
90100
}
101+
102+
func addConfusedDeputyHeaders(cfg *aws.Config, sourceArn string) {
103+
if sourceArn != "" {
104+
sourceAccount, err := getSourceAccount(sourceArn)
105+
if err != nil {
106+
panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceArn, err))
107+
}
108+
109+
cfg.APIOptions = append(cfg.APIOptions, func(stack *smithymiddleware.Stack) error {
110+
return stack.Build.Add(smithymiddleware.BuildMiddlewareFunc("KMSConfusedDeputyHeaders", func(
111+
ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler,
112+
) (smithymiddleware.BuildOutput, smithymiddleware.Metadata, error) {
113+
req, ok := in.Request.(*smithyhttp.Request)
114+
if ok {
115+
req.Header.Set(headerSourceAccount, sourceAccount)
116+
req.Header.Set(headerSourceArn, sourceArn)
117+
}
118+
return next.HandleBuild(ctx, in)
119+
}), smithymiddleware.Before)
120+
})
121+
122+
zap.L().Info("configuring KMS client with confused deputy headers",
123+
zap.String("sourceArn", sourceArn), zap.String("sourceAccount", sourceAccount))
124+
}
125+
}
126+
127+
// getSourceAccount constructs source account and return them for use
128+
func getSourceAccount(sourceArn string) (string, error) {
129+
// ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html)
130+
// arn:partition:service:region:account-id:resource-type/resource-id
131+
// arn:aws:eks:region:account:cluster/cluster-name
132+
if !arn.IsARN(sourceArn) {
133+
return "", fmt.Errorf("incorrect ARN format for source arn: %s", sourceArn)
134+
}
135+
136+
parsedArn, err := arn.Parse(sourceArn)
137+
if err != nil {
138+
return "", err
139+
}
140+
141+
return parsedArn.AccountID, nil
142+
}

pkg/cloud/cloud_test.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ zssmrkdYYvn9aUhjc3XK3tjAoDpsPpeBeTBamuUKDHoH/dNRXxerZ8vu6uPR3Pgs
2525
`)
2626

2727
func TestNewSessionClientWithoutEnv(t *testing.T) {
28-
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500)
28+
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500, "")
2929
assert.NoError(t, err, "Failed to create object with error (%v)", err)
3030
assert.NotNil(t, kmsObjet, "Failed to create object with error (%v)", err)
3131
}
@@ -36,7 +36,7 @@ func TestNewSessionClientWithEnv(t *testing.T) {
3636
defer os.Remove(tempFile) //nolint:errcheck
3737
os.Setenv("AWS_CA_BUNDLE", tempFile) //nolint:errcheck
3838
defer os.Unsetenv("AWS_CA_BUNDLE") //nolint:errcheck
39-
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500)
39+
kmsObjet, err := New("us-west-2", "https://kms.us-west-2.amazonaws.com", 0, 0, 500, "")
4040
assert.NoError(t, err, "Failed to create object with error (%v)", err)
4141
assert.NotNil(t, kmsObjet, "Failed to create object with error (%v)", err)
4242
}
@@ -95,7 +95,7 @@ func TestNewConfig(t *testing.T) {
9595

9696
for _, test := range tests {
9797
t.Run(test.name, func(t *testing.T) {
98-
_, err := New(test.region, test.endpoint, test.qps, test.burst, test.retryTokenCapacity)
98+
_, err := New(test.region, test.endpoint, test.qps, test.burst, test.retryTokenCapacity, "")
9999
if test.expectErr {
100100
assert.Error(t, err)
101101
} else {
@@ -104,3 +104,27 @@ func TestNewConfig(t *testing.T) {
104104
})
105105
}
106106
}
107+
108+
func TestNewWithSourceArn(t *testing.T) {
109+
client, err := New("us-east-1", "", 0, 0, 0, "arn:aws:eks:us-east-1:123456789012:cluster/test")
110+
assert.NoError(t, err)
111+
assert.NotNil(t, client)
112+
}
113+
114+
func TestNewWithEmptySourceArn(t *testing.T) {
115+
client, err := New("us-east-1", "", 0, 0, 0, "")
116+
assert.NoError(t, err)
117+
assert.NotNil(t, client)
118+
}
119+
120+
func TestNewWithMalformedSourceArn(t *testing.T) {
121+
assert.Panics(t, func() {
122+
New("us-east-1", "", 0, 0, 0, "invalid-arn-format")
123+
})
124+
}
125+
126+
func TestGetSourceAccount(t *testing.T) {
127+
account, err := getSourceAccount("arn:aws:eks:us-east-1:123456789012:cluster/test")
128+
assert.NoError(t, err)
129+
assert.Equal(t, "123456789012", account)
130+
}

0 commit comments

Comments
 (0)