Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 29 additions & 34 deletions wrappers/awskms/awskms.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ import (
"os"
"sync/atomic"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/kms"
"github.com/aws/aws-sdk-go/service/kms/kmsiface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kms"
cleanhttp "github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/go-secure-stdlib/awsutil/v2"
wrapping "github.com/openbao/go-kms-wrapping/v2"
)

Expand All @@ -34,6 +32,11 @@ const (
AwsKmsEnvelopeAesGcmEncrypt
)

type KMSAPI interface {
Decrypt(ctx context.Context, params *kms.DecryptInput, optFns ...func(*kms.Options)) (*kms.DecryptOutput, error)
Encrypt(ctx context.Context, params *kms.EncryptInput, optFns ...func(*kms.Options)) (*kms.EncryptOutput, error)
}

// Wrapper represents credentials and Key information for the KMS Key used to
// encryption and decryption
type Wrapper struct {
Expand All @@ -52,7 +55,7 @@ type Wrapper struct {

currentKeyId *atomic.Value

client kmsiface.KMSAPI
client KMSAPI

logger hclog.Logger
}
Expand All @@ -77,7 +80,7 @@ func NewWrapper() *Wrapper {
// * Passed in config map
// * Instance metadata role (access key and secret key)
// * Default values
func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) {
func (k *Wrapper) SetConfig(ctx context.Context, opt ...wrapping.Option) (*wrapping.WrapperConfig, error) {
opts, err := getOpts(opt...)
if err != nil {
return nil, err
Expand All @@ -103,7 +106,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin
k.currentKeyId.Store(k.keyId)

// Please see GetRegion for an explanation of the order in which region is parsed.
k.region, err = awsutil.GetRegion(opts.withRegion)
k.region, err = awsutil.GetRegion(ctx, opts.withRegion)
if err != nil {
return nil, err
}
Expand All @@ -127,14 +130,14 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin

// Check and set k.client
if k.client == nil {
client, err := k.GetAwsKmsClient()
client, err := k.GetAwsKmsClient(ctx)
if err != nil {
return nil, fmt.Errorf("error initializing AWS KMS wrapping client: %w", err)
}

if !k.keyNotRequired {
// Test the client connection using provided key ID
keyInfo, err := client.DescribeKey(&kms.DescribeKeyInput{
keyInfo, err := client.DescribeKey(ctx, &kms.DescribeKeyInput{
KeyId: aws.String(k.keyId),
})
if err != nil {
Expand All @@ -143,7 +146,7 @@ func (k *Wrapper) SetConfig(_ context.Context, opt ...wrapping.Option) (*wrappin
if keyInfo == nil || keyInfo.KeyMetadata == nil || keyInfo.KeyMetadata.KeyId == nil {
return nil, errors.New("no key information returned")
}
k.currentKeyId.Store(aws.StringValue(keyInfo.KeyMetadata.KeyId))
k.currentKeyId.Store(*keyInfo.KeyMetadata.KeyId)
}

k.client = client
Expand Down Expand Up @@ -174,7 +177,7 @@ func (k *Wrapper) KeyId(_ context.Context) (string, error) {
// Encrypt is used to encrypt the master key using the the AWS CMK.
// This returns the ciphertext, and/or any errors from this
// call. This should be called after the KMS client has been instantiated.
func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) {
func (k *Wrapper) Encrypt(ctx context.Context, plaintext []byte, opt ...wrapping.Option) (*wrapping.BlobInfo, error) {
if plaintext == nil {
return nil, fmt.Errorf("given plaintext for encryption is nil")
}
Expand All @@ -192,7 +195,7 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
KeyId: aws.String(k.keyId),
Plaintext: env.Key,
}
output, err := k.client.Encrypt(input)
output, err := k.client.Encrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("error encrypting data: %w", err)
}
Expand All @@ -203,8 +206,11 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
// used for encryption. This is helpful if you are looking to reencyrpt
// your data when it is not using the latest key id. See these docs relating
// to key rotation https://docs.aws.amazon.com/kms/latest/developerguide/rotate-keys.html
keyId := aws.StringValue(output.KeyId)
k.currentKeyId.Store(keyId)
var keyId string
if output.KeyId != nil {
keyId = *output.KeyId
k.currentKeyId.Store(keyId)
}

ret := &wrapping.BlobInfo{
Ciphertext: env.Ciphertext,
Expand All @@ -223,7 +229,7 @@ func (k *Wrapper) Encrypt(_ context.Context, plaintext []byte, opt ...wrapping.O
}

// Decrypt is used to decrypt the ciphertext. This should be called after Init.
func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapping.Option) ([]byte, error) {
func (k *Wrapper) Decrypt(ctx context.Context, in *wrapping.BlobInfo, opt ...wrapping.Option) ([]byte, error) {
if in == nil {
return nil, fmt.Errorf("given input for decryption is nil")
}
Expand All @@ -242,7 +248,7 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp
CiphertextBlob: in.Ciphertext,
}

output, err := k.client.Decrypt(input)
output, err := k.client.Decrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("error decrypting data: %w", err)
}
Expand All @@ -254,7 +260,7 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp
input := &kms.DecryptInput{
CiphertextBlob: in.KeyInfo.WrappedKey,
}
output, err := k.client.Decrypt(input)
output, err := k.client.Decrypt(ctx, input)
if err != nil {
return nil, fmt.Errorf("error decrypting data encryption key: %w", err)
}
Expand All @@ -277,12 +283,12 @@ func (k *Wrapper) Decrypt(_ context.Context, in *wrapping.BlobInfo, opt ...wrapp
}

// Client returns the AWS KMS client used by the wrapper.
func (k *Wrapper) Client() kmsiface.KMSAPI {
func (k *Wrapper) Client() KMSAPI {
return k.client
}

// GetAwsKmsClient returns an instance of the KMS client.
func (k *Wrapper) GetAwsKmsClient() (*kms.KMS, error) {
func (k *Wrapper) GetAwsKmsClient(ctx context.Context) (*kms.Client, error) {
credsConfig := &awsutil.CredentialsConfig{}

credsConfig.AccessKey = k.accessKey
Expand All @@ -298,27 +304,16 @@ func (k *Wrapper) GetAwsKmsClient() (*kms.KMS, error) {

credsConfig.HTTPClient = cleanhttp.DefaultClient()

creds, err := credsConfig.GenerateCredentialChain()
awsConfig, err := credsConfig.GenerateCredentialChain(ctx)
if err != nil {
return nil, err
}

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(credsConfig.Region),
HTTPClient: cleanhttp.DefaultClient(),
}

if k.endpoint != "" {
awsConfig.Endpoint = aws.String(k.endpoint)
}

sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
awsConfig.BaseEndpoint = aws.String(k.endpoint)
}

client := kms.New(sess)
client := kms.NewFromConfig(*awsConfig)

return client, nil
}
46 changes: 24 additions & 22 deletions wrappers/awskms/awskms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
package awskms

import (
"context"
"os"
"reflect"
"testing"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go-v2/aws"
wrapping "github.com/openbao/go-kms-wrapping/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -21,16 +20,18 @@ func TestAwsKmsWrapper(t *testing.T) {
keyId: aws.String(awsTestKeyId),
}

_, err := s.SetConfig(nil)
oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId)
defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId)

os.Unsetenv(EnvAwsKmsWrapperKeyId)
_, err := s.SetConfig(t.Context(), WithRegion("dummy"))
if err == nil {
t.Fatal("expected error when AwsKms wrapping key ID is not provided")
}

// Set the key
oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId)
os.Setenv(EnvAwsKmsWrapperKeyId, awsTestKeyId)
defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId)
_, err = s.SetConfig(nil)
_, err = s.SetConfig(t.Context(), WithRegion("dummy"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -54,7 +55,7 @@ func TestAwsKmsWrapper_IgnoreEnv(t *testing.T) {
"endpoint": "my-endpoint",
}

_, err := wrapper.SetConfig(context.Background(), wrapping.WithConfigMap(config))
_, err := wrapper.SetConfig(t.Context(), wrapping.WithConfigMap(config), WithRegion("dummy"))
assert.NoError(t, err)

require.Equal(t, config["access_key"], wrapper.accessKey)
Expand All @@ -64,17 +65,14 @@ func TestAwsKmsWrapper_IgnoreEnv(t *testing.T) {
}

func TestAwsKmsWrapper_Lifecycle(t *testing.T) {
if os.Getenv(EnvAwsKmsWrapperKeyId) == "" && os.Getenv(EnvVaultAwsKmsSealKeyId) == "" {
t.SkipNow()
}
s := NewWrapper()
s.client = &mockClient{
keyId: aws.String(awsTestKeyId),
}
oldKeyId := os.Getenv(EnvAwsKmsWrapperKeyId)
os.Setenv(EnvAwsKmsWrapperKeyId, awsTestKeyId)
defer os.Setenv(EnvAwsKmsWrapperKeyId, oldKeyId)
testEncryptionRoundTrip(t, s)
testEncryptionRoundTrip(t, s, WithRegion("dummy"))
}

// This test executes real calls. The calls themselves should be free,
Expand All @@ -94,15 +92,19 @@ func TestAccAwsKmsWrapper_Lifecycle(t *testing.T) {
testEncryptionRoundTrip(t, s)
}

func testEncryptionRoundTrip(t *testing.T, w *Wrapper) {
w.SetConfig(context.Background())
func testEncryptionRoundTrip(t *testing.T, w *Wrapper, opt ...wrapping.Option) {
_, err := w.SetConfig(t.Context(), opt...)
if err != nil {
t.Fatalf("err: %s", err.Error())
}

input := []byte("foo")
swi, err := w.Encrypt(context.Background(), input, nil)
swi, err := w.Encrypt(t.Context(), input, nil)
if err != nil {
t.Fatalf("err: %s", err.Error())
}

pt, err := w.Decrypt(context.Background(), swi, nil)
pt, err := w.Decrypt(t.Context(), swi, nil)
if err != nil {
t.Fatalf("err: %s", err.Error())
}
Expand Down Expand Up @@ -178,27 +180,27 @@ func TestAwsKmsWrapper_custom_endpoint(t *testing.T) {
if tc.Config != nil {
cfg = tc.Config
}
if _, err := s.SetConfig(context.Background(), wrapping.WithConfigMap(cfg)); err != nil {
if _, err := s.SetConfig(t.Context(), wrapping.WithConfigMap(cfg), WithRegion("dummy")); err != nil {
t.Fatalf("error setting config: %s", err)
}

// call GetAwsKmsClient() to get the configured client and verify it's
// endpoint
k, err := s.GetAwsKmsClient()
k, err := s.GetAwsKmsClient(t.Context())
if err != nil {
t.Fatal(err)
}

if tc.Expected == nil && k.Config.Endpoint != nil {
t.Fatalf("Expected nil endpoint, got: (%s)", *k.Config.Endpoint)
if tc.Expected == nil && k.Options().BaseEndpoint != nil {
t.Fatalf("Expected nil endpoint, got: (%s)", *k.Options().BaseEndpoint)
}

if tc.Expected != nil {
if k.Config.Endpoint == nil {
if k.Options().BaseEndpoint == nil {
t.Fatal("expected custom endpoint, but config was nil")
}
if *k.Config.Endpoint != *tc.Expected {
t.Fatalf("expected custom endpoint (%s), got: (%s)", *tc.Expected, *k.Config.Endpoint)
if *k.Options().BaseEndpoint != *tc.Expected {
t.Fatalf("expected custom endpoint (%s), got: (%s)", *tc.Expected, *k.Options().BaseEndpoint)
}
}

Expand Down
21 changes: 17 additions & 4 deletions wrappers/awskms/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,29 @@ go 1.24.0
replace github.com/openbao/go-kms-wrapping/v2 => ../../

require (
github.com/aws/aws-sdk-go v1.55.5
github.com/aws/aws-sdk-go-v2 v1.41.1
github.com/aws/aws-sdk-go-v2/service/kms v1.49.5
github.com/hashicorp/go-cleanhttp v0.5.2
github.com/hashicorp/go-hclog v1.6.3
github.com/hashicorp/go-secure-stdlib/awsutil v0.3.0
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.1.2
github.com/openbao/go-kms-wrapping/v2 v2.2.0
github.com/stretchr/testify v1.10.0
)

require (
github.com/aws/aws-sdk-go-v2/config v1.28.5 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.46 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.20 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect
github.com/aws/aws-sdk-go-v2/service/iam v1.38.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.24.6 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.1 // indirect
github.com/aws/smithy-go v1.24.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
Expand All @@ -23,11 +37,10 @@ require (
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
github.com/hashicorp/go-sockaddr v1.0.6 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
golang.org/x/sys v0.29.0 // indirect
Expand Down
Loading
Loading