Skip to content

Commit 16704fd

Browse files
authored
Merge pull request crossplane-contrib#2243 from crossplane-contrib/feat/userpool_customattributes
feat(userpool): add custom attributes when adding schema
2 parents e6cb878 + 176905d commit 16704fd

File tree

3 files changed

+180
-20
lines changed

3 files changed

+180
-20
lines changed

pkg/clients/cognitoidentityprovider/fake/fake.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type MockCognitoIdentityProviderClient struct {
3131

3232
MockGetUserPoolMfaConfig func(*cognitoidentityprovider.GetUserPoolMfaConfigInput) (*cognitoidentityprovider.GetUserPoolMfaConfigOutput, error)
3333
MockSetUserPoolMfaConfigWithContext func(context.Context, *cognitoidentityprovider.SetUserPoolMfaConfigInput, []request.Option) (*cognitoidentityprovider.SetUserPoolMfaConfigOutput, error)
34+
MockAddCustomAttributes func(*cognitoidentityprovider.AddCustomAttributesInput) (*cognitoidentityprovider.AddCustomAttributesOutput, error)
3435

3536
Called MockCognitoIdentityProviderClientCall
3637
}
@@ -42,6 +43,11 @@ type CallGetUserPoolMfaConfig struct {
4243
Opts []request.Option
4344
}
4445

46+
// CallAddCustomAttributes to log call
47+
type CallAddCustomAttributes struct {
48+
I *cognitoidentityprovider.AddCustomAttributesInput
49+
}
50+
4551
// GetUserPoolMfaConfig calls MockGetUserPoolMfaConfig
4652
func (m *MockCognitoIdentityProviderClient) GetUserPoolMfaConfig(i *cognitoidentityprovider.GetUserPoolMfaConfigInput) (*cognitoidentityprovider.GetUserPoolMfaConfigOutput, error) {
4753
m.Called.GetUserPoolMfaConfig = append(m.Called.GetUserPoolMfaConfig, &CallGetUserPoolMfaConfig{I: i})
@@ -63,8 +69,14 @@ func (m *MockCognitoIdentityProviderClient) SetUserPoolMfaConfigWithContext(ctx
6369
return m.MockSetUserPoolMfaConfigWithContext(ctx, i, opts)
6470
}
6571

72+
func (m *MockCognitoIdentityProviderClient) AddCustomAttributes(in *cognitoidentityprovider.AddCustomAttributesInput) (*cognitoidentityprovider.AddCustomAttributesOutput, error) {
73+
m.Called.MockAddCustomAttributes = append(m.Called.MockAddCustomAttributes, &CallAddCustomAttributes{I: in})
74+
return m.MockAddCustomAttributes(in)
75+
}
76+
6677
// MockCognitoIdentityProviderClientCall to log calls
6778
type MockCognitoIdentityProviderClientCall struct {
6879
GetUserPoolMfaConfig []*CallGetUserPoolMfaConfig
6980
SetUserPoolMfaConfigWithContext []*CallSetUserPoolMfaConfigWithContext
81+
MockAddCustomAttributes []*CallAddCustomAttributes
7082
}

pkg/controller/cognitoidentityprovider/userpool/setup.go

Lines changed: 106 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package userpool
1616
import (
1717
"context"
1818
"reflect"
19+
"strings"
1920

2021
svcsdk "github.com/aws/aws-sdk-go/service/cognitoidentityprovider"
2122
svcsdkapi "github.com/aws/aws-sdk-go/service/cognitoidentityprovider/cognitoidentityprovideriface"
@@ -53,6 +54,7 @@ func SetupUserPool(mgr ctrl.Manager, o controller.Options) error {
5354
e.preObserve = preObserve
5455
e.postObserve = postObserve
5556
e.preUpdate = h.preUpdate
57+
e.postUpdate = h.postUpdate
5658
e.preDelete = preDelete
5759
e.preCreate = preCreate
5860
e.postCreate = postCreate
@@ -93,7 +95,8 @@ func SetupUserPool(mgr ctrl.Manager, o controller.Options) error {
9395
}
9496

9597
type hooks struct {
96-
client svcsdkapi.CognitoIdentityProviderAPI
98+
client svcsdkapi.CognitoIdentityProviderAPI
99+
currentCustomAttributes []*svcsdk.SchemaAttributeType
97100
}
98101

99102
func preObserve(_ context.Context, cr *svcapitypes.UserPool, obj *svcsdk.DescribeUserPoolInput) error {
@@ -313,28 +316,43 @@ func arePoliciesEqual(spec *svcapitypes.UserPoolPolicyType, current *svcsdk.User
313316
return true
314317
}
315318

316-
func areSchemaEqual(spec []*svcapitypes.SchemaAttributeType, current []*svcsdk.SchemaAttributeType) bool {
319+
func areSchemaEqual(spec []*svcapitypes.SchemaAttributeType, current []*svcsdk.SchemaAttributeType) bool { //nolint:gocyclo
317320
if spec != nil && current != nil {
318-
if len(spec) > 0 && len(spec) != len(current) {
319-
return false
320-
}
321-
322-
for i, s := range spec {
323-
switch {
324-
case pointer.StringValue(s.AttributeDataType) != pointer.StringValue(current[i].AttributeDataType),
325-
pointer.BoolValue(s.DeveloperOnlyAttribute) != pointer.BoolValue(current[i].DeveloperOnlyAttribute),
326-
pointer.BoolValue(s.Mutable) != pointer.BoolValue(current[i].Mutable),
327-
pointer.StringValue(s.Name) != pointer.StringValue(current[i].Name),
328-
pointer.StringValue(s.NumberAttributeConstraints.MaxValue) != pointer.StringValue(current[i].NumberAttributeConstraints.MaxValue),
329-
pointer.StringValue(s.NumberAttributeConstraints.MinValue) != pointer.StringValue(current[i].NumberAttributeConstraints.MinValue),
330-
pointer.BoolValue(s.Required) != pointer.BoolValue(current[i].Required),
331-
pointer.StringValue(s.StringAttributeConstraints.MaxLength) != pointer.StringValue(current[i].StringAttributeConstraints.MaxLength),
332-
pointer.StringValue(s.StringAttributeConstraints.MinLength) != pointer.StringValue(current[i].StringAttributeConstraints.MinLength):
333-
return false
321+
if len(spec) == 0 {
322+
return true
323+
}
324+
325+
for _, s := range spec {
326+
for _, cur := range current {
327+
if *s.Name != strings.TrimPrefix(*cur.Name, "custom:") {
328+
continue
329+
}
330+
switch {
331+
case pointer.StringValue(s.AttributeDataType) != pointer.StringValue(cur.AttributeDataType),
332+
pointer.BoolValue(s.DeveloperOnlyAttribute) != pointer.BoolValue(cur.DeveloperOnlyAttribute),
333+
pointer.BoolValue(s.Mutable) != pointer.BoolValue(cur.Mutable),
334+
pointer.BoolValue(s.Required) != pointer.BoolValue(cur.Required),
335+
s.NumberAttributeConstraints == nil && cur.NumberAttributeConstraints != nil,
336+
s.NumberAttributeConstraints != nil && cur.NumberAttributeConstraints == nil,
337+
s.StringAttributeConstraints == nil && cur.StringAttributeConstraints != nil,
338+
s.StringAttributeConstraints != nil && cur.StringAttributeConstraints == nil:
339+
return false
340+
}
341+
if s.NumberAttributeConstraints != nil && cur.NumberAttributeConstraints != nil {
342+
if pointer.StringValue(s.NumberAttributeConstraints.MaxValue) != pointer.StringValue(cur.NumberAttributeConstraints.MaxValue) ||
343+
pointer.StringValue(s.NumberAttributeConstraints.MinValue) != pointer.StringValue(cur.NumberAttributeConstraints.MinValue) {
344+
return false
345+
}
346+
}
347+
if s.StringAttributeConstraints != nil && cur.StringAttributeConstraints != nil {
348+
if pointer.StringValue(s.StringAttributeConstraints.MaxLength) != pointer.StringValue(cur.StringAttributeConstraints.MaxLength) ||
349+
pointer.StringValue(s.StringAttributeConstraints.MinLength) != pointer.StringValue(cur.StringAttributeConstraints.MinLength) {
350+
return false
351+
}
352+
}
334353
}
335354
}
336355
}
337-
338356
return true
339357
}
340358

@@ -460,7 +478,7 @@ func (e *hooks) areMFAConfigEqual(cr *svcapitypes.UserPool) (bool, error) {
460478
return true, nil
461479
}
462480

463-
func lateInitialize(cr *svcapitypes.UserPoolParameters, resp *svcsdk.DescribeUserPoolOutput) error {
481+
func lateInitialize(cr *svcapitypes.UserPoolParameters, resp *svcsdk.DescribeUserPoolOutput) error { //nolint:gocyclo
464482
instance := resp.UserPool
465483

466484
cr.MFAConfiguration = pointer.LateInitialize(cr.MFAConfiguration, instance.MfaConfiguration)
@@ -500,6 +518,21 @@ func lateInitialize(cr *svcapitypes.UserPoolParameters, resp *svcsdk.DescribeUse
500518
cr.VerificationMessageTemplate.DefaultEmailOption = pointer.LateInitialize(cr.VerificationMessageTemplate.DefaultEmailOption, instance.VerificationMessageTemplate.DefaultEmailOption)
501519
}
502520

521+
if cr.Schema != nil || len(cr.Schema) > 0 {
522+
for i, scheme := range cr.Schema {
523+
if scheme.StringAttributeConstraints != nil {
524+
continue
525+
}
526+
for _, schemaAttribute := range instance.SchemaAttributes {
527+
if *scheme.Name == strings.TrimPrefix(*schemaAttribute.Name, "custom:") {
528+
cr.Schema[i].StringAttributeConstraints = &svcapitypes.StringAttributeConstraintsType{
529+
MaxLength: schemaAttribute.StringAttributeConstraints.MaxLength,
530+
MinLength: schemaAttribute.StringAttributeConstraints.MinLength,
531+
}
532+
}
533+
}
534+
}
535+
}
503536
// Info: to avoid redundancy+problems, do not lateInit conflicting fields
504537
// (e.g. VerificationMessageTemplate.SmsMessage & SmsVerificationMessage)
505538

@@ -549,3 +582,56 @@ func (e *hooks) setMfaConfiguration(ctx context.Context, cr *svcapitypes.UserPoo
549582

550583
return nil
551584
}
585+
586+
func (h *hooks) postUpdate(ctx context.Context, cr *svcapitypes.UserPool, _ *svcsdk.UpdateUserPoolOutput, updateExternalUpdate managed.ExternalUpdate, updateError error) (managed.ExternalUpdate, error) {
587+
if updateError != nil {
588+
return updateExternalUpdate, updateError
589+
}
590+
591+
for _, scheme := range cr.Spec.ForProvider.Schema {
592+
isNew := true
593+
for _, schemaAttribute := range cr.Status.AtProvider.SchemaAttributes {
594+
if *scheme.Name == strings.TrimPrefix(*schemaAttribute.Name, "custom:") {
595+
isNew = false
596+
}
597+
}
598+
if isNew {
599+
_, err := h.client.AddCustomAttributes(&svcsdk.AddCustomAttributesInput{
600+
CustomAttributes: []*svcsdk.SchemaAttributeType{
601+
convertSchemaAttribute(scheme),
602+
},
603+
UserPoolId: pointer.ToOrNilIfZeroValue(meta.GetExternalName(cr)),
604+
})
605+
if err != nil {
606+
return managed.ExternalUpdate{}, err
607+
}
608+
}
609+
}
610+
return managed.ExternalUpdate{}, nil
611+
}
612+
613+
func convertSchemaAttribute(in *svcapitypes.SchemaAttributeType) *svcsdk.SchemaAttributeType {
614+
out := &svcsdk.SchemaAttributeType{
615+
AttributeDataType: in.AttributeDataType,
616+
DeveloperOnlyAttribute: in.DeveloperOnlyAttribute,
617+
Mutable: in.Mutable,
618+
Name: in.Name,
619+
Required: in.Required,
620+
}
621+
622+
if in.NumberAttributeConstraints != nil {
623+
out.NumberAttributeConstraints = &svcsdk.NumberAttributeConstraintsType{
624+
MaxValue: in.NumberAttributeConstraints.MaxValue,
625+
MinValue: in.NumberAttributeConstraints.MinValue,
626+
}
627+
}
628+
629+
if in.StringAttributeConstraints != nil {
630+
out.StringAttributeConstraints = &svcsdk.StringAttributeConstraintsType{
631+
MaxLength: in.StringAttributeConstraints.MaxLength,
632+
MinLength: in.StringAttributeConstraints.MinLength,
633+
}
634+
}
635+
636+
return out
637+
}

pkg/controller/cognitoidentityprovider/userpool/setup_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/crossplane/crossplane-runtime/pkg/test"
1111
"github.com/google/go-cmp/cmp"
1212
"github.com/pkg/errors"
13+
"k8s.io/utils/ptr"
1314

1415
svcapitypes "github.com/crossplane-contrib/provider-aws/apis/cognitoidentityprovider/v1alpha1"
1516
"github.com/crossplane-contrib/provider-aws/pkg/clients/cognitoidentityprovider/fake"
@@ -313,6 +314,7 @@ func TestIsUpToDate(t *testing.T) {
313314
cr: userPool(withSpec(svcapitypes.UserPoolParameters{
314315
Schema: []*svcapitypes.SchemaAttributeType{
315316
{
317+
Name: &testString1,
316318
NumberAttributeConstraints: &svcapitypes.NumberAttributeConstraintsType{
317319
MaxValue: &testString1,
318320
},
@@ -322,6 +324,7 @@ func TestIsUpToDate(t *testing.T) {
322324
resp: &svcsdk.DescribeUserPoolOutput{UserPool: &svcsdk.UserPoolType{
323325
SchemaAttributes: []*svcsdk.SchemaAttributeType{
324326
{
327+
Name: &testString1,
325328
NumberAttributeConstraints: &svcsdk.NumberAttributeConstraintsType{
326329
MaxValue: &testString2,
327330
},
@@ -610,3 +613,62 @@ func TestPostCreate(t *testing.T) {
610613
})
611614
}
612615
}
616+
617+
func TestPostUpdate(t *testing.T) {
618+
type args struct {
619+
cr *svcapitypes.UserPool
620+
err error
621+
resp *svcsdk.AddCustomAttributesOutput
622+
}
623+
624+
type want struct {
625+
result managed.ExternalUpdate
626+
err error
627+
}
628+
629+
cases := map[string]struct {
630+
args
631+
want
632+
}{
633+
"CreateSuccessful": {
634+
args: args{
635+
cr: userPool(
636+
withSpec(svcapitypes.UserPoolParameters{
637+
Schema: []*svcapitypes.SchemaAttributeType{
638+
{
639+
Name: ptr.To("attribute1"),
640+
},
641+
},
642+
}),
643+
withExternalName(testString1),
644+
),
645+
resp: &svcsdk.AddCustomAttributesOutput{},
646+
err: nil,
647+
},
648+
want: want{
649+
result: managed.ExternalUpdate{},
650+
err: nil,
651+
},
652+
},
653+
}
654+
655+
for name, tc := range cases {
656+
t.Run(name, func(t *testing.T) {
657+
h := &hooks{
658+
client: &fake.MockCognitoIdentityProviderClient{
659+
MockAddCustomAttributes: func(in *svcsdk.AddCustomAttributesInput) (*svcsdk.AddCustomAttributesOutput, error) {
660+
return tc.resp, nil
661+
},
662+
},
663+
}
664+
// Act
665+
result, err := h.postUpdate(context.Background(), tc.args.cr, nil, managed.ExternalUpdate{}, tc.args.err)
666+
if diff := cmp.Diff(tc.want.result, result, test.EquateConditions()); diff != "" {
667+
t.Errorf("r: -want, +got:\n%s", diff)
668+
}
669+
if diff := cmp.Diff(tc.want.err, err, test.EquateErrors()); diff != "" {
670+
t.Errorf("r: -want, +got:\n%s", diff)
671+
}
672+
})
673+
}
674+
}

0 commit comments

Comments
 (0)