Skip to content

Commit 6b23c22

Browse files
authored
test: add TAS e2e test infrastructure and basic tests (#348)
* test: add TAS e2e test infrastructure and basic tests - Add 4-level topology hierarchy setup (zone/block/rack/host) - Add KAI Topology verification utilities - Add topology constraint verification helpers - Include 2 foundational tests: * Topology infrastructure verification * Multiple cliques with different constraints - Update dependencies to KAI Scheduler v0.13.0-rc1 - Add Makefile target for selective test execution - Add topology-test skaffold profile Signed-off-by: Ron Kahn <rkahn@nvidia.com>
1 parent a5b8e0d commit 6b23c22

File tree

14 files changed

+893
-26
lines changed

14 files changed

+893
-26
lines changed

.github/workflows/e2e-test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ jobs:
4545
test_pattern: "^Test_RU"
4646
- test_name: startup_ordering
4747
test_pattern: "^Test_SO"
48+
- test_name: Topology_Aware_Scheduling
49+
test_pattern: "^Test_TAS"
4850
name: E2E - ${{ matrix.test_name }}
4951
steps:
5052
# print runner specs so we have a record incase of failures

operator/Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,18 @@ cover-html: test-cover
8787
@echo "Coverage report generated at coverage.html"
8888

8989
# Run e2e tests
90+
# Usage: make test-e2e [TEST_PATTERN=<pattern>]
91+
# Examples:
92+
# make test-e2e # Run all tests
93+
# make test-e2e TEST_PATTERN=Test_GS # Run all gang scheduling tests
94+
# make test-e2e TEST_PATTERN=Test_GS1 # Run specific test
95+
# make test-e2e TEST_PATTERN=Test_TAS # Run all topology tests
9096
.PHONY: test-e2e
9197
test-e2e:
9298
@echo "> Preparing charts (copying CRDs)..."
9399
@$(MODULE_HACK_DIR)/prepare-charts.sh
94100
@echo "> Running e2e tests..."
95-
@cd e2e && go test -count=1 -tags=e2e ./tests/... -v -timeout 45m
101+
@cd e2e && go test -count=1 -tags=e2e ./tests/... -v -timeout 45m $(if $(TEST_PATTERN),-run $(TEST_PATTERN))
96102

97103
# Make targets for local development and testing
98104
# -------------------------------------------------------------

operator/e2e/dependencies.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,25 @@ images:
1818

1919
# Kai Scheduler components
2020
- name: ghcr.io/nvidia/kai-scheduler/admission
21-
version: v0.12.0
21+
version: v0.13.0-rc1
2222
- name: ghcr.io/nvidia/kai-scheduler/binder
23-
version: v0.12.0
23+
version: v0.13.0-rc1
2424
- name: ghcr.io/nvidia/kai-scheduler/operator
25-
version: v0.12.0
25+
version: v0.13.0-rc1
2626
- name: ghcr.io/nvidia/kai-scheduler/podgroupcontroller
27-
version: v0.12.0
27+
version: v0.13.0-rc1
2828
- name: ghcr.io/nvidia/kai-scheduler/podgrouper
29-
version: v0.12.0
29+
version: v0.13.0-rc1
3030
- name: ghcr.io/nvidia/kai-scheduler/queuecontroller
31-
version: v0.12.0
31+
version: v0.13.0-rc1
3232
- name: ghcr.io/nvidia/kai-scheduler/scheduler
33-
version: v0.12.0
33+
version: v0.13.0-rc1
3434

3535
# Helm charts used in E2E tests
3636
helmCharts:
3737
# Kai Scheduler - gang scheduling for Kubernetes
3838
kaiScheduler:
3939
releaseName: kai-scheduler
4040
chartRef: oci://ghcr.io/nvidia/kai-scheduler/kai-scheduler
41-
version: v0.12.0
41+
version: v0.13.0-rc1
4242
namespace: kai-scheduler

operator/e2e/setup/k8s_clusters.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ func InstallCoreComponents(ctx context.Context, restConfig *rest.Config, kaiConf
536536
skaffoldConfig := &SkaffoldInstallConfig{
537537
SkaffoldYAMLPath: absoluteSkaffoldYAMLPath,
538538
RestConfig: restConfig,
539-
Profiles: []string{"debug"},
539+
Profiles: []string{"topology-test"},
540540
PushRepo: fmt.Sprintf("localhost:%s", registryPort),
541541
PullRepo: fmt.Sprintf("registry:%s", registryPort),
542542
Namespace: OperatorNamespace,
@@ -570,6 +570,11 @@ func InstallCoreComponents(ctx context.Context, restConfig *rest.Config, kaiConf
570570
return err // Return the first error encountered
571571
}
572572

573+
// Apply hierarchical topology labels to worker nodes
574+
if err := applyTopologyLabels(ctx, restConfig, logger); err != nil {
575+
return fmt.Errorf("failed to apply topology labels: %w", err)
576+
}
577+
573578
logger.Debug("✅ All component installations completed successfully")
574579
return nil
575580
}

operator/e2e/setup/skaffold.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"time"
2929

3030
"github.com/ai-dynamo/grove/operator/e2e/utils"
31+
"github.com/samber/lo"
3132
"k8s.io/client-go/rest"
3233
"k8s.io/client-go/tools/clientcmd"
3334
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
@@ -153,7 +154,11 @@ func runSkaffoldBuild(ctx context.Context, absSkaffoldPath, skaffoldDir, kubecon
153154
cmd.Dir = skaffoldDir
154155

155156
// Set up environment variables
156-
cmd.Env = os.Environ()
157+
// To allow running the tests from the IDE
158+
cmd.Env = filterEnv(os.Environ(), "GOOS", "GOARCH")
159+
config.Logger.Debugf("Filtered environment variables (removed GOOS, GOARCH), kept %d vars", len(cmd.Env))
160+
cmd.Env = append(cmd.Env, "CGO_ENABLED=0")
161+
157162
cmd.Env = append(cmd.Env, fmt.Sprintf("KUBECONFIG=%s", kubeconfigPath))
158163

159164
// Add build-specific environment variables
@@ -315,3 +320,14 @@ func writeTemporaryKubeconfig(restConfig *rest.Config, logger *utils.Logger) (st
315320
logger.Debugf("📄 Wrote temporary kubeconfig to: %s", tmpPath)
316321
return tmpPath, cleanup, nil
317322
}
323+
324+
// filterEnv filters out specified environment variables from the environment
325+
func filterEnv(env []string, keysToRemove ...string) []string {
326+
filtered := lo.Filter(env, func(e string, _ int) bool {
327+
_, found := lo.Find(keysToRemove, func(key string) bool {
328+
return strings.HasPrefix(e, key+"=")
329+
})
330+
return !found
331+
})
332+
return filtered
333+
}

operator/e2e/setup/topology.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// /*
2+
// Copyright 2025 The Grove Authors.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
// */
16+
17+
package setup
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"sort"
23+
24+
"github.com/ai-dynamo/grove/operator/e2e/utils"
25+
v1 "k8s.io/api/core/v1"
26+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
27+
k8stypes "k8s.io/apimachinery/pkg/types"
28+
"k8s.io/client-go/kubernetes"
29+
"k8s.io/client-go/rest"
30+
)
31+
32+
const (
33+
// WorkerNodeLabelKey is the label key used to identify worker nodes in e2e tests.
34+
// This can be changed if infrastructure changes.
35+
WorkerNodeLabelKey = "node_role.e2e.grove.nvidia.com"
36+
// WorkerNodeLabelValue is the label value for worker node identification in e2e tests.
37+
WorkerNodeLabelValue = "agent"
38+
39+
// TopologyLabelZone is the Kubernetes label key for zone topology domain.
40+
TopologyLabelZone = "kubernetes.io/zone"
41+
// TopologyLabelBlock is the Kubernetes label key for the block topology domain.
42+
TopologyLabelBlock = "kubernetes.io/block"
43+
// TopologyLabelRack is the Kubernetes label key for the rack topology domain.
44+
TopologyLabelRack = "kubernetes.io/rack"
45+
// TopologyLabelHostname is the Kubernetes label key for the hostname topology domain.
46+
TopologyLabelHostname = "kubernetes.io/hostname"
47+
48+
// NodesPerZone is the number of nodes per zone.
49+
NodesPerZone = 28
50+
// NodesPerBlock is the number of nodes per block (28 / 2 blocks).
51+
NodesPerBlock = 14
52+
// NodesPerRack is the number of nodes per rack (28 / 4 racks).
53+
NodesPerRack = 7
54+
)
55+
56+
// GetZoneForNodeIndex returns the zone label for a given node index.
57+
// Both the index parameter and the returned zone number are 0-based.
58+
// e.g., nodes 0-27 → zone-0, nodes 28-55 → zone-1, etc.
59+
func GetZoneForNodeIndex(idx int) string {
60+
zoneNum := idx / NodesPerZone
61+
return fmt.Sprintf("zone-%d", zoneNum)
62+
}
63+
64+
// GetBlockForNodeIndex returns the block label for a given node index.
65+
// Both the index parameter and the returned block number are 0-based.
66+
// e.g., nodes 0-13 → block-0, nodes 14-27 → block-1
67+
func GetBlockForNodeIndex(idx int) string {
68+
blockNum := idx / NodesPerBlock
69+
return fmt.Sprintf("block-%d", blockNum)
70+
}
71+
72+
// GetRackForNodeIndex returns the rack label for a given node index.
73+
// Both the index parameter and the returned rack number are 0-based.
74+
// e.g., nodes 0-6 → rack-0, nodes 7-13 → rack-1, etc.
75+
func GetRackForNodeIndex(idx int) string {
76+
rackNum := idx / NodesPerRack
77+
return fmt.Sprintf("rack-%d", rackNum)
78+
}
79+
80+
// GetWorkerNodeLabelSelector returns the label selector for worker nodes in e2e tests.
81+
// Returns a formatted string "key=value" for use with Kubernetes label selectors.
82+
func GetWorkerNodeLabelSelector() string {
83+
return fmt.Sprintf("%s=%s", WorkerNodeLabelKey, WorkerNodeLabelValue)
84+
}
85+
86+
// applyTopologyLabels applies hierarchical topology labels to worker nodes in the k3d cluster.
87+
// Creates a 4-level topology hierarchy: zone -> block -> rack -> host (kubernetes.io/hostname already exists)
88+
// Distribution strategy for 28 worker nodes:
89+
// - Zone: all nodes in "zone-0"
90+
// - Block: nodes 0-13 in "block-0", nodes 14-27 in "block-1"
91+
// - Rack: 4 racks total (2 per block), 7 hosts per rack
92+
func applyTopologyLabels(ctx context.Context, restConfig *rest.Config, logger *utils.Logger) error {
93+
logger.Info("🏷️ Applying hierarchical topology labels to worker nodes...")
94+
95+
// Create clientset
96+
clientset, err := kubernetes.NewForConfig(restConfig)
97+
if err != nil {
98+
return fmt.Errorf("failed to create clientset: %w", err)
99+
}
100+
101+
// Get all worker nodes (filter by label set during cluster creation)
102+
workerLabelSelector := GetWorkerNodeLabelSelector()
103+
nodes, err := clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{
104+
LabelSelector: workerLabelSelector,
105+
})
106+
if err != nil {
107+
return fmt.Errorf("failed to list worker nodes: %w", err)
108+
}
109+
110+
if len(nodes.Items) == 0 {
111+
logger.Warn("⚠️ No worker nodes found for topology labeling")
112+
return nil
113+
}
114+
115+
sortedNodes := make([]v1.Node, len(nodes.Items))
116+
copy(sortedNodes, nodes.Items)
117+
sort.Slice(sortedNodes, func(i, j int) bool { return sortedNodes[i].Name < sortedNodes[j].Name })
118+
119+
for idx, node := range sortedNodes {
120+
topologyLabels := fmt.Sprintf(`{"metadata":{"labels":{"%s":"%s","%s":"%s","%s":"%s"}}}`,
121+
TopologyLabelZone, GetZoneForNodeIndex(idx), TopologyLabelBlock, GetBlockForNodeIndex(idx), TopologyLabelRack, GetRackForNodeIndex(idx))
122+
123+
_, err := clientset.CoreV1().Nodes().Patch(
124+
ctx,
125+
node.Name,
126+
k8stypes.StrategicMergePatchType,
127+
[]byte(topologyLabels),
128+
metav1.PatchOptions{},
129+
)
130+
if err != nil {
131+
return fmt.Errorf("failed to patch node %s with topology labels: %w", node.Name, err)
132+
}
133+
}
134+
logger.Infof("✅ Applied topology labels to %d worker nodes", len(sortedNodes))
135+
return nil
136+
}

operator/e2e/tests/rolling_update_utils.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"time"
2929

3030
grovev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1"
31+
"github.com/ai-dynamo/grove/operator/e2e/utils"
3132
corev1 "k8s.io/api/core/v1"
3233
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3334
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -85,7 +86,7 @@ func triggerPodCliqueRollingUpdate(tc TestContext, cliqueName string) error {
8586

8687
// Convert unstructured to typed PodCliqueSet
8788
var pcs grovev1alpha1.PodCliqueSet
88-
err = convertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
89+
err = utils.ConvertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
8990
if err != nil {
9091
return fmt.Errorf("failed to convert to PodCliqueSet: %w", err)
9192
}
@@ -155,7 +156,7 @@ func patchPCSWithSIGTERMIgnoringCommand(tc TestContext) error {
155156
}
156157

157158
var pcs grovev1alpha1.PodCliqueSet
158-
err = convertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
159+
err = utils.ConvertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
159160
if err != nil {
160161
return fmt.Errorf("failed to convert to PodCliqueSet: %w", err)
161162
}
@@ -210,7 +211,7 @@ func waitForRollingUpdateComplete(tc TestContext, expectedReplicas int32) error
210211
}
211212

212213
var pcs grovev1alpha1.PodCliqueSet
213-
err = convertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
214+
err = utils.ConvertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
214215
if err != nil {
215216
return false, err
216217
}
@@ -256,7 +257,7 @@ func waitForOrdinalUpdating(tc TestContext, ordinal int32) error {
256257
}
257258

258259
var pcs grovev1alpha1.PodCliqueSet
259-
err = convertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
260+
err = utils.ConvertUnstructuredToTyped(unstructuredPCS.Object, &pcs)
260261
if err != nil {
261262
return false, err
262263
}
@@ -956,7 +957,7 @@ func scalePodCliqueInPCS(tc TestContext, cliqueName string, replicas int32) erro
956957
}
957958

958959
var pcs grovev1alpha1.PodCliqueSet
959-
if err := convertUnstructuredToTyped(unstructuredPCS.Object, &pcs); err != nil {
960+
if err := utils.ConvertUnstructuredToTyped(unstructuredPCS.Object, &pcs); err != nil {
960961
return fmt.Errorf("failed to convert to PodCliqueSet: %w", err)
961962
}
962963

operator/e2e/tests/setup.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -640,15 +640,6 @@ func scalePCSGAcrossAllReplicas(tc TestContext, pcsName, pcsgName string, pcsRep
640640
return errCh
641641
}
642642

643-
// convertUnstructuredToTyped converts an unstructured map to a typed object
644-
func convertUnstructuredToTyped(u map[string]interface{}, typed interface{}) error {
645-
data, err := json.Marshal(u)
646-
if err != nil {
647-
return err
648-
}
649-
return json.Unmarshal(data, typed)
650-
}
651-
652643
// convertTypedToUnstructured converts a typed object to an unstructured object
653644
func convertTypedToUnstructured(typed interface{}) (*unstructured.Unstructured, error) {
654645
data, err := json.Marshal(typed)

0 commit comments

Comments
 (0)