Skip to content

Commit 457696c

Browse files
authored
PutSecureParameter and recursive GetAllParametersByPath (#27)
1 parent 40470bb commit 457696c

File tree

2 files changed

+211
-9
lines changed

2 files changed

+211
-9
lines changed

parameter_store_client.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package awsssm
22

33
import (
44
"errors"
5+
56
"github.com/aws/aws-sdk-go/aws"
67
"github.com/aws/aws-sdk-go/aws/awserr"
78
"github.com/aws/aws-sdk-go/aws/session"
@@ -18,6 +19,7 @@ var (
1819
type ssmClient interface {
1920
GetParametersByPathPages(input *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error
2021
GetParameter(input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error)
22+
PutParameter(input *ssm.PutParameterInput) (*ssm.PutParameterOutput, error)
2123
}
2224

2325
//ParameterStore holds all the methods tha are supported against AWS Parameter Store
@@ -30,6 +32,8 @@ type ParameterStore struct {
3032
//Will return /my-service/dev/param-a, /my-service/dev/param-b, etc... but will not return recursive paths
3133
//the `ssm:GetAllParametersByPath` permission is required
3234
//to the `arn:aws:ssm:aws-region:aws-account-id:/my-service/dev/*`
35+
//
36+
//This will also page through and return all elements in the hierarchy, non-recursively
3337
func (ps *ParameterStore) GetAllParametersByPath(path string, decrypt bool) (*Parameters, error) {
3438
var input = &ssm.GetParametersByPathInput{}
3539
input.SetWithDecryption(decrypt)
@@ -81,6 +85,55 @@ func (ps *ParameterStore) getParameter(input *ssm.GetParameterInput) (*Parameter
8185
}, nil
8286
}
8387

88+
//PutSecureParameter is setting the parameter with the given name to a passed in value.
89+
//Allow overwriting the value of the parameter already exists, otherwise an error is returned
90+
//For example a request with name as '/my-service/dev/param-1':
91+
//Will set the parameter value if exists or ErrParameterInvalidName if parameter already exists or is empty
92+
// and `overwrite` is false. The `ssm:PutParameter` permission is required to the
93+
//`arn:aws:ssm:aws-region:aws-account-id:/my-service/dev/param-1` resource
94+
func (ps *ParameterStore) PutSecureParameter(name, value string, overwrite bool) error {
95+
return ps.putSecureParameterWrapper(name, value, "", overwrite)
96+
}
97+
98+
//PutSecureParameterWithCMK is the same as PutSecureParameter but with a passed in CMK (Customer Master Key)
99+
//For example a request with name as '/my-service/dev/param-1' and a `kmsID` of 'foo':
100+
//Will set the parameter value if exists or ErrParameterInvalidName if parameter already exists or is empty
101+
// and `overwrite` is false. The `ssm:PutParameter` permission is required to the
102+
//`arn:aws:ssm:aws-region:aws-account-id:/my-service/dev/param-1` resource
103+
// The `kms:Encrypt` permission is required to the `arn:aws:kms:us-east-1:710015040892:key/foo`
104+
func (ps *ParameterStore) PutSecureParameterWithCMK(name, value string, overwrite bool, kmsID string) error {
105+
return ps.putSecureParameterWrapper(name, value, kmsID, overwrite)
106+
}
107+
func (ps *ParameterStore) putSecureParameterWrapper(name, value, kmsID string, overwrite bool) error {
108+
if name == "" {
109+
return ErrParameterInvalidName
110+
}
111+
input := &ssm.PutParameterInput{}
112+
input.SetName(name)
113+
input.SetType("SecureString")
114+
input.SetValue(value)
115+
if kmsID != "" {
116+
input.SetKeyId(kmsID)
117+
}
118+
input.SetOverwrite(overwrite)
119+
120+
if err := input.Validate(); err != nil {
121+
return err
122+
}
123+
124+
return ps.putParameter(input)
125+
}
126+
func (ps *ParameterStore) putParameter(input *ssm.PutParameterInput) error {
127+
_, err := ps.ssm.PutParameter(input)
128+
if err != nil {
129+
if awsError, ok := err.(awserr.Error); ok && awsError.Code() == ssm.ErrCodeParameterAlreadyExists {
130+
return ErrParameterInvalidName
131+
}
132+
return err
133+
}
134+
return nil
135+
}
136+
84137
//NewParameterStoreWithClient is creating a new ParameterStore with the given ssm Client
85138
func NewParameterStoreWithClient(client ssmClient) *ParameterStore {
86139
return &ParameterStore{ssm: client}

parameter_store_client_test.go

Lines changed: 158 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package awsssm
22

33
import (
44
"errors"
5-
"github.com/aws/aws-sdk-go/aws/awserr"
6-
"github.com/aws/aws-sdk-go/service/ssm"
75
"reflect"
86
"testing"
7+
8+
"github.com/aws/aws-sdk-go/aws/awserr"
9+
"github.com/aws/aws-sdk-go/service/ssm"
910
)
1011

1112
var param1 = new(ssm.Parameter).
@@ -18,6 +19,7 @@ var param2 = new(ssm.Parameter).
1819
SetValue("rds.something.aws.com").
1920
SetARN("arn:aws:ssm:us-east-2:aws-account-id:/my-service/dev/DB_HOST")
2021

22+
// return s.GetParametersByPathOutput, s.GetParametersByPathError
2123
var param3 = new(ssm.Parameter).
2224
SetName("/my-service/dev/DB_USERNAME").
2325
SetValue("username").
@@ -31,13 +33,14 @@ type stubGetParametersByPathOutput struct {
3133
}
3234

3335
type stubSSMClient struct {
34-
GetParametersByPathOutput []stubGetParametersByPathOutput
35-
GetParametersByPathError error
36-
GetParameterOutput *ssm.GetParameterOutput
37-
GetParameterError error
36+
GetParametersByPathOutput []stubGetParametersByPathOutput
37+
GetParametersByPathError error
38+
GetParameterOutput *ssm.GetParameterOutput
39+
GetParameterError error
40+
PutParameterInputReceived *ssm.PutParameterInput
3841
}
3942

40-
func (s stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error {
43+
func (s *stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error {
4144
if s.GetParametersByPathError == nil {
4245
for _, output := range s.GetParametersByPathOutput {
4346
done := fn(&output.Output, output.MoreParamsLeft)
@@ -49,10 +52,17 @@ func (s stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathIn
4952
return s.GetParametersByPathError
5053
}
5154

52-
func (s stubSSMClient) GetParameter(input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error) {
55+
func (s *stubSSMClient) GetParameter(input *ssm.GetParameterInput) (*ssm.GetParameterOutput, error) {
5356
return s.GetParameterOutput, s.GetParameterError
5457
}
5558

59+
// we return nothing becuase the actual response is pretty boring. Just a version number. We DO
60+
// want to track was is input because there is a _little_ business logic around that
61+
func (s *stubSSMClient) PutParameter(input *ssm.PutParameterInput) (*ssm.PutParameterOutput, error) {
62+
s.PutParameterInputReceived = input
63+
return nil, nil
64+
}
65+
5666
func TestClient_GetParametersByPath(t *testing.T) {
5767
tests := []struct {
5868
name string
@@ -71,6 +81,12 @@ func TestClient_GetParametersByPath(t *testing.T) {
7181
Parameters: getParameters(),
7282
},
7383
},
84+
{
85+
MoreParamsLeft: true,
86+
Output: ssm.GetParametersByPathOutput{
87+
Parameters: getParameters2(),
88+
},
89+
},
7490
{
7591
MoreParamsLeft: false,
7692
Output: ssm.GetParametersByPathOutput{
@@ -110,7 +126,7 @@ func TestClient_GetParametersByPath(t *testing.T) {
110126
t.Errorf(`Unexpected error: got %d, expected %d`, err, test.expectedError)
111127
}
112128
if !reflect.DeepEqual(parameters, test.expectedOutput) {
113-
t.Error(`Unexpected parameters`, *parameters, *test.expectedOutput)
129+
t.Errorf(`Unexpected parameters: got: %+v, expected: %+v`, *parameters, *test.expectedOutput)
114130
}
115131
})
116132
}
@@ -122,6 +138,12 @@ func getParameters() []*ssm.Parameter {
122138
}
123139
}
124140

141+
func getParameters2() []*ssm.Parameter {
142+
return []*ssm.Parameter{
143+
param3,
144+
}
145+
}
146+
125147
func TestParameterStore_GetParameter(t *testing.T) {
126148
value := "something-secure"
127149
tests := []struct {
@@ -172,3 +194,130 @@ func TestParameterStore_GetParameter(t *testing.T) {
172194
})
173195
}
174196
}
197+
198+
func TestParameterStore_PutSecureParameter(t *testing.T) {
199+
paramName := "foo"
200+
paramValue := "baz"
201+
paramType := "SecureString"
202+
overwriteTrue := true
203+
overwriteFalse := false
204+
205+
tests := []struct {
206+
name string
207+
ssmClient *stubSSMClient
208+
parameterName string
209+
parameterValue string
210+
overwrite bool
211+
expectedError error
212+
expectedInput *ssm.PutParameterInput
213+
}{
214+
{
215+
name: "Failed Empty name",
216+
ssmClient: &stubSSMClient{},
217+
parameterName: "",
218+
parameterValue: "",
219+
expectedError: ErrParameterInvalidName,
220+
},
221+
{
222+
name: "Set Correct Defaults",
223+
ssmClient: &stubSSMClient{},
224+
parameterName: paramName,
225+
parameterValue: paramValue,
226+
expectedInput: &ssm.PutParameterInput{
227+
Name: &paramName,
228+
Type: &paramType,
229+
Value: &paramValue,
230+
Overwrite: &overwriteFalse,
231+
},
232+
},
233+
{
234+
name: "Overwrite Changes Propagate",
235+
ssmClient: &stubSSMClient{},
236+
parameterName: paramName,
237+
parameterValue: paramValue,
238+
overwrite: overwriteTrue,
239+
expectedInput: &ssm.PutParameterInput{
240+
Name: &paramName,
241+
Type: &paramType,
242+
Value: &paramValue,
243+
Overwrite: &overwriteTrue,
244+
},
245+
},
246+
}
247+
for _, test := range tests {
248+
t.Run(test.name, func(t *testing.T) {
249+
client := NewParameterStoreWithClient(test.ssmClient)
250+
err := client.PutSecureParameter(test.parameterName, test.parameterValue, test.overwrite)
251+
if err != test.expectedError {
252+
t.Errorf(`Unexpected error: got %d, expected %d`, err, test.expectedError)
253+
}
254+
if !reflect.DeepEqual(test.ssmClient.PutParameterInputReceived, test.expectedInput) {
255+
t.Errorf(`Unexpected parameter: got %v, expected %v`, test.ssmClient.PutParameterInputReceived, test.expectedInput)
256+
}
257+
})
258+
}
259+
}
260+
261+
func TestParameterStore_PutSecureParameterWithCMK(t *testing.T) {
262+
paramName := "foo"
263+
paramValue := "baz"
264+
paramType := "SecureString"
265+
overwriteFalse := false
266+
kmsID := "super-secret-kms"
267+
tests := []struct {
268+
name string
269+
ssmClient *stubSSMClient
270+
parameterName string
271+
parameterValue string
272+
overwrite bool
273+
kmsID string
274+
expectedError error
275+
expectedInput *ssm.PutParameterInput
276+
}{
277+
{
278+
name: "Failed Empty name",
279+
ssmClient: &stubSSMClient{},
280+
parameterName: "",
281+
parameterValue: "",
282+
expectedError: ErrParameterInvalidName,
283+
},
284+
{
285+
name: "Set Correct Defaults",
286+
ssmClient: &stubSSMClient{},
287+
parameterName: paramName,
288+
parameterValue: paramValue,
289+
expectedInput: &ssm.PutParameterInput{
290+
Name: &paramName,
291+
Overwrite: &overwriteFalse,
292+
Type: &paramType,
293+
Value: &paramValue,
294+
},
295+
},
296+
{
297+
name: "KMS ID Changes Propagate",
298+
ssmClient: &stubSSMClient{},
299+
parameterName: paramName,
300+
parameterValue: paramValue,
301+
kmsID: kmsID,
302+
expectedInput: &ssm.PutParameterInput{
303+
KeyId: &kmsID,
304+
Name: &paramName,
305+
Overwrite: &overwriteFalse,
306+
Type: &paramType,
307+
Value: &paramValue,
308+
},
309+
},
310+
}
311+
for _, test := range tests {
312+
t.Run(test.name, func(t *testing.T) {
313+
client := NewParameterStoreWithClient(test.ssmClient)
314+
err := client.PutSecureParameterWithCMK(test.parameterName, test.parameterValue, test.overwrite, test.kmsID)
315+
if err != test.expectedError {
316+
t.Errorf(`Unexpected error: got %d, expected %d`, err, test.expectedError)
317+
}
318+
if !reflect.DeepEqual(test.ssmClient.PutParameterInputReceived, test.expectedInput) {
319+
t.Errorf(`Unexpected parameter: got %v, expected %v`, test.ssmClient.PutParameterInputReceived, test.expectedInput)
320+
}
321+
})
322+
}
323+
}

0 commit comments

Comments
 (0)