Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions internal/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ import (
"github.com/castai/cluster-controller/internal/waitext"
)

// gcCompletedActionAfterTimes specifies after how many GCs to remove the completed action from the store.
const gcCompletedActionAfterTimes = 2

type Config struct {
PollWaitInterval time.Duration // How long to wait unit next long polling request.
PollTimeout time.Duration // hard timeout. Normally server should return empty result before this timeout.
Expand All @@ -40,13 +43,14 @@ func NewService(
actionHandlers actions.ActionHandlers,
) *Controller {
return &Controller{
log: log,
cfg: cfg,
k8sVersion: k8sVersion,
castAIClient: castaiClient,
startedActions: map[string]struct{}{},
actionHandlers: actionHandlers,
healthCheck: healthCheck,
log: log,
cfg: cfg,
k8sVersion: k8sVersion,
castAIClient: castaiClient,
startedActions: make(map[string]struct{}),
recentlyCompletedActions: make(map[string]int8),
actionHandlers: actionHandlers,
healthCheck: healthCheck,
}
}

Expand All @@ -59,10 +63,12 @@ type Controller struct {

actionHandlers actions.ActionHandlers

startedActionsWg sync.WaitGroup
startedActions map[string]struct{}
startedActionsMu sync.Mutex
healthCheck *health.HealthzProvider
startedActionsWg sync.WaitGroup
actionsMu sync.Mutex
startedActions map[string]struct{} // protected by actionsMu
recentlyCompletedActions map[string]int8 // protected by actionsMu

healthCheck *health.HealthzProvider
}

func (s *Controller) Run(ctx context.Context) {
Expand Down Expand Up @@ -122,6 +128,7 @@ func (s *Controller) doWork(ctx context.Context) error {

s.log.WithFields(logrus.Fields{"n": strconv.Itoa(len(actions))}).Infof("received in %s", pollDuration)
s.handleActions(ctx, actions)
s.gcCompletedActions()
return nil
}

Expand All @@ -132,7 +139,10 @@ func (s *Controller) handleActions(ctx context.Context, clusterActions []*castai
}

go func(action *castai.ClusterAction) {
defer s.finishProcessing(action.ID)
var ackErr error
defer func() {
s.finishProcessing(action.ID, ackErr)
}()

var err error

Expand All @@ -142,11 +152,12 @@ func (s *Controller) handleActions(ctx context.Context, clusterActions []*castai
handleErr := s.handleAction(ctx, action)
if errors.Is(handleErr, context.Canceled) {
// Action should be handled again on context canceled errors.
ackErr = ctx.Err()
return
}

handleDuration := time.Since(startTime)
ackErr := s.ackAction(ctx, action, handleErr, handleDuration)
ackErr = s.ackAction(ctx, action, handleErr, handleDuration)
if handleErr != nil {
err = handleErr
}
Expand All @@ -163,29 +174,40 @@ func (s *Controller) handleActions(ctx context.Context, clusterActions []*castai
}
}

func (s *Controller) finishProcessing(actionID string) {
s.startedActionsMu.Lock()
defer s.startedActionsMu.Unlock()
func (s *Controller) finishProcessing(actionID string, ackErr error) {
s.actionsMu.Lock()
defer s.actionsMu.Unlock()

s.startedActionsWg.Done()
delete(s.startedActions, actionID)

if ackErr == nil {
// only mark the action as completed if it was successfully acknowledged so it can be retried quickly if not and still requested.
s.recentlyCompletedActions[actionID] = gcCompletedActionAfterTimes + 1
}
}

func (s *Controller) startProcessing(actionID string) bool {
s.startedActionsMu.Lock()
defer s.startedActionsMu.Unlock()
s.actionsMu.Lock()
defer s.actionsMu.Unlock()

if _, ok := s.startedActions[actionID]; ok {
return false
}

if _, ok := s.recentlyCompletedActions[actionID]; ok {
s.log.WithField(actions.ActionIDLogField, actionID).Debug("action has been recently completed, not starting")
return false
}

if inProgress := len(s.startedActions); inProgress >= s.cfg.MaxActionsInProgress {
s.log.Warnf("too many actions in progress %d/%d", inProgress, s.cfg.MaxActionsInProgress)
return false
}

s.startedActionsWg.Add(1)
s.startedActions[actionID] = struct{}{}

return true
}

Expand Down Expand Up @@ -243,6 +265,25 @@ func (s *Controller) ackAction(ctx context.Context, action *castai.ClusterAction
})
}

// gcCompletedActions removes recently completed actions from memory after they've been visited
// a certain number of times during polling cycles. This prevents completed actions from being
// re-executed while allowing enough time for duplicate action requests to be filtered out.
// Actions are removed after gcCompletedActionAfterTimes visits to balance memory usage and
// protection against duplicate execution.
func (s *Controller) gcCompletedActions() {
s.actionsMu.Lock()
defer s.actionsMu.Unlock()

for actionID, timesVisited := range s.recentlyCompletedActions {
timesVisited--
if timesVisited <= 0 {
delete(s.recentlyCompletedActions, actionID)
continue
}
s.recentlyCompletedActions[actionID] = timesVisited
}
}

func getHandlerError(err error) *string {
if err != nil {
str := err.Error()
Expand Down
115 changes: 115 additions & 0 deletions internal/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@ import (
"context"
"fmt"
"reflect"
"strconv"
"sync"
"testing"
"testing/synctest"
"time"

"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"

Expand Down Expand Up @@ -264,6 +268,117 @@ func TestController_Run(t *testing.T) {
}
}

func TestController_ParallelExecutionTest(t *testing.T) {
t.Parallel()

synctest.Test(t, func(t *testing.T) {
cfg := Config{
PollWaitInterval: time.Second,
PollTimeout: 50 * time.Millisecond,
AckTimeout: time.Second,
AckRetriesCount: 2,
AckRetryWait: time.Millisecond,
ClusterID: uuid.New().String(),
MaxActionsInProgress: 2,
}

ctrl := gomock.NewController(t)
defer ctrl.Finish()

client := mock_castai.NewMockCastAIClient(ctrl)
handler := mock_actions.NewMockActionHandler(ctrl)

testActionHandlers := map[reflect.Type]actions.ActionHandler{
reflect.TypeFor[*castai.ActionCreateEvent](): handler,
}

const maxActions = 4
actions := make([]*castai.ClusterAction, 0, maxActions)
for i := range maxActions {
actions = append(actions, &castai.ClusterAction{
ID: "action-" + strconv.Itoa(i),
CreatedAt: time.Now(),
ActionCreateEvent: &castai.ActionCreateEvent{
EventType: "fake",
},
})
}
actionsWithAckErr := map[string]struct{}{
actions[2].ID: {},
}

var (
mu sync.Mutex
currentlyExecuting int
maxExecutingObserved int
executionCounts = make(map[string]int)
)

handler.EXPECT().Handle(gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, action *castai.ClusterAction) error {
mu.Lock()
currentlyExecuting++
executionCounts[action.ID]++
if currentlyExecuting > maxExecutingObserved {
maxExecutingObserved = currentlyExecuting
}
mu.Unlock()

time.Sleep(100 * time.Millisecond)

mu.Lock()
currentlyExecuting--
mu.Unlock()

return nil
},
).AnyTimes()

client.EXPECT().GetActions(gomock.Any(), gomock.Any()).Return(actions, nil).Times(1)
client.EXPECT().AckAction(gomock.Any(), gomock.Any(), &castai.AckClusterActionRequest{}).
DoAndReturn(func(ctx context.Context, actionID string, req *castai.AckClusterActionRequest) error {
if _, ok := actionsWithAckErr[actionID]; ok {
return assert.AnError
}
return nil
}).AnyTimes()

client.EXPECT().GetActions(gomock.Any(), gomock.Any()).Return(actions, nil).Times(3)

client.EXPECT().GetActions(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()

logger := logrus.New()
svc := NewService(
logger,
cfg,
"v0",
client,
health.NewHealthzProvider(health.HealthzCfg{HealthyPollIntervalLimit: cfg.PollTimeout}, logger),
testActionHandlers,
)

ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second)
defer cancel()

go svc.Run(ctx)

synctest.Wait()
<-ctx.Done()
svc.startedActionsWg.Wait()

require.LessOrEqual(t, maxExecutingObserved, 2, "Expected no more than 2 actions to execute concurrently, but observed %d", maxExecutingObserved)

for _, action := range actions {
count := executionCounts[action.ID]
if _, ok := actionsWithAckErr[action.ID]; ok {
assert.Equal(t, 3, count, "Expected action %s to be executed three times because of ack errors, but it was executed %d times", action.ID, count)
continue
}
assert.Equal(t, 1, count, "Expected action %s to be executed exactly once, but it was executed %d times", action.ID, count)
}
})
}

func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.IgnoreTopFunction("k8s.io/klog/v2.(*loggingT).flushDaemon"))
}
Loading