Skip to content

Commit 8e403ca

Browse files
authored
feat: execute successfully completed actions once (#206)
actions might be executed at least twice as PollActions and AckAction are asynchronous operations and can sometimes race. Introduce a store for tracking recently completed actions to avoid that. The completed actions will be garbage collected after two polls when they are completed and successfully acked.
1 parent 3ef58ed commit 8e403ca

File tree

2 files changed

+174
-18
lines changed

2 files changed

+174
-18
lines changed

internal/controller/controller.go

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import (
1919
"github.com/castai/cluster-controller/internal/waitext"
2020
)
2121

22+
// gcCompletedActionAfterTimes specifies after how many GCs to remove the completed action from the store.
23+
const gcCompletedActionAfterTimes = 2
24+
2225
type Config struct {
2326
PollWaitInterval time.Duration // How long to wait unit next long polling request.
2427
PollTimeout time.Duration // hard timeout. Normally server should return empty result before this timeout.
@@ -40,13 +43,14 @@ func NewService(
4043
actionHandlers actions.ActionHandlers,
4144
) *Controller {
4245
return &Controller{
43-
log: log,
44-
cfg: cfg,
45-
k8sVersion: k8sVersion,
46-
castAIClient: castaiClient,
47-
startedActions: map[string]struct{}{},
48-
actionHandlers: actionHandlers,
49-
healthCheck: healthCheck,
46+
log: log,
47+
cfg: cfg,
48+
k8sVersion: k8sVersion,
49+
castAIClient: castaiClient,
50+
startedActions: make(map[string]struct{}),
51+
recentlyCompletedActions: make(map[string]int8),
52+
actionHandlers: actionHandlers,
53+
healthCheck: healthCheck,
5054
}
5155
}
5256

@@ -59,10 +63,12 @@ type Controller struct {
5963

6064
actionHandlers actions.ActionHandlers
6165

62-
startedActionsWg sync.WaitGroup
63-
startedActions map[string]struct{}
64-
startedActionsMu sync.Mutex
65-
healthCheck *health.HealthzProvider
66+
startedActionsWg sync.WaitGroup
67+
actionsMu sync.Mutex
68+
startedActions map[string]struct{} // protected by actionsMu
69+
recentlyCompletedActions map[string]int8 // protected by actionsMu
70+
71+
healthCheck *health.HealthzProvider
6672
}
6773

6874
func (s *Controller) Run(ctx context.Context) {
@@ -122,6 +128,7 @@ func (s *Controller) doWork(ctx context.Context) error {
122128

123129
s.log.WithFields(logrus.Fields{"n": strconv.Itoa(len(actions))}).Infof("received in %s", pollDuration)
124130
s.handleActions(ctx, actions)
131+
s.gcCompletedActions()
125132
return nil
126133
}
127134

@@ -132,7 +139,10 @@ func (s *Controller) handleActions(ctx context.Context, clusterActions []*castai
132139
}
133140

134141
go func(action *castai.ClusterAction) {
135-
defer s.finishProcessing(action.ID)
142+
var ackErr error
143+
defer func() {
144+
s.finishProcessing(action.ID, ackErr)
145+
}()
136146

137147
var err error
138148

@@ -142,11 +152,12 @@ func (s *Controller) handleActions(ctx context.Context, clusterActions []*castai
142152
handleErr := s.handleAction(ctx, action)
143153
if errors.Is(handleErr, context.Canceled) {
144154
// Action should be handled again on context canceled errors.
155+
ackErr = ctx.Err()
145156
return
146157
}
147158

148159
handleDuration := time.Since(startTime)
149-
ackErr := s.ackAction(ctx, action, handleErr, handleDuration)
160+
ackErr = s.ackAction(ctx, action, handleErr, handleDuration)
150161
if handleErr != nil {
151162
err = handleErr
152163
}
@@ -163,29 +174,40 @@ func (s *Controller) handleActions(ctx context.Context, clusterActions []*castai
163174
}
164175
}
165176

166-
func (s *Controller) finishProcessing(actionID string) {
167-
s.startedActionsMu.Lock()
168-
defer s.startedActionsMu.Unlock()
177+
func (s *Controller) finishProcessing(actionID string, ackErr error) {
178+
s.actionsMu.Lock()
179+
defer s.actionsMu.Unlock()
169180

170181
s.startedActionsWg.Done()
171182
delete(s.startedActions, actionID)
183+
184+
if ackErr == nil {
185+
// only mark the action as completed if it was successfully acknowledged so it can be retried quickly if not and still requested.
186+
s.recentlyCompletedActions[actionID] = gcCompletedActionAfterTimes + 1
187+
}
172188
}
173189

174190
func (s *Controller) startProcessing(actionID string) bool {
175-
s.startedActionsMu.Lock()
176-
defer s.startedActionsMu.Unlock()
191+
s.actionsMu.Lock()
192+
defer s.actionsMu.Unlock()
177193

178194
if _, ok := s.startedActions[actionID]; ok {
179195
return false
180196
}
181197

198+
if _, ok := s.recentlyCompletedActions[actionID]; ok {
199+
s.log.WithField(actions.ActionIDLogField, actionID).Debug("action has been recently completed, not starting")
200+
return false
201+
}
202+
182203
if inProgress := len(s.startedActions); inProgress >= s.cfg.MaxActionsInProgress {
183204
s.log.Warnf("too many actions in progress %d/%d", inProgress, s.cfg.MaxActionsInProgress)
184205
return false
185206
}
186207

187208
s.startedActionsWg.Add(1)
188209
s.startedActions[actionID] = struct{}{}
210+
189211
return true
190212
}
191213

@@ -243,6 +265,25 @@ func (s *Controller) ackAction(ctx context.Context, action *castai.ClusterAction
243265
})
244266
}
245267

268+
// gcCompletedActions removes recently completed actions from memory after they've been visited
269+
// a certain number of times during polling cycles. This prevents completed actions from being
270+
// re-executed while allowing enough time for duplicate action requests to be filtered out.
271+
// Actions are removed after gcCompletedActionAfterTimes visits to balance memory usage and
272+
// protection against duplicate execution.
273+
func (s *Controller) gcCompletedActions() {
274+
s.actionsMu.Lock()
275+
defer s.actionsMu.Unlock()
276+
277+
for actionID, timesVisited := range s.recentlyCompletedActions {
278+
timesVisited--
279+
if timesVisited <= 0 {
280+
delete(s.recentlyCompletedActions, actionID)
281+
continue
282+
}
283+
s.recentlyCompletedActions[actionID] = timesVisited
284+
}
285+
}
286+
246287
func getHandlerError(err error) *string {
247288
if err != nil {
248289
str := err.Error()

internal/controller/controller_test.go

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ import (
44
"context"
55
"fmt"
66
"reflect"
7+
"strconv"
8+
"sync"
79
"testing"
10+
"testing/synctest"
811
"time"
912

1013
"github.com/golang/mock/gomock"
1114
"github.com/google/uuid"
1215
"github.com/sirupsen/logrus"
16+
"github.com/stretchr/testify/assert"
1317
"github.com/stretchr/testify/require"
1418
"go.uber.org/goleak"
1519

@@ -264,6 +268,117 @@ func TestController_Run(t *testing.T) {
264268
}
265269
}
266270

271+
func TestController_ParallelExecutionTest(t *testing.T) {
272+
t.Parallel()
273+
274+
synctest.Test(t, func(t *testing.T) {
275+
cfg := Config{
276+
PollWaitInterval: time.Second,
277+
PollTimeout: 50 * time.Millisecond,
278+
AckTimeout: time.Second,
279+
AckRetriesCount: 2,
280+
AckRetryWait: time.Millisecond,
281+
ClusterID: uuid.New().String(),
282+
MaxActionsInProgress: 2,
283+
}
284+
285+
ctrl := gomock.NewController(t)
286+
defer ctrl.Finish()
287+
288+
client := mock_castai.NewMockCastAIClient(ctrl)
289+
handler := mock_actions.NewMockActionHandler(ctrl)
290+
291+
testActionHandlers := map[reflect.Type]actions.ActionHandler{
292+
reflect.TypeFor[*castai.ActionCreateEvent](): handler,
293+
}
294+
295+
const maxActions = 4
296+
actions := make([]*castai.ClusterAction, 0, maxActions)
297+
for i := range maxActions {
298+
actions = append(actions, &castai.ClusterAction{
299+
ID: "action-" + strconv.Itoa(i),
300+
CreatedAt: time.Now(),
301+
ActionCreateEvent: &castai.ActionCreateEvent{
302+
EventType: "fake",
303+
},
304+
})
305+
}
306+
actionsWithAckErr := map[string]struct{}{
307+
actions[2].ID: {},
308+
}
309+
310+
var (
311+
mu sync.Mutex
312+
currentlyExecuting int
313+
maxExecutingObserved int
314+
executionCounts = make(map[string]int)
315+
)
316+
317+
handler.EXPECT().Handle(gomock.Any(), gomock.Any()).DoAndReturn(
318+
func(ctx context.Context, action *castai.ClusterAction) error {
319+
mu.Lock()
320+
currentlyExecuting++
321+
executionCounts[action.ID]++
322+
if currentlyExecuting > maxExecutingObserved {
323+
maxExecutingObserved = currentlyExecuting
324+
}
325+
mu.Unlock()
326+
327+
time.Sleep(100 * time.Millisecond)
328+
329+
mu.Lock()
330+
currentlyExecuting--
331+
mu.Unlock()
332+
333+
return nil
334+
},
335+
).AnyTimes()
336+
337+
client.EXPECT().GetActions(gomock.Any(), gomock.Any()).Return(actions, nil).Times(1)
338+
client.EXPECT().AckAction(gomock.Any(), gomock.Any(), &castai.AckClusterActionRequest{}).
339+
DoAndReturn(func(ctx context.Context, actionID string, req *castai.AckClusterActionRequest) error {
340+
if _, ok := actionsWithAckErr[actionID]; ok {
341+
return assert.AnError
342+
}
343+
return nil
344+
}).AnyTimes()
345+
346+
client.EXPECT().GetActions(gomock.Any(), gomock.Any()).Return(actions, nil).Times(3)
347+
348+
client.EXPECT().GetActions(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
349+
350+
logger := logrus.New()
351+
svc := NewService(
352+
logger,
353+
cfg,
354+
"v0",
355+
client,
356+
health.NewHealthzProvider(health.HealthzCfg{HealthyPollIntervalLimit: cfg.PollTimeout}, logger),
357+
testActionHandlers,
358+
)
359+
360+
ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
361+
defer cancel()
362+
363+
go svc.Run(ctx)
364+
365+
synctest.Wait()
366+
<-ctx.Done()
367+
svc.startedActionsWg.Wait()
368+
369+
require.LessOrEqual(t, maxExecutingObserved, 2, "Expected no more than 2 actions to execute concurrently, but observed %d", maxExecutingObserved)
370+
371+
for _, action := range actions {
372+
count := executionCounts[action.ID]
373+
if _, ok := actionsWithAckErr[action.ID]; ok {
374+
assert.Equal(t, 3, count, "Expected action %s to be executed three times because of ack errors, but it was executed %d times", action.ID, count)
375+
continue
376+
}
377+
assert.Equal(t, 1, count, "Expected action %s to be executed exactly once, but it was executed %d times", action.ID, count)
378+
}
379+
})
380+
}
381+
267382
func TestMain(m *testing.M) {
268383
goleak.VerifyTestMain(m, goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"))
269384
}

0 commit comments

Comments
 (0)