diff --git a/Makefile b/Makefile index 060a26497..1ac3415b8 100644 --- a/Makefile +++ b/Makefile @@ -100,21 +100,15 @@ cover-html: @echo "> Generating HTML coverage report for operator" @make --directory=operator cover-html -# Runs envtest tests for the operator -.PHONY: test-envtest -test-envtest: - @echo "> Running envtest for operator" - @make --directory=operator test-envtest - # Runs e2e tests for the operator .PHONY: test-e2e test-e2e: @echo "> Running e2e tests for operator" @make --directory=operator test-e2e -# Runs all tests (unit + envtest) +# Runs all tests .PHONY: test -test: test-unit test-envtest +test: test-unit @echo "> All tests passed" # Updates the docs/proposals table of contents diff --git a/operator/internal/constants/gpu.go b/operator/internal/constants/gpu.go new file mode 100644 index 000000000..6c6400e54 --- /dev/null +++ b/operator/internal/constants/gpu.go @@ -0,0 +1,28 @@ +// /* +// Copyright 2026 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package constants + +import ( + corev1 "k8s.io/api/core/v1" +) + +// GPU resource constants +const ( + // GPUResourceName is the resource name for NVIDIA GPUs. It is used to determine + // if a pod requests NVIDIA GPUs. + GPUResourceName corev1.ResourceName = "nvidia.com/gpu" +) diff --git a/operator/internal/controller/manager.go b/operator/internal/controller/manager.go index 34ff624da..9483481d2 100644 --- a/operator/internal/controller/manager.go +++ b/operator/internal/controller/manager.go @@ -62,7 +62,7 @@ func RegisterControllersAndWebhooks(mgr ctrl.Manager, logger logr.Logger, operat if err := registerControllersWithMgr(mgr, operatorCfg.Controllers, operatorCfg.TopologyAwareScheduling); err != nil { return err } - if err := registerWebhooksWithMgr(mgr, operatorCfg.Authorizer, operatorCfg.TopologyAwareScheduling); err != nil { + if err := registerWebhooksWithMgr(mgr, operatorCfg.Authorizer, operatorCfg.TopologyAwareScheduling, operatorCfg.Network); err != nil { return err } return nil diff --git a/operator/internal/controller/manager_test.go b/operator/internal/controller/manager_test.go index 169a56b11..505ccae3a 100644 --- a/operator/internal/controller/manager_test.go +++ b/operator/internal/controller/manager_test.go @@ -556,7 +556,7 @@ func TestRegisterControllersAndWebhooks(t *testing.T) { controllersCalled = true return tc.controllerErr } - registerWebhooksWithMgr = func(_ ctrl.Manager, _ configv1alpha1.AuthorizerConfig, _ configv1alpha1.TopologyAwareSchedulingConfiguration) error { + registerWebhooksWithMgr = func(_ ctrl.Manager, _ configv1alpha1.AuthorizerConfig, _ configv1alpha1.TopologyAwareSchedulingConfiguration, _ configv1alpha1.NetworkAcceleration) error { webhooksCalled = true return tc.webhookErr } diff --git a/operator/internal/mnnvl/helpers.go b/operator/internal/mnnvl/helpers.go new file mode 100644 index 000000000..035fbce12 --- /dev/null +++ b/operator/internal/mnnvl/helpers.go @@ -0,0 +1,68 @@ +// /* +// Copyright 2025 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package mnnvl + +import ( + grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + "github.com/ai-dynamo/grove/operator/internal/constants" + + corev1 "k8s.io/api/core/v1" +) + +// hasGPURequirement checks if any container in any clique of the PCS requests nvidia.com/gpu. +func hasGPURequirement(pcs *grovecorev1alpha1.PodCliqueSet) bool { + for _, clique := range pcs.Spec.Template.Cliques { + if clique == nil { + continue + } + if hasGPUInContainers(clique.Spec.PodSpec.Containers) { + return true + } + if hasGPUInContainers(clique.Spec.PodSpec.InitContainers) { + return true + } + } + return false +} + +// hasGPUInContainers checks if any container in the slice requests GPU resources. +func hasGPUInContainers(containers []corev1.Container) bool { + for _, container := range containers { + // Check limits + if quantity, exists := container.Resources.Limits[constants.GPUResourceName]; exists { + if !quantity.IsZero() { + return true + } + } + // Check requests + if quantity, exists := container.Resources.Requests[constants.GPUResourceName]; exists { + if !quantity.IsZero() { + return true + } + } + } + return false +} + +// getAnnotationValue safely retrieves an annotation value from a PCS. +func getAnnotationValue(pcs *grovecorev1alpha1.PodCliqueSet, key string) (string, bool) { + if pcs.Annotations == nil { + return "", false + } + value, exists := pcs.Annotations[key] + return value, exists +} diff --git a/operator/internal/mnnvl/helpers_test.go b/operator/internal/mnnvl/helpers_test.go new file mode 100644 index 000000000..6150ee02c --- /dev/null +++ b/operator/internal/mnnvl/helpers_test.go @@ -0,0 +1,154 @@ +// /* +// Copyright 2025 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package mnnvl + +import ( + "testing" + + grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + "github.com/ai-dynamo/grove/operator/internal/constants" + testutils "github.com/ai-dynamo/grove/operator/test/utils" + + "github.com/stretchr/testify/assert" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" +) + +func Test_hasGPURequirement(t *testing.T) { + tests := []struct { + name string + pcs *grovecorev1alpha1.PodCliqueSet + expected bool + }{ + { + name: "container with GPU limits", + pcs: createPCSWithGPU(nil), + expected: true, + }, + { + name: "container without GPU", + pcs: createPCSWithoutGPU(nil), + expected: false, + }, + { + name: "empty cliques", + pcs: &grovecorev1alpha1.PodCliqueSet{}, + expected: false, + }, + { + name: "GPU in init container", + pcs: testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithInitContainer(corev1.Container{ + Name: "init", + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + constants.GPUResourceName: resource.MustParse("1"), + }, + }, + }). + Build(), + ). + Build(), + expected: true, + }, + { + name: "GPU in requests not limits", + pcs: testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(corev1.Container{ + Name: "train", + Resources: corev1.ResourceRequirements{ + Requests: corev1.ResourceList{ + constants.GPUResourceName: resource.MustParse("2"), + }, + }, + }). + Build(), + ). + Build(), + expected: true, + }, + { + name: "GPU with zero quantity", + pcs: testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(corev1.Container{ + Name: "train", + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + constants.GPUResourceName: resource.MustParse("0"), + }, + }, + }). + Build(), + ). + Build(), + expected: false, + }, + { + name: "multiple cliques - one with GPU", + pcs: testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("controller"). + WithContainer(testutils.NewContainer("ctrl", "busybox")). + Build(), + ). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(testutils.NewGPUContainer("train", "nvidia/cuda:latest", 8)). + Build(), + ). + Build(), + expected: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result := hasGPURequirement(test.pcs) + assert.Equal(t, test.expected, result) + }) + } +} + +// createPCSWithGPU creates a PCS with GPU using the builder for tests in this package. +func createPCSWithGPU(annotations map[string]string) *grovecorev1alpha1.PodCliqueSet { + return testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithAnnotations(annotations). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(testutils.NewGPUContainer("train", "nvidia/cuda:latest", 8)). + Build(), + ). + Build() +} + +// createPCSWithoutGPU creates a PCS without GPU using the builder for tests in this package. +func createPCSWithoutGPU(annotations map[string]string) *grovecorev1alpha1.PodCliqueSet { + return testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithAnnotations(annotations). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(testutils.NewContainer("app", "nginx:latest")). + Build(), + ). + Build() +} diff --git a/operator/internal/mnnvl/webhook.go b/operator/internal/mnnvl/webhook.go new file mode 100644 index 000000000..a6244fcb7 --- /dev/null +++ b/operator/internal/mnnvl/webhook.go @@ -0,0 +1,99 @@ +// /* +// Copyright 2026 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package mnnvl + +import ( + "fmt" + + grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" +) + +// MutateAutoMNNVL adds the grove.io/auto-mnnvl annotation to a PodCliqueSet +// if all conditions are met: +// 1. Annotation does not already exist +// 2. MNNVL feature is enabled globally (autoMNNVLEnabled) +// 3. PCS has at least one container requesting GPU +// +// Returns true if the annotation was added, false otherwise. +func MutateAutoMNNVL(pcs *grovecorev1alpha1.PodCliqueSet, autoMNNVLEnabled bool) bool { + // If feature is disabled, don't add annotation + if !autoMNNVLEnabled { + return false + } + + // If annotation already exists (user explicitly set it), don't override + if pcs.Annotations != nil { + if _, exists := pcs.Annotations[AnnotationAutoMNNVL]; exists { + return false + } + } + + // Check if PCS has GPU requirements + if !hasGPURequirement(pcs) { + return false + } + + // All conditions met - add the annotation + if pcs.Annotations == nil { + pcs.Annotations = make(map[string]string) + } + pcs.Annotations[AnnotationAutoMNNVL] = "true" + return true +} + +// ValidateAutoMNNVLOnCreate validates the MNNVL annotation on PCS creation. +// Returns an error if the annotation is set to "true" but the MNNVL feature is disabled. +// This prevents users from explicitly requesting MNNVL when the cluster doesn't support it. +func ValidateAutoMNNVLOnCreate(pcs *grovecorev1alpha1.PodCliqueSet, autoMNNVLEnabled bool) error { + value, exists := pcs.Annotations[AnnotationAutoMNNVL] + if !exists { + return nil + } + + // If annotation is "true" but feature is disabled, reject + if value == "true" && !autoMNNVLEnabled { + return fmt.Errorf("MNNVL is not enabled in the operator configuration. "+ + "Either enable MNNVL globally or remove the %s annotation", AnnotationAutoMNNVL) + } + + return nil +} + +// ValidateAutoMNNVLOnUpdate ensures the grove.io/auto-mnnvl annotation is immutable. +// Returns an error if the annotation was added, removed, or its value was changed. +func ValidateAutoMNNVLOnUpdate(oldPCS, newPCS *grovecorev1alpha1.PodCliqueSet) error { + oldValue, oldExists := getAnnotationValue(oldPCS, AnnotationAutoMNNVL) + newValue, newExists := getAnnotationValue(newPCS, AnnotationAutoMNNVL) + + // Check if annotation was added + if !oldExists && newExists { + return fmt.Errorf("annotation %s cannot be added after PodCliqueSet creation", AnnotationAutoMNNVL) + } + + // Check if annotation was removed + if oldExists && !newExists { + return fmt.Errorf("annotation %s cannot be removed after PodCliqueSet creation", AnnotationAutoMNNVL) + } + + // Check if annotation value was changed + if newExists && oldValue != newValue { + return fmt.Errorf("annotation %s is immutable and cannot be changed from %q to %q", + AnnotationAutoMNNVL, oldValue, newValue) + } + + return nil +} diff --git a/operator/internal/mnnvl/webhook_test.go b/operator/internal/mnnvl/webhook_test.go new file mode 100644 index 000000000..152116984 --- /dev/null +++ b/operator/internal/mnnvl/webhook_test.go @@ -0,0 +1,236 @@ +// /* +// Copyright 2026 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package mnnvl + +import ( + "testing" + + grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + + "github.com/stretchr/testify/assert" +) + +func TestMutateAutoMNNVL(t *testing.T) { + tests := []struct { + description string + pcs *grovecorev1alpha1.PodCliqueSet + autoMNNVLEnabled bool + expectMutation bool + expectedAnnotation string + }{ + { + description: "feature enabled + GPU + no annotation -> add true", + pcs: createPCSWithGPU(nil), + autoMNNVLEnabled: true, + expectMutation: true, + expectedAnnotation: "true", + }, + { + description: "feature enabled + GPU + existing false -> no change", + pcs: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + autoMNNVLEnabled: true, + expectMutation: false, + expectedAnnotation: "false", + }, + { + description: "feature enabled + GPU + existing true -> no change", + pcs: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + autoMNNVLEnabled: true, + expectMutation: false, + expectedAnnotation: "true", + }, + { + description: "feature enabled + no GPU -> no annotation", + pcs: createPCSWithoutGPU(nil), + autoMNNVLEnabled: true, + expectMutation: false, + expectedAnnotation: "", + }, + { + description: "feature disabled + GPU -> no annotation", + pcs: createPCSWithGPU(nil), + autoMNNVLEnabled: false, + expectMutation: false, + expectedAnnotation: "", + }, + { + description: "feature disabled + no GPU -> no annotation", + pcs: createPCSWithoutGPU(nil), + autoMNNVLEnabled: false, + expectMutation: false, + expectedAnnotation: "", + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + mutated := MutateAutoMNNVL(test.pcs, test.autoMNNVLEnabled) + + assert.Equal(t, test.expectMutation, mutated) + + if test.expectedAnnotation == "" { + if test.pcs.Annotations != nil { + _, exists := test.pcs.Annotations[AnnotationAutoMNNVL] + assert.False(t, exists, "annotation should not exist") + } + } else { + assert.Equal(t, test.expectedAnnotation, test.pcs.Annotations[AnnotationAutoMNNVL]) + } + }) + } +} + +func TestValidateAutoMNNVLOnCreate(t *testing.T) { + tests := []struct { + description string + pcs *grovecorev1alpha1.PodCliqueSet + autoMNNVLEnabled bool + expectError bool + }{ + { + description: "annotation true + feature enabled -> no error", + pcs: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + autoMNNVLEnabled: true, + expectError: false, + }, + { + description: "annotation true + feature disabled -> error", + pcs: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + autoMNNVLEnabled: false, + expectError: true, + }, + { + description: "annotation false + feature disabled -> no error", + pcs: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + autoMNNVLEnabled: false, + expectError: false, + }, + { + description: "annotation false + feature enabled -> no error", + pcs: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + autoMNNVLEnabled: true, + expectError: false, + }, + { + description: "no annotation + feature disabled -> no error", + pcs: createPCSWithGPU(nil), + autoMNNVLEnabled: false, + expectError: false, + }, + { + description: "no annotation + feature enabled -> no error", + pcs: createPCSWithGPU(nil), + autoMNNVLEnabled: true, + expectError: false, + }, + { + description: "nil annotations map -> no error", + pcs: &grovecorev1alpha1.PodCliqueSet{}, + autoMNNVLEnabled: false, + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + err := ValidateAutoMNNVLOnCreate(test.pcs, test.autoMNNVLEnabled) + + if test.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), "MNNVL is not enabled") + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestValidateAutoMNNVLOnUpdate(t *testing.T) { + tests := []struct { + description string + oldPCS *grovecorev1alpha1.PodCliqueSet + newPCS *grovecorev1alpha1.PodCliqueSet + expectError bool + errorMsg string + }{ + { + description: "no annotation on both -> no error", + oldPCS: createPCSWithGPU(nil), + newPCS: createPCSWithGPU(nil), + expectError: false, + }, + { + description: "annotation unchanged true -> no error", + oldPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + newPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + expectError: false, + }, + { + description: "annotation unchanged false -> no error", + oldPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + newPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + expectError: false, + }, + { + description: "annotation added -> error", + oldPCS: createPCSWithGPU(nil), + newPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + expectError: true, + errorMsg: "cannot be added", + }, + { + description: "annotation removed -> error", + oldPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + newPCS: createPCSWithGPU(nil), + expectError: true, + errorMsg: "cannot be removed", + }, + { + description: "annotation changed true to false -> error", + oldPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + newPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + expectError: true, + errorMsg: "immutable", + }, + { + description: "annotation changed false to true -> error", + oldPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "false"}), + newPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true"}), + expectError: true, + errorMsg: "immutable", + }, + { + description: "other annotations changed but mnnvl unchanged -> no error", + oldPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true", "other": "old"}), + newPCS: createPCSWithGPU(map[string]string{AnnotationAutoMNNVL: "true", "other": "new"}), + expectError: false, + }, + } + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + err := ValidateAutoMNNVLOnUpdate(test.oldPCS, test.newPCS) + + if test.expectError { + assert.Error(t, err) + assert.Contains(t, err.Error(), test.errorMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/operator/internal/webhook/admission/pcs/defaulting/handler.go b/operator/internal/webhook/admission/pcs/defaulting/handler.go index bd092a160..8b38680d3 100644 --- a/operator/internal/webhook/admission/pcs/defaulting/handler.go +++ b/operator/internal/webhook/admission/pcs/defaulting/handler.go @@ -20,7 +20,9 @@ import ( "context" "fmt" + configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + "github.com/ai-dynamo/grove/operator/internal/mnnvl" k8sutils "github.com/ai-dynamo/grove/operator/internal/utils/kubernetes" "github.com/go-logr/logr" @@ -31,13 +33,15 @@ import ( // Handler sets default values on PodCliqueSet resources. type Handler struct { - logger logr.Logger + logger logr.Logger + networkConfig configv1alpha1.NetworkAcceleration } // NewHandler returns a new instance of defaulting webhook handler. -func NewHandler(mgr manager.Manager) *Handler { +func NewHandler(mgr manager.Manager, networkConfig configv1alpha1.NetworkAcceleration) *Handler { return &Handler{ - logger: mgr.GetLogger().WithName("webhook").WithName(Name), + logger: mgr.GetLogger().WithName("webhook").WithName(Name), + networkConfig: networkConfig, } } @@ -54,5 +58,11 @@ func (h *Handler) Default(ctx context.Context, obj runtime.Object) error { } h.logger.Info("Applying defaults", "PodCliqueSet", k8sutils.CreateObjectKeyForCreateWebhooks(pcs, req)) defaultPodCliqueSet(pcs) + + // Apply MNNVL auto-annotation if conditions are met + if mnnvl.MutateAutoMNNVL(pcs, h.networkConfig.AutoMNNVLEnabled) { + h.logger.Info("Added auto-mnnvl annotation", "PodCliqueSet", k8sutils.CreateObjectKeyForCreateWebhooks(pcs, req)) + } + return nil } diff --git a/operator/internal/webhook/admission/pcs/defaulting/handler_mnnvl_test.go b/operator/internal/webhook/admission/pcs/defaulting/handler_mnnvl_test.go new file mode 100644 index 000000000..6ced1881a --- /dev/null +++ b/operator/internal/webhook/admission/pcs/defaulting/handler_mnnvl_test.go @@ -0,0 +1,139 @@ +// /* +// Copyright 2026 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package defaulting + +import ( + "context" + "testing" + + configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" + grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + "github.com/ai-dynamo/grove/operator/internal/mnnvl" + testutils "github.com/ai-dynamo/grove/operator/test/utils" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + admissionv1 "k8s.io/api/admission/v1" + authenticationv1 "k8s.io/api/authentication/v1" + "sigs.k8s.io/controller-runtime/pkg/webhook/admission" +) + +// TestDefault_MNNVL tests the MNNVL auto-annotation behavior in the defaulting webhook. +func TestDefault_MNNVL(t *testing.T) { + tests := []struct { + description string + pcs *grovecorev1alpha1.PodCliqueSet + autoMNNVLEnabled bool + expectedAnnotation string // empty string means annotation should not exist + }{ + { + description: "feature enabled + GPU + no annotation -> adds annotation", + pcs: createPCSWithGPU(nil), + autoMNNVLEnabled: true, + expectedAnnotation: "true", + }, + { + description: "feature disabled + GPU -> no annotation added", + pcs: createPCSWithGPU(nil), + autoMNNVLEnabled: false, + expectedAnnotation: "", + }, + { + description: "feature enabled + no GPU -> no annotation added", + pcs: createPCSWithoutGPU(nil), + autoMNNVLEnabled: true, + expectedAnnotation: "", + }, + { + description: "feature enabled + GPU + existing false annotation -> unchanged", + pcs: createPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "false"}), + autoMNNVLEnabled: true, + expectedAnnotation: "false", + }, + { + description: "feature enabled + GPU + existing true annotation -> unchanged", + pcs: createPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + autoMNNVLEnabled: true, + expectedAnnotation: "true", + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + cl := testutils.NewTestClientBuilder().Build() + mgr := &testutils.FakeManager{ + Client: cl, + Scheme: cl.Scheme(), + Logger: logr.Discard(), + } + + networkConfig := configv1alpha1.NetworkAcceleration{ + AutoMNNVLEnabled: tt.autoMNNVLEnabled, + } + handler := NewHandler(mgr, networkConfig) + + ctx := admission.NewContextWithRequest(context.Background(), admission.Request{ + AdmissionRequest: admissionv1.AdmissionRequest{ + Name: "test-pcs", + Namespace: "default", + Operation: admissionv1.Create, + UserInfo: authenticationv1.UserInfo{ + Username: "test-user", + }, + }, + }) + + err := handler.Default(ctx, tt.pcs) + require.NoError(t, err) + + if tt.expectedAnnotation == "" { + if tt.pcs.Annotations != nil { + _, exists := tt.pcs.Annotations[mnnvl.AnnotationAutoMNNVL] + assert.False(t, exists, "annotation should not exist") + } + } else { + require.NotNil(t, tt.pcs.Annotations) + assert.Equal(t, tt.expectedAnnotation, tt.pcs.Annotations[mnnvl.AnnotationAutoMNNVL]) + } + }) + } +} + +// createPCSWithGPU creates a PCS with GPU using the builder. +func createPCSWithGPU(annotations map[string]string) *grovecorev1alpha1.PodCliqueSet { + return testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithAnnotations(annotations). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(testutils.NewGPUContainer("train", "nvidia/cuda:latest", 8)). + Build(), + ). + Build() +} + +// createPCSWithoutGPU creates a PCS without GPU using the builder. +func createPCSWithoutGPU(annotations map[string]string) *grovecorev1alpha1.PodCliqueSet { + return testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithAnnotations(annotations). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithContainer(testutils.NewContainer("app", "nginx:latest")). + Build(), + ). + Build() +} diff --git a/operator/internal/webhook/admission/pcs/defaulting/handler_test.go b/operator/internal/webhook/admission/pcs/defaulting/handler_test.go index b5b7a8ab3..c6ca5636e 100644 --- a/operator/internal/webhook/admission/pcs/defaulting/handler_test.go +++ b/operator/internal/webhook/admission/pcs/defaulting/handler_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" testutils "github.com/ai-dynamo/grove/operator/test/utils" @@ -46,7 +47,7 @@ func TestNewHandler(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr) + handler := NewHandler(mgr, configv1alpha1.NetworkAcceleration{}) require.NotNil(t, handler) assert.NotNil(t, handler.logger) } @@ -197,7 +198,7 @@ func TestDefault(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr) + handler := NewHandler(mgr, configv1alpha1.NetworkAcceleration{}) ctx := context.Background() if tt.setupContext != nil { diff --git a/operator/internal/webhook/admission/pcs/defaulting/register_test.go b/operator/internal/webhook/admission/pcs/defaulting/register_test.go index eee9588c5..aec5efcab 100644 --- a/operator/internal/webhook/admission/pcs/defaulting/register_test.go +++ b/operator/internal/webhook/admission/pcs/defaulting/register_test.go @@ -19,6 +19,7 @@ package defaulting import ( "testing" + configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" testutils "github.com/ai-dynamo/grove/operator/test/utils" "github.com/go-logr/logr" @@ -41,7 +42,7 @@ func TestRegisterWithManager(t *testing.T) { }) mgr.WebhookServer = server - handler := NewHandler(mgr) + handler := NewHandler(mgr, configv1alpha1.NetworkAcceleration{}) err := handler.RegisterWithManager(mgr) require.NoError(t, err) } diff --git a/operator/internal/webhook/admission/pcs/validation/handler.go b/operator/internal/webhook/admission/pcs/validation/handler.go index 91317a186..7f290c117 100644 --- a/operator/internal/webhook/admission/pcs/validation/handler.go +++ b/operator/internal/webhook/admission/pcs/validation/handler.go @@ -23,6 +23,7 @@ import ( configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" "github.com/ai-dynamo/grove/operator/internal/errors" + "github.com/ai-dynamo/grove/operator/internal/mnnvl" "github.com/go-logr/logr" admissionv1 "k8s.io/api/admission/v1" @@ -41,15 +42,17 @@ const ( // Handler is a handler for validating PodCliqueSet resources. type Handler struct { - logger logr.Logger - tasConfig configv1alpha1.TopologyAwareSchedulingConfiguration + logger logr.Logger + tasConfig configv1alpha1.TopologyAwareSchedulingConfiguration + networkConfig configv1alpha1.NetworkAcceleration } // NewHandler creates a new handler for PodCliqueSet Webhook. -func NewHandler(mgr manager.Manager, tasConfig configv1alpha1.TopologyAwareSchedulingConfiguration) *Handler { +func NewHandler(mgr manager.Manager, tasConfig configv1alpha1.TopologyAwareSchedulingConfiguration, networkConfig configv1alpha1.NetworkAcceleration) *Handler { return &Handler{ - logger: mgr.GetLogger().WithName("webhook").WithName(Name), - tasConfig: tasConfig, + logger: mgr.GetLogger().WithName("webhook").WithName(Name), + tasConfig: tasConfig, + networkConfig: networkConfig, } } @@ -66,6 +69,15 @@ func (h *Handler) ValidateCreate(ctx context.Context, obj runtime.Object) (admis allErrs = append(allErrs, v.validateTopologyConstraintsOnCreate()...) warnings, errs := v.validate() allErrs = append(allErrs, errs...) + + // Validate MNNVL annotation: reject if annotation="true" but feature is disabled + if err := mnnvl.ValidateAutoMNNVLOnCreate(pcs, h.networkConfig.AutoMNNVLEnabled); err != nil { + allErrs = append(allErrs, field.Invalid( + field.NewPath("metadata", "annotations", mnnvl.AnnotationAutoMNNVL), + pcs.Annotations[mnnvl.AnnotationAutoMNNVL], + err.Error())) + } + return warnings, allErrs.ToAggregate() } @@ -80,6 +92,12 @@ func (h *Handler) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Obj if err != nil { return nil, errors.WrapError(err, ErrValidateUpdatePodCliqueSet, string(admissionv1.Update), "failed to cast old object to PodCliqueSet") } + + // Validate MNNVL annotation immutability + if err := mnnvl.ValidateAutoMNNVLOnUpdate(oldPCS, newPCS); err != nil { + return nil, errors.WrapError(err, ErrValidateUpdatePodCliqueSet, string(admissionv1.Update), err.Error()) + } + v := newPCSValidator(newPCS, admissionv1.Update, h.tasConfig) warnings, errs := v.validate() if len(errs) > 0 { diff --git a/operator/internal/webhook/admission/pcs/validation/handler_mnnvl_test.go b/operator/internal/webhook/admission/pcs/validation/handler_mnnvl_test.go new file mode 100644 index 000000000..f99d6e8c0 --- /dev/null +++ b/operator/internal/webhook/admission/pcs/validation/handler_mnnvl_test.go @@ -0,0 +1,194 @@ +// /* +// Copyright 2026 The Grove Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// */ + +package validation + +import ( + "context" + "testing" + "time" + + configv1alpha1 "github.com/ai-dynamo/grove/operator/api/config/v1alpha1" + grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + "github.com/ai-dynamo/grove/operator/internal/mnnvl" + testutils "github.com/ai-dynamo/grove/operator/test/utils" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/utils/ptr" +) + +// TestValidateCreate_MNNVL tests the MNNVL annotation validation on create. +func TestValidateCreate_MNNVL(t *testing.T) { + tests := []struct { + description string + pcs *grovecorev1alpha1.PodCliqueSet + autoMNNVLEnabled bool + expectError bool + errorContains string + }{ + { + description: "annotation true + feature enabled -> no error", + pcs: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + autoMNNVLEnabled: true, + expectError: false, + }, + { + description: "annotation true + feature disabled -> error", + pcs: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + autoMNNVLEnabled: false, + expectError: true, + errorContains: "MNNVL is not enabled", + }, + { + description: "annotation false + feature disabled -> no error", + pcs: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "false"}), + autoMNNVLEnabled: false, + expectError: false, + }, + { + description: "no annotation + feature disabled -> no error", + pcs: createValidPCSWithGPU(nil), + autoMNNVLEnabled: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + cl := testutils.NewTestClientBuilder().Build() + mgr := &testutils.FakeManager{ + Client: cl, + Scheme: cl.Scheme(), + Logger: logr.Discard(), + } + + networkConfig := configv1alpha1.NetworkAcceleration{ + AutoMNNVLEnabled: tt.autoMNNVLEnabled, + } + handler := NewHandler(mgr, getDefaultTASConfig(), networkConfig) + + ctx := context.Background() + warnings, err := handler.ValidateCreate(ctx, tt.pcs) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + assert.Empty(t, warnings) + } + }) + } +} + +// TestValidateUpdate_MNNVL tests the MNNVL annotation immutability on update. +func TestValidateUpdate_MNNVL(t *testing.T) { + tests := []struct { + description string + oldPCS *grovecorev1alpha1.PodCliqueSet + newPCS *grovecorev1alpha1.PodCliqueSet + expectError bool + errorContains string + }{ + { + description: "annotation unchanged -> no error", + oldPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + newPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + expectError: false, + }, + { + description: "annotation added -> error", + oldPCS: createValidPCSWithGPU(nil), + newPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + expectError: true, + errorContains: "cannot be added", + }, + { + description: "annotation removed -> error", + oldPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + newPCS: createValidPCSWithGPU(nil), + expectError: true, + errorContains: "cannot be removed", + }, + { + description: "annotation changed true to false -> error", + oldPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + newPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "false"}), + expectError: true, + errorContains: "immutable", + }, + { + description: "annotation changed false to true -> error", + oldPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "false"}), + newPCS: createValidPCSWithGPU(map[string]string{mnnvl.AnnotationAutoMNNVL: "true"}), + expectError: true, + errorContains: "immutable", + }, + { + description: "no annotation on both -> no error", + oldPCS: createValidPCSWithGPU(nil), + newPCS: createValidPCSWithGPU(nil), + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + cl := testutils.NewTestClientBuilder().Build() + mgr := &testutils.FakeManager{ + Client: cl, + Scheme: cl.Scheme(), + Logger: logr.Discard(), + } + + // MNNVL validation on update doesn't depend on feature flag + handler := NewHandler(mgr, getDefaultTASConfig(), getDefaultNetworkConfig()) + + ctx := context.Background() + warnings, err := handler.ValidateUpdate(ctx, tt.oldPCS, tt.newPCS) + + if tt.expectError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + assert.Empty(t, warnings) + } + }) + } +} + +// createValidPCSWithGPU creates a fully valid PCS with GPU for validation tests. +func createValidPCSWithGPU(annotations map[string]string) *grovecorev1alpha1.PodCliqueSet { + return testutils.NewPodCliqueSetBuilder("test-pcs", "default", ""). + WithAnnotations(annotations). + WithCliqueStartupType(ptr.To(grovecorev1alpha1.CliqueStartupTypeAnyOrder)). + WithTerminationDelay(4 * time.Hour). + WithPodCliqueTemplateSpec( + testutils.NewPodCliqueTemplateSpecBuilder("worker"). + WithRoleName("worker"). + WithMinAvailable(1). + WithContainer(testutils.NewGPUContainer("train", "nvidia/cuda:latest", 8)). + Build(), + ). + Build() +} diff --git a/operator/internal/webhook/admission/pcs/validation/handler_test.go b/operator/internal/webhook/admission/pcs/validation/handler_test.go index 49866c273..ccf15f11b 100644 --- a/operator/internal/webhook/admission/pcs/validation/handler_test.go +++ b/operator/internal/webhook/admission/pcs/validation/handler_test.go @@ -47,7 +47,7 @@ func TestNewHandler(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr, getDefaultTASConfig()) + handler := NewHandler(mgr, getDefaultTASConfig(), getDefaultNetworkConfig()) require.NotNil(t, handler) assert.NotNil(t, handler.logger) } @@ -113,7 +113,7 @@ func TestValidateCreate(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr, getDefaultTASConfig()) + handler := NewHandler(mgr, getDefaultTASConfig(), getDefaultNetworkConfig()) ctx := context.Background() warnings, err := handler.ValidateCreate(ctx, tt.obj) @@ -244,7 +244,7 @@ func TestValidateUpdate(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr, getDefaultTASConfig()) + handler := NewHandler(mgr, getDefaultTASConfig(), getDefaultNetworkConfig()) ctx := context.Background() warnings, err := handler.ValidateUpdate(ctx, tt.newObj, tt.oldObj) @@ -271,7 +271,7 @@ func TestValidateDelete(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr, getDefaultTASConfig()) + handler := NewHandler(mgr, getDefaultTASConfig(), getDefaultNetworkConfig()) // Deletion validation always succeeds ctx := context.Background() @@ -382,7 +382,7 @@ func TestLogValidatorFunctionInvocation(t *testing.T) { Logger: logr.Discard(), } - handler := NewHandler(mgr, getDefaultTASConfig()) + handler := NewHandler(mgr, getDefaultTASConfig(), getDefaultNetworkConfig()) // This function doesn't return an error, but we can verify it doesn't panic assert.NotPanics(t, func() { @@ -400,3 +400,10 @@ func getDefaultTASConfig() groveconfigv1alpha1.TopologyAwareSchedulingConfigurat Enabled: false, } } + +// getDefaultNetworkConfig returns a default network configuration with MNNVL disabled. +func getDefaultNetworkConfig() groveconfigv1alpha1.NetworkAcceleration { + return groveconfigv1alpha1.NetworkAcceleration{ + AutoMNNVLEnabled: false, + } +} diff --git a/operator/internal/webhook/admission/pcs/validation/register_test.go b/operator/internal/webhook/admission/pcs/validation/register_test.go index 12f3be59f..7926b8874 100644 --- a/operator/internal/webhook/admission/pcs/validation/register_test.go +++ b/operator/internal/webhook/admission/pcs/validation/register_test.go @@ -42,7 +42,7 @@ func TestRegisterWithManager(t *testing.T) { }) mgr.WebhookServer = server - handler := NewHandler(mgr, configv1alpha1.TopologyAwareSchedulingConfiguration{}) + handler := NewHandler(mgr, configv1alpha1.TopologyAwareSchedulingConfiguration{}, configv1alpha1.NetworkAcceleration{}) err := handler.RegisterWithManager(mgr) require.NoError(t, err) } diff --git a/operator/internal/webhook/register.go b/operator/internal/webhook/register.go index 0ba84ff77..9cf378160 100644 --- a/operator/internal/webhook/register.go +++ b/operator/internal/webhook/register.go @@ -31,13 +31,13 @@ import ( ) // Register registers the webhooks with the controller manager. -func Register(mgr manager.Manager, authorizerConfig configv1alpha1.AuthorizerConfig, tasConfig configv1alpha1.TopologyAwareSchedulingConfiguration) error { - defaultingWebhook := defaulting.NewHandler(mgr) +func Register(mgr manager.Manager, authorizerConfig configv1alpha1.AuthorizerConfig, tasConfig configv1alpha1.TopologyAwareSchedulingConfiguration, networkConfig configv1alpha1.NetworkAcceleration) error { + defaultingWebhook := defaulting.NewHandler(mgr, networkConfig) slog.Info("Registering webhook with manager", "handler", defaulting.Name) if err := defaultingWebhook.RegisterWithManager(mgr); err != nil { return fmt.Errorf("failed adding %s webhook handler: %v", defaulting.Name, err) } - pcsValidatingWebhook := pcsvalidation.NewHandler(mgr, tasConfig) + pcsValidatingWebhook := pcsvalidation.NewHandler(mgr, tasConfig, networkConfig) slog.Info("Registering webhook with manager", "handler", pcsvalidation.Name) if err := pcsValidatingWebhook.RegisterWithManager(mgr); err != nil { return fmt.Errorf("failed adding %s webhook handler: %v", pcsvalidation.Name, err) diff --git a/operator/internal/webhook/register_test.go b/operator/internal/webhook/register_test.go index 449aa31e4..9c560e507 100644 --- a/operator/internal/webhook/register_test.go +++ b/operator/internal/webhook/register_test.go @@ -91,7 +91,7 @@ func TestRegisterWebhooks_WithoutAuthorizer(t *testing.T) { Enabled: false, } - err := Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}) + err := Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}, configv1alpha1.NetworkAcceleration{}) require.NoError(t, err) } @@ -120,7 +120,7 @@ func TestRegisterWebhooks_WithAuthorizerMissingEnvVar(t *testing.T) { Enabled: true, } - err = Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}) + err = Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}, configv1alpha1.NetworkAcceleration{}) require.Error(t, err) assert.Contains(t, err.Error(), constants.EnvVarServiceAccountName) } @@ -149,7 +149,7 @@ func TestRegisterWebhooks_WithAuthorizerMissingNamespaceFile(t *testing.T) { Enabled: true, } - err := Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}) + err := Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}, configv1alpha1.NetworkAcceleration{}) require.Error(t, err) assert.Contains(t, err.Error(), "error reading namespace file") } @@ -194,7 +194,7 @@ func TestRegisterWebhooks_WithAuthorizerSuccess(t *testing.T) { Enabled: true, } - err = Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}) + err = Register(mgr, authorizerConfig, configv1alpha1.TopologyAwareSchedulingConfiguration{}, configv1alpha1.NetworkAcceleration{}) // Will error because it tries to read the hardcoded namespace file path require.Error(t, err) } diff --git a/operator/test/utils/pclqTemplate.go b/operator/test/utils/pclqTemplate.go index a5b76d5be..fd22848b5 100644 --- a/operator/test/utils/pclqTemplate.go +++ b/operator/test/utils/pclqTemplate.go @@ -18,8 +18,10 @@ package utils import ( grovecorev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" + "github.com/ai-dynamo/grove/operator/internal/constants" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" ) // PodCliqueTemplateSpecBuilder is a builder for creating PodCliqueTemplateSpec objects. @@ -48,7 +50,10 @@ func NewBasicPodCliqueTemplateSpec(name string) *grovecorev1alpha1.PodCliqueTemp // Build creates a PodCliqueTemplateSpec object. func (b *PodCliqueTemplateSpecBuilder) Build() *grovecorev1alpha1.PodCliqueTemplateSpec { - b.withDefaultPodSpec() + // Only apply default PodSpec if no containers were configured + if len(b.pclqTemplateSpec.Spec.PodSpec.Containers) == 0 && len(b.pclqTemplateSpec.Spec.PodSpec.InitContainers) == 0 { + b.withDefaultPodSpec() + } return b.pclqTemplateSpec } @@ -121,6 +126,39 @@ func (b *PodCliqueTemplateSpecBuilder) WithTopologyConstraint(constraint *grovec return b } +// WithContainer adds a container to the PodSpec. +func (b *PodCliqueTemplateSpecBuilder) WithContainer(container corev1.Container) *PodCliqueTemplateSpecBuilder { + b.pclqTemplateSpec.Spec.PodSpec.Containers = append(b.pclqTemplateSpec.Spec.PodSpec.Containers, container) + return b +} + +// WithInitContainer adds an init container to the PodSpec. +func (b *PodCliqueTemplateSpecBuilder) WithInitContainer(container corev1.Container) *PodCliqueTemplateSpecBuilder { + b.pclqTemplateSpec.Spec.PodSpec.InitContainers = append(b.pclqTemplateSpec.Spec.PodSpec.InitContainers, container) + return b +} + +// NewGPUContainer creates a container with GPU resources. +func NewGPUContainer(name, image string, gpuCount int64) corev1.Container { + return corev1.Container{ + Name: name, + Image: image, + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + constants.GPUResourceName: *resource.NewQuantity(gpuCount, resource.DecimalSI), + }, + }, + } +} + +// NewContainer creates a simple container without GPU resources. +func NewContainer(name, image string) corev1.Container { + return corev1.Container{ + Name: name, + Image: image, + } +} + func (b *PodCliqueTemplateSpecBuilder) withDefaultPodSpec() *PodCliqueTemplateSpecBuilder { b.pclqTemplateSpec.Spec.PodSpec = NewPodWithBuilderWithDefaultSpec("test-name", "test-ns").Build().Spec return b diff --git a/operator/test/utils/pcs.go b/operator/test/utils/pcs.go index 70495fe36..5a87b1e68 100644 --- a/operator/test/utils/pcs.go +++ b/operator/test/utils/pcs.go @@ -148,6 +148,12 @@ func (b *PodCliqueSetBuilder) WithTopologyConstraint(constraint *grovecorev1alph return b } +// WithAnnotations sets the annotations for the PodCliqueSet. +func (b *PodCliqueSetBuilder) WithAnnotations(annotations map[string]string) *PodCliqueSetBuilder { + b.pcs.Annotations = annotations + return b +} + // Build creates a PodCliqueSet object. func (b *PodCliqueSetBuilder) Build() *grovecorev1alpha1.PodCliqueSet { return b.pcs