diff --git a/cmd/controller/run.go b/cmd/controller/run.go index e83c456a..f4b53662 100644 --- a/cmd/controller/run.go +++ b/cmd/controller/run.go @@ -25,6 +25,7 @@ import ( "github.com/castai/cluster-controller/cmd/utils" "github.com/castai/cluster-controller/health" + "github.com/castai/cluster-controller/internal/actions" "github.com/castai/cluster-controller/internal/actions/csr" "github.com/castai/cluster-controller/internal/castai" "github.com/castai/cluster-controller/internal/config" @@ -131,6 +132,15 @@ func runController( log.Infof("running castai-cluster-controller version %v, log-level: %v", binVersion, logger.Level) + actionHandlers := actions.NewDefaultActionHandlers( + k8sVer.Full(), + cfg.SelfPod.Namespace, + log, + clientset, + dynamicClient, + helmClient, + ) + actionsConfig := controller.Config{ PollWaitInterval: 5 * time.Second, PollTimeout: maxRequestTimeout, @@ -148,11 +158,9 @@ func runController( log, actionsConfig, k8sVer.Full(), - clientset, - dynamicClient, client, - helmClient, healthzAction, + actionHandlers, ) defer func() { if err := svc.Close(); err != nil { diff --git a/internal/actions/actions.go b/internal/actions/actions.go new file mode 100644 index 00000000..761801a2 --- /dev/null +++ b/internal/actions/actions.go @@ -0,0 +1,44 @@ +package actions + +import ( + "reflect" + + "github.com/sirupsen/logrus" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/kubernetes" + + "github.com/castai/cluster-controller/internal/castai" + "github.com/castai/cluster-controller/internal/helm" +) + +type ActionHandlers map[reflect.Type]ActionHandler + +func NewDefaultActionHandlers( + k8sVersion string, + castNamespace string, + log logrus.FieldLogger, + clientset *kubernetes.Clientset, + dynamicClient dynamic.Interface, + helmClient helm.Client, +) ActionHandlers { + return ActionHandlers{ + reflect.TypeOf(&castai.ActionDeleteNode{}): NewDeleteNodeHandler(log, clientset), + reflect.TypeOf(&castai.ActionDrainNode{}): NewDrainNodeHandler(log, clientset, castNamespace), + reflect.TypeOf(&castai.ActionPatchNode{}): NewPatchNodeHandler(log, clientset), + reflect.TypeOf(&castai.ActionCreateEvent{}): NewCreateEventHandler(log, clientset), + reflect.TypeOf(&castai.ActionChartUpsert{}): NewChartUpsertHandler(log, helmClient), + reflect.TypeOf(&castai.ActionChartUninstall{}): NewChartUninstallHandler(log, helmClient), + reflect.TypeOf(&castai.ActionChartRollback{}): NewChartRollbackHandler(log, helmClient, k8sVersion), + reflect.TypeOf(&castai.ActionDisconnectCluster{}): NewDisconnectClusterHandler(log, clientset), + reflect.TypeOf(&castai.ActionCheckNodeDeleted{}): NewCheckNodeDeletedHandler(log, clientset), + reflect.TypeOf(&castai.ActionCheckNodeStatus{}): NewCheckNodeStatusHandler(log, clientset), + reflect.TypeOf(&castai.ActionEvictPod{}): NewEvictPodHandler(log, clientset), + reflect.TypeOf(&castai.ActionPatch{}): NewPatchHandler(log, dynamicClient), + reflect.TypeOf(&castai.ActionCreate{}): NewCreateHandler(log, dynamicClient), + reflect.TypeOf(&castai.ActionDelete{}): NewDeleteHandler(log, dynamicClient), + } +} + +func (h ActionHandlers) Close() error { + return h[reflect.TypeOf(&castai.ActionCreateEvent{})].(*CreateEventHandler).Close() +} diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 5906a508..d1d7bea9 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -11,13 +11,10 @@ import ( "time" "github.com/sirupsen/logrus" - "k8s.io/client-go/dynamic" - "k8s.io/client-go/kubernetes" "github.com/castai/cluster-controller/health" "github.com/castai/cluster-controller/internal/actions" "github.com/castai/cluster-controller/internal/castai" - "github.com/castai/cluster-controller/internal/helm" "github.com/castai/cluster-controller/internal/metrics" "github.com/castai/cluster-controller/internal/waitext" ) @@ -38,11 +35,9 @@ func NewService( log logrus.FieldLogger, cfg Config, k8sVersion string, - clientset *kubernetes.Clientset, - dynamicClient dynamic.Interface, castaiClient castai.CastAIClient, - helmClient helm.Client, healthCheck *health.HealthzProvider, + actionHandlers actions.ActionHandlers, ) *Controller { return &Controller{ log: log, @@ -50,23 +45,8 @@ func NewService( k8sVersion: k8sVersion, castAIClient: castaiClient, startedActions: map[string]struct{}{}, - actionHandlers: map[reflect.Type]actions.ActionHandler{ - reflect.TypeOf(&castai.ActionDeleteNode{}): actions.NewDeleteNodeHandler(log, clientset), - reflect.TypeOf(&castai.ActionDrainNode{}): actions.NewDrainNodeHandler(log, clientset, cfg.Namespace), - reflect.TypeOf(&castai.ActionPatchNode{}): actions.NewPatchNodeHandler(log, clientset), - reflect.TypeOf(&castai.ActionCreateEvent{}): actions.NewCreateEventHandler(log, clientset), - reflect.TypeOf(&castai.ActionChartUpsert{}): actions.NewChartUpsertHandler(log, helmClient), - reflect.TypeOf(&castai.ActionChartUninstall{}): actions.NewChartUninstallHandler(log, helmClient), - reflect.TypeOf(&castai.ActionChartRollback{}): actions.NewChartRollbackHandler(log, helmClient, cfg.Version), - reflect.TypeOf(&castai.ActionDisconnectCluster{}): actions.NewDisconnectClusterHandler(log, clientset), - reflect.TypeOf(&castai.ActionCheckNodeDeleted{}): actions.NewCheckNodeDeletedHandler(log, clientset), - reflect.TypeOf(&castai.ActionCheckNodeStatus{}): actions.NewCheckNodeStatusHandler(log, clientset), - reflect.TypeOf(&castai.ActionEvictPod{}): actions.NewEvictPodHandler(log, clientset), - reflect.TypeOf(&castai.ActionPatch{}): actions.NewPatchHandler(log, dynamicClient), - reflect.TypeOf(&castai.ActionCreate{}): actions.NewCreateHandler(log, dynamicClient), - reflect.TypeOf(&castai.ActionDelete{}): actions.NewDeleteHandler(log, dynamicClient), - }, - healthCheck: healthCheck, + actionHandlers: actionHandlers, + healthCheck: healthCheck, } } @@ -77,7 +57,7 @@ type Controller struct { k8sVersion string - actionHandlers map[reflect.Type]actions.ActionHandler + actionHandlers actions.ActionHandlers startedActionsWg sync.WaitGroup startedActions map[string]struct{} @@ -273,5 +253,5 @@ func getHandlerError(err error) *string { } func (s *Controller) Close() error { - return s.actionHandlers[reflect.TypeOf(&castai.ActionCreateEvent{})].(*actions.CreateEventHandler).Close() + return s.actionHandlers.Close() } diff --git a/internal/controller/controller_test.go b/internal/controller/controller_test.go index 9213b920..fd6e8f0f 100644 --- a/internal/controller/controller_test.go +++ b/internal/controller/controller_test.go @@ -3,6 +3,7 @@ package controller import ( "context" "fmt" + "reflect" "testing" "time" @@ -11,12 +12,12 @@ import ( "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "go.uber.org/goleak" - "k8s.io/client-go/kubernetes" "github.com/castai/cluster-controller/health" + "github.com/castai/cluster-controller/internal/actions" mock_actions "github.com/castai/cluster-controller/internal/actions/mock" "github.com/castai/cluster-controller/internal/castai" - "github.com/castai/cluster-controller/internal/castai/mock" + mock_castai "github.com/castai/cluster-controller/internal/castai/mock" ) // nolint: govet @@ -99,9 +100,9 @@ func TestController_Run(t *testing.T) { }, }, }, nil).Times(1).MinTimes(1) - m.EXPECT().AckAction(gomock.Any(), "a1", gomock.Any()).Return(nil).MinTimes(1) - m.EXPECT().AckAction(gomock.Any(), "a2", gomock.Any()).Return(nil).MinTimes(1) - m.EXPECT().AckAction(gomock.Any(), "a3", gomock.Any()).Return(nil).MinTimes(1) + m.EXPECT().AckAction(gomock.Any(), "a1", &castai.AckClusterActionRequest{}).Return(nil).MinTimes(1) + m.EXPECT().AckAction(gomock.Any(), "a2", &castai.AckClusterActionRequest{}).Return(nil).MinTimes(1) + m.EXPECT().AckAction(gomock.Any(), "a3", &castai.AckClusterActionRequest{}).Return(nil).MinTimes(1) }, }, }, @@ -240,22 +241,25 @@ func TestController_Run(t *testing.T) { if tt.fields.tuneMockCastAIClient != nil { tt.fields.tuneMockCastAIClient(client) } - s := NewService( - logrus.New(), - tt.fields.cfg, - tt.fields.k8sVersion, - kubernetes.New(nil), - nil, - client, - nil, - health.NewHealthzProvider(health.HealthzCfg{HealthyPollIntervalLimit: pollTimeout}, logrus.New())) + handler := mock_actions.NewMockActionHandler(m) if tt.fields.tuneMockHandler != nil { tt.fields.tuneMockHandler(handler) } - for k := range s.actionHandlers { - s.actionHandlers[k] = handler + testActionHandlers := map[reflect.Type]actions.ActionHandler{ + reflect.TypeOf(&castai.ActionDeleteNode{}): handler, + reflect.TypeOf(&castai.ActionDrainNode{}): handler, + reflect.TypeOf(&castai.ActionPatchNode{}): handler, } + + s := NewService( + logrus.New(), + tt.fields.cfg, + tt.fields.k8sVersion, + client, + health.NewHealthzProvider(health.HealthzCfg{HealthyPollIntervalLimit: pollTimeout}, logrus.New()), + testActionHandlers) + s.Run(tt.args.ctx()) }) }