@@ -2,10 +2,11 @@ package awsssm
22
33import (
44 "errors"
5- "github.com/aws/aws-sdk-go/aws/awserr"
6- "github.com/aws/aws-sdk-go/service/ssm"
75 "reflect"
86 "testing"
7+
8+ "github.com/aws/aws-sdk-go/aws/awserr"
9+ "github.com/aws/aws-sdk-go/service/ssm"
910)
1011
1112var param1 = new (ssm.Parameter ).
@@ -18,6 +19,7 @@ var param2 = new(ssm.Parameter).
1819 SetValue ("rds.something.aws.com" ).
1920 SetARN ("arn:aws:ssm:us-east-2:aws-account-id:/my-service/dev/DB_HOST" )
2021
22+ // return s.GetParametersByPathOutput, s.GetParametersByPathError
2123var param3 = new (ssm.Parameter ).
2224 SetName ("/my-service/dev/DB_USERNAME" ).
2325 SetValue ("username" ).
@@ -31,13 +33,14 @@ type stubGetParametersByPathOutput struct {
3133}
3234
3335type stubSSMClient struct {
34- GetParametersByPathOutput []stubGetParametersByPathOutput
35- GetParametersByPathError error
36- GetParameterOutput * ssm.GetParameterOutput
37- GetParameterError error
36+ GetParametersByPathOutput []stubGetParametersByPathOutput
37+ GetParametersByPathError error
38+ GetParameterOutput * ssm.GetParameterOutput
39+ GetParameterError error
40+ PutParameterInputReceived * ssm.PutParameterInput
3841}
3942
40- func (s stubSSMClient ) GetParametersByPathPages (input * ssm.GetParametersByPathInput , fn func (* ssm.GetParametersByPathOutput , bool ) bool ) error {
43+ func (s * stubSSMClient ) GetParametersByPathPages (input * ssm.GetParametersByPathInput , fn func (* ssm.GetParametersByPathOutput , bool ) bool ) error {
4144 if s .GetParametersByPathError == nil {
4245 for _ , output := range s .GetParametersByPathOutput {
4346 done := fn (& output .Output , output .MoreParamsLeft )
@@ -49,10 +52,17 @@ func (s stubSSMClient) GetParametersByPathPages(input *ssm.GetParametersByPathIn
4952 return s .GetParametersByPathError
5053}
5154
52- func (s stubSSMClient ) GetParameter (input * ssm.GetParameterInput ) (* ssm.GetParameterOutput , error ) {
55+ func (s * stubSSMClient ) GetParameter (input * ssm.GetParameterInput ) (* ssm.GetParameterOutput , error ) {
5356 return s .GetParameterOutput , s .GetParameterError
5457}
5558
59+ // we return nothing becuase the actual response is pretty boring. Just a version number. We DO
60+ // want to track was is input because there is a _little_ business logic around that
61+ func (s * stubSSMClient ) PutParameter (input * ssm.PutParameterInput ) (* ssm.PutParameterOutput , error ) {
62+ s .PutParameterInputReceived = input
63+ return nil , nil
64+ }
65+
5666func TestClient_GetParametersByPath (t * testing.T ) {
5767 tests := []struct {
5868 name string
@@ -71,6 +81,12 @@ func TestClient_GetParametersByPath(t *testing.T) {
7181 Parameters : getParameters (),
7282 },
7383 },
84+ {
85+ MoreParamsLeft : true ,
86+ Output : ssm.GetParametersByPathOutput {
87+ Parameters : getParameters2 (),
88+ },
89+ },
7490 {
7591 MoreParamsLeft : false ,
7692 Output : ssm.GetParametersByPathOutput {
@@ -110,7 +126,7 @@ func TestClient_GetParametersByPath(t *testing.T) {
110126 t .Errorf (`Unexpected error: got %d, expected %d` , err , test .expectedError )
111127 }
112128 if ! reflect .DeepEqual (parameters , test .expectedOutput ) {
113- t .Error (`Unexpected parameters` , * parameters , * test .expectedOutput )
129+ t .Errorf (`Unexpected parameters: got: %+v, expected: %+v ` , * parameters , * test .expectedOutput )
114130 }
115131 })
116132 }
@@ -122,6 +138,12 @@ func getParameters() []*ssm.Parameter {
122138 }
123139}
124140
141+ func getParameters2 () []* ssm.Parameter {
142+ return []* ssm.Parameter {
143+ param3 ,
144+ }
145+ }
146+
125147func TestParameterStore_GetParameter (t * testing.T ) {
126148 value := "something-secure"
127149 tests := []struct {
@@ -172,3 +194,130 @@ func TestParameterStore_GetParameter(t *testing.T) {
172194 })
173195 }
174196}
197+
198+ func TestParameterStore_PutSecureParameter (t * testing.T ) {
199+ paramName := "foo"
200+ paramValue := "baz"
201+ paramType := "SecureString"
202+ overwriteTrue := true
203+ overwriteFalse := false
204+
205+ tests := []struct {
206+ name string
207+ ssmClient * stubSSMClient
208+ parameterName string
209+ parameterValue string
210+ overwrite bool
211+ expectedError error
212+ expectedInput * ssm.PutParameterInput
213+ }{
214+ {
215+ name : "Failed Empty name" ,
216+ ssmClient : & stubSSMClient {},
217+ parameterName : "" ,
218+ parameterValue : "" ,
219+ expectedError : ErrParameterInvalidName ,
220+ },
221+ {
222+ name : "Set Correct Defaults" ,
223+ ssmClient : & stubSSMClient {},
224+ parameterName : paramName ,
225+ parameterValue : paramValue ,
226+ expectedInput : & ssm.PutParameterInput {
227+ Name : & paramName ,
228+ Type : & paramType ,
229+ Value : & paramValue ,
230+ Overwrite : & overwriteFalse ,
231+ },
232+ },
233+ {
234+ name : "Overwrite Changes Propagate" ,
235+ ssmClient : & stubSSMClient {},
236+ parameterName : paramName ,
237+ parameterValue : paramValue ,
238+ overwrite : overwriteTrue ,
239+ expectedInput : & ssm.PutParameterInput {
240+ Name : & paramName ,
241+ Type : & paramType ,
242+ Value : & paramValue ,
243+ Overwrite : & overwriteTrue ,
244+ },
245+ },
246+ }
247+ for _ , test := range tests {
248+ t .Run (test .name , func (t * testing.T ) {
249+ client := NewParameterStoreWithClient (test .ssmClient )
250+ err := client .PutSecureParameter (test .parameterName , test .parameterValue , test .overwrite )
251+ if err != test .expectedError {
252+ t .Errorf (`Unexpected error: got %d, expected %d` , err , test .expectedError )
253+ }
254+ if ! reflect .DeepEqual (test .ssmClient .PutParameterInputReceived , test .expectedInput ) {
255+ t .Errorf (`Unexpected parameter: got %v, expected %v` , test .ssmClient .PutParameterInputReceived , test .expectedInput )
256+ }
257+ })
258+ }
259+ }
260+
261+ func TestParameterStore_PutSecureParameterWithCMK (t * testing.T ) {
262+ paramName := "foo"
263+ paramValue := "baz"
264+ paramType := "SecureString"
265+ overwriteFalse := false
266+ kmsID := "super-secret-kms"
267+ tests := []struct {
268+ name string
269+ ssmClient * stubSSMClient
270+ parameterName string
271+ parameterValue string
272+ overwrite bool
273+ kmsID string
274+ expectedError error
275+ expectedInput * ssm.PutParameterInput
276+ }{
277+ {
278+ name : "Failed Empty name" ,
279+ ssmClient : & stubSSMClient {},
280+ parameterName : "" ,
281+ parameterValue : "" ,
282+ expectedError : ErrParameterInvalidName ,
283+ },
284+ {
285+ name : "Set Correct Defaults" ,
286+ ssmClient : & stubSSMClient {},
287+ parameterName : paramName ,
288+ parameterValue : paramValue ,
289+ expectedInput : & ssm.PutParameterInput {
290+ Name : & paramName ,
291+ Overwrite : & overwriteFalse ,
292+ Type : & paramType ,
293+ Value : & paramValue ,
294+ },
295+ },
296+ {
297+ name : "KMS ID Changes Propagate" ,
298+ ssmClient : & stubSSMClient {},
299+ parameterName : paramName ,
300+ parameterValue : paramValue ,
301+ kmsID : kmsID ,
302+ expectedInput : & ssm.PutParameterInput {
303+ KeyId : & kmsID ,
304+ Name : & paramName ,
305+ Overwrite : & overwriteFalse ,
306+ Type : & paramType ,
307+ Value : & paramValue ,
308+ },
309+ },
310+ }
311+ for _ , test := range tests {
312+ t .Run (test .name , func (t * testing.T ) {
313+ client := NewParameterStoreWithClient (test .ssmClient )
314+ err := client .PutSecureParameterWithCMK (test .parameterName , test .parameterValue , test .overwrite , test .kmsID )
315+ if err != test .expectedError {
316+ t .Errorf (`Unexpected error: got %d, expected %d` , err , test .expectedError )
317+ }
318+ if ! reflect .DeepEqual (test .ssmClient .PutParameterInputReceived , test .expectedInput ) {
319+ t .Errorf (`Unexpected parameter: got %v, expected %v` , test .ssmClient .PutParameterInputReceived , test .expectedInput )
320+ }
321+ })
322+ }
323+ }
0 commit comments