Skip to content

Commit ec28bdf

Browse files
committed
add tests
Signed-off-by: jose.vazquez <[email protected]>
1 parent 189657c commit ec28bdf

File tree

12 files changed

+1498
-43
lines changed

12 files changed

+1498
-43
lines changed

pkg/controller/state/reapply_test.go

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
// Copyright 2025 MongoDB Inc
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package state
16+
17+
import (
18+
"context"
19+
"errors"
20+
"strconv"
21+
"testing"
22+
"time"
23+
24+
"github.com/stretchr/testify/assert"
25+
corev1 "k8s.io/api/core/v1"
26+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
27+
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
28+
"k8s.io/apimachinery/pkg/runtime"
29+
"sigs.k8s.io/controller-runtime/pkg/client"
30+
"sigs.k8s.io/controller-runtime/pkg/client/fake"
31+
"sigs.k8s.io/controller-runtime/pkg/client/interceptor"
32+
)
33+
34+
func TestReapplyPeriod(t *testing.T) {
35+
tests := []struct {
36+
name string
37+
annotations map[string]string
38+
want time.Duration
39+
wantOk bool
40+
wantErr string
41+
}{
42+
{
43+
name: "valid period",
44+
annotations: map[string]string{"mongodb.com/reapply-period": "2h"},
45+
want: 2 * time.Hour,
46+
wantOk: true,
47+
},
48+
{
49+
name: "period missing",
50+
annotations: map[string]string{},
51+
want: 0,
52+
wantOk: false,
53+
},
54+
{
55+
name: "period invalid format",
56+
annotations: map[string]string{"mongodb.com/reapply-period": "not-a-period"},
57+
want: 0,
58+
wantOk: false,
59+
wantErr: "invalid duration",
60+
},
61+
{
62+
name: "period too short",
63+
annotations: map[string]string{"mongodb.com/reapply-period": "30s"},
64+
want: 0,
65+
wantOk: false,
66+
wantErr: "must be greater than 60m",
67+
},
68+
}
69+
70+
for _, tc := range tests {
71+
t.Run(tc.name, func(t *testing.T) {
72+
obj := newUnstructuredObj(tc.annotations)
73+
got, ok, err := ReapplyPeriod(obj)
74+
if tc.wantErr != "" {
75+
assert.ErrorContains(t, err, tc.wantErr)
76+
}
77+
assert.Equal(t, tc.wantOk, ok)
78+
assert.Equal(t, tc.want, got)
79+
})
80+
}
81+
}
82+
83+
func TestReapplyTimestamp(t *testing.T) {
84+
now := time.Now().UnixMilli()
85+
tests := []struct {
86+
name string
87+
annotations map[string]string
88+
want int64
89+
wantOk bool
90+
wantErr string
91+
}{
92+
{
93+
name: "valid timestamp",
94+
annotations: map[string]string{AnnotationReapplyTimestamp: strconv.FormatInt(now, 10)},
95+
want: now,
96+
wantOk: true,
97+
},
98+
{
99+
name: "timestamp missing",
100+
annotations: map[string]string{},
101+
want: 0,
102+
wantOk: false,
103+
},
104+
{
105+
name: "invalid timestamp",
106+
annotations: map[string]string{AnnotationReapplyTimestamp: "not-a-number"},
107+
want: 0,
108+
wantOk: false,
109+
wantErr: "parsing \"not-a-number\": invalid syntax",
110+
},
111+
}
112+
113+
for _, tc := range tests {
114+
t.Run(tc.name, func(t *testing.T) {
115+
obj := newUnstructuredObj(tc.annotations)
116+
got, ok, err := ReapplyTimestamp(obj)
117+
assertErrContains(t, tc.wantErr, err)
118+
assert.Equal(t, tc.wantOk, ok)
119+
if tc.wantOk {
120+
assert.Equal(t, tc.want, got.UnixMilli())
121+
}
122+
})
123+
}
124+
}
125+
126+
func TestShouldReapply(t *testing.T) {
127+
past := time.Now().Add(-2 * time.Hour).UnixMilli()
128+
future := time.Now().Add(2 * time.Hour).UnixMilli()
129+
130+
tests := []struct {
131+
name string
132+
annotations map[string]string
133+
want bool
134+
wantErr string
135+
}{
136+
{
137+
name: "should reapply (past+1h < now)",
138+
annotations: map[string]string{
139+
AnnotationReapplyTimestamp: strconv.FormatInt(past, 10),
140+
"mongodb.com/reapply-period": "1h",
141+
},
142+
want: true,
143+
wantErr: "",
144+
},
145+
{
146+
name: "should not reapply (future+1h > now)",
147+
annotations: map[string]string{
148+
AnnotationReapplyTimestamp: strconv.FormatInt(future, 10),
149+
"mongodb.com/reapply-period": "1h",
150+
},
151+
want: false,
152+
wantErr: "",
153+
},
154+
{
155+
name: "missing period",
156+
annotations: map[string]string{
157+
AnnotationReapplyTimestamp: strconv.FormatInt(past, 10),
158+
},
159+
want: false,
160+
wantErr: "",
161+
},
162+
{
163+
name: "missing timestamp",
164+
annotations: map[string]string{"mongodb.com/reapply-period": "1h"},
165+
want: false,
166+
wantErr: "",
167+
},
168+
{
169+
name: "invalid period",
170+
annotations: map[string]string{
171+
AnnotationReapplyTimestamp: strconv.FormatInt(past, 10),
172+
"mongodb.com/reapply-period": "bad",
173+
},
174+
want: false,
175+
wantErr: "invalid duration",
176+
},
177+
{
178+
name: "invalid timestamp",
179+
annotations: map[string]string{
180+
AnnotationReapplyTimestamp: "bad",
181+
"mongodb.com/reapply-period": "1h",
182+
},
183+
want: false,
184+
wantErr: "invalid syntax",
185+
},
186+
}
187+
188+
for _, tc := range tests {
189+
t.Run(tc.name, func(t *testing.T) {
190+
obj := newUnstructuredObj(tc.annotations)
191+
got, err := ShouldReapply(obj)
192+
assertErrContains(t, tc.wantErr, err)
193+
assert.Equal(t, tc.want, got)
194+
})
195+
}
196+
}
197+
198+
func TestPatchReapplyTimestamp(t *testing.T) {
199+
now := time.Now()
200+
pastMillis := strconv.FormatInt(now.Add(-2*time.Hour).UnixMilli(), 10)
201+
202+
tests := []struct {
203+
name string
204+
annotations map[string]string
205+
patchErr error
206+
want time.Duration
207+
wantErr string // substring to match in the error message
208+
wantPatched bool // true if we expect the annotation to be updated
209+
}{
210+
{
211+
name: "patch performed",
212+
annotations: map[string]string{
213+
AnnotationReapplyTimestamp: pastMillis,
214+
"mongodb.com/reapply-period": "1h",
215+
},
216+
want: time.Hour,
217+
wantErr: "",
218+
wantPatched: true,
219+
},
220+
{
221+
name: "patch not needed (no period)",
222+
annotations: map[string]string{},
223+
want: 0,
224+
wantErr: "",
225+
wantPatched: false,
226+
},
227+
{
228+
name: "patch error",
229+
annotations: map[string]string{
230+
AnnotationReapplyTimestamp: pastMillis,
231+
"mongodb.com/reapply-period": "1h",
232+
},
233+
patchErr: errors.New("fail"),
234+
want: 0,
235+
wantErr: "fail",
236+
wantPatched: true,
237+
},
238+
}
239+
for _, tc := range tests {
240+
t.Run(tc.name, func(t *testing.T) {
241+
obj := &corev1.Pod{
242+
ObjectMeta: metav1.ObjectMeta{
243+
Name: "dummy",
244+
Namespace: "default",
245+
Annotations: tc.annotations,
246+
},
247+
}
248+
scheme := runtime.NewScheme()
249+
_ = corev1.AddToScheme(scheme)
250+
patchFn := func(_ context.Context, _ client.WithWatch, _ client.Object, _ client.Patch, _ ...client.PatchOption) error {
251+
return tc.patchErr
252+
}
253+
if tc.patchErr == nil {
254+
patchFn = nil
255+
}
256+
c := fake.NewClientBuilder().
257+
WithScheme(scheme).
258+
WithObjects(obj.DeepCopy()).
259+
WithInterceptorFuncs(interceptor.Funcs{Patch: patchFn}).
260+
Build()
261+
ctx := context.Background()
262+
263+
period, err := PatchReapplyTimestamp(ctx, c, obj)
264+
assertErrContains(t, tc.wantErr, err)
265+
assert.Equal(t, tc.want, period)
266+
267+
fetched := &corev1.Pod{}
268+
_ = c.Get(ctx, client.ObjectKeyFromObject(obj), fetched)
269+
270+
annot := fetched.GetAnnotations()
271+
_, patched := annot[AnnotationReapplyTimestamp]
272+
273+
assert.Equal(t, tc.wantPatched, patched, "Annotation patched?")
274+
})
275+
}
276+
}
277+
278+
// Helper to create an Unstructured object with annotations.
279+
func newUnstructuredObj(annotations map[string]string) *unstructured.Unstructured {
280+
obj := &unstructured.Unstructured{}
281+
obj.SetAnnotations(annotations)
282+
return obj
283+
}
284+
285+
func assertErrContains(t *testing.T, wantErr string, err error) {
286+
if wantErr == "" {
287+
assert.NoError(t, err)
288+
} else {
289+
assert.ErrorContains(t, err, wantErr)
290+
}
291+
}

pkg/controller/state/reconciler.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import (
3434

3535
"github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/finalizer"
3636
"github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/state"
37-
"github.com/mongodb/mongodb-atlas-kubernetes/v2/pkg/status"
3837
)
3938

4039
type Result struct {
@@ -109,8 +108,11 @@ func (r *Reconciler[T]) Reconcile(ctx context.Context, req ctrl.Request) (reconc
109108
return ctrl.Result{}, fmt.Errorf("unable to get object: %w", err)
110109
}
111110

112-
currentStatus := status.GetStatus(obj)
113-
currentState := state.GetState(currentStatus.Status.Conditions)
111+
currentStatus, err := newStatusObject(obj)
112+
if err != nil {
113+
return ctrl.Result{}, fmt.Errorf("failed to get status object: %w", err)
114+
}
115+
currentState := state.GetState(currentStatus.GetConditions())
114116

115117
logger.Info("reconcile started", "currentState", currentState)
116118
if err := finalizer.EnsureFinalizers(ctx, r.cluster.GetClient(), obj, "mongodb.com/finalizer"); err != nil {
@@ -123,9 +125,13 @@ func (r *Reconciler[T]) Reconcile(ctx context.Context, req ctrl.Request) (reconc
123125
// error message will be displayed in Ready state.
124126
stateStatus = false
125127
}
126-
newStatus := status.GetStatus(obj)
128+
newStatus, err := newStatusObject(obj)
129+
if err != nil {
130+
return ctrl.Result{}, fmt.Errorf("failed to get status object: %w", err)
131+
}
127132
observedGeneration := getObservedGeneration(obj, currentStatus, result.NextState)
128-
state.EnsureState(&newStatus.Status.Conditions, observedGeneration, result.NextState, result.StateMsg, stateStatus)
133+
newStatusConditions := newStatus.GetConditions()
134+
state.EnsureState(&newStatusConditions, observedGeneration, result.NextState, result.StateMsg, stateStatus)
129135

130136
logger.Info("reconcile finished", "nextState", result.NextState)
131137

@@ -146,9 +152,9 @@ func (r *Reconciler[T]) Reconcile(ctx context.Context, req ctrl.Request) (reconc
146152
ready.Message = reconcileErr.Error()
147153
}
148154

149-
meta.SetStatusCondition(&newStatus.Status.Conditions, ready)
155+
meta.SetStatusCondition(&newStatusConditions, ready)
150156

151-
if err := status.PatchStatus(ctx, r.cluster.GetClient(), obj, newStatus); err != nil {
157+
if err := patchStatus(ctx, r.cluster.GetClient(), obj, newStatus); err != nil {
152158
return ctrl.Result{}, fmt.Errorf("failed to patch status: %w", err)
153159
}
154160

@@ -226,15 +232,18 @@ func (r *Reconciler[T]) ReconcileState(ctx context.Context, t *T) (Result, error
226232
obj := any(t).(client.Object)
227233

228234
var (
229-
currentState = state.GetState(status.GetStatus(obj).Status.Conditions)
230-
231235
result = Result{
232236
Result: reconcile.Result{},
233237
NextState: state.StateInitial,
234238
}
235239

236240
err error
237241
)
242+
statusObj, err := newStatusObject(obj)
243+
if err != nil {
244+
return Result{}, fmt.Errorf("failed to get status object: %w", err)
245+
}
246+
currentState := state.GetState(statusObj.GetConditions())
238247

239248
if currentState == state.StateInitial {
240249
for key := range obj.GetAnnotations() {
@@ -267,6 +276,8 @@ func (r *Reconciler[T]) ReconcileState(ctx context.Context, t *T) (Result, error
267276
result, err = r.reconciler.HandleDeletionRequested(ctx, t)
268277
case state.StateDeleting:
269278
result, err = r.reconciler.HandleDeleting(ctx, t)
279+
default:
280+
return Result{}, fmt.Errorf("unsupported state %q", currentState)
270281
}
271282

272283
if result.NextState == "" {
@@ -289,11 +300,11 @@ func (r *Reconciler[T]) ReconcileState(ctx context.Context, t *T) (Result, error
289300
return result, err
290301
}
291302

292-
func getObservedGeneration(obj client.Object, prevStatus *status.Resource, nextState state.ResourceState) int64 {
303+
func getObservedGeneration(obj client.Object, prevStatus StatusObject, nextState state.ResourceState) int64 {
293304
observedGeneration := obj.GetGeneration()
294-
prevState := state.GetState(prevStatus.Status.Conditions)
305+
prevState := state.GetState(prevStatus.GetConditions())
295306

296-
if prevCondition := meta.FindStatusCondition(prevStatus.Status.Conditions, state.StateCondition); prevCondition != nil {
307+
if prevCondition := meta.FindStatusCondition(prevStatus.GetConditions(), state.StateCondition); prevCondition != nil {
297308
from := prevState
298309
to := nextState
299310

0 commit comments

Comments
 (0)