Skip to content

Commit 79a0fa4

Browse files
committed
test(scheduler): keep portfolio solver tests local
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent eb8f991 commit 79a0fa4

2 files changed

Lines changed: 147 additions & 1 deletion

File tree

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright 2026 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package solvers
5+
6+
import (
7+
"strconv"
8+
"testing"
9+
10+
"github.com/stretchr/testify/require"
11+
"go.uber.org/mock/gomock"
12+
v1 "k8s.io/api/core/v1"
13+
"k8s.io/apimachinery/pkg/api/resource"
14+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
15+
"k8s.io/utils/ptr"
16+
17+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
18+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
19+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
20+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
21+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_affinity"
22+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
23+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/podgroup_info"
24+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/queue_info"
25+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/resource_info"
26+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
27+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/scheduler_util"
28+
)
29+
30+
func newGeneratorTestSession(t *testing.T, nodeGPUs map[string]int) *framework.Session {
31+
t.Helper()
32+
33+
defaultQueue := createQueue("default")
34+
defaultQueue.ParentQueue = ""
35+
36+
return &framework.Session{
37+
ClusterInfo: &api.ClusterInfo{
38+
PodGroupInfos: map[common_info.PodGroupID]*podgroup_info.PodGroupInfo{},
39+
Queues: map[common_info.QueueID]*queue_info.QueueInfo{
40+
defaultQueue.UID: defaultQueue,
41+
},
42+
Nodes: newGeneratorTestNodes(t, nodeGPUs),
43+
},
44+
}
45+
}
46+
47+
func newGeneratorTestNodes(t *testing.T, nodeGPUs map[string]int) map[string]*node_info.NodeInfo {
48+
t.Helper()
49+
50+
resourceLists := make([]v1.ResourceList, 0, len(nodeGPUs))
51+
for _, gpus := range nodeGPUs {
52+
resourceLists = append(resourceLists, generatorTestNodeResources(gpus))
53+
}
54+
vectorMap := resource_info.BuildResourceVectorMap(resourceLists)
55+
56+
nodes := map[string]*node_info.NodeInfo{}
57+
for name, gpus := range nodeGPUs {
58+
controller := gomock.NewController(t)
59+
nodePodAffinityInfo := pod_affinity.NewMockNodePodAffinityInfo(controller)
60+
nodePodAffinityInfo.EXPECT().AddPod(gomock.Any()).AnyTimes()
61+
nodePodAffinityInfo.EXPECT().RemovePod(gomock.Any()).AnyTimes()
62+
63+
node := &v1.Node{
64+
ObjectMeta: metav1.ObjectMeta{Name: name},
65+
Status: v1.NodeStatus{
66+
Allocatable: generatorTestNodeResources(gpus),
67+
Capacity: generatorTestNodeResources(gpus),
68+
},
69+
}
70+
nodes[name] = node_info.NewNodeInfo(node, nodePodAffinityInfo, vectorMap)
71+
}
72+
return nodes
73+
}
74+
75+
func generatorTestNodeResources(gpus int) v1.ResourceList {
76+
return v1.ResourceList{
77+
resource_info.GPUResourceName: resource.MustParse(strconv.Itoa(gpus)),
78+
v1.ResourcePods: resource.MustParse("100"),
79+
}
80+
}
81+
82+
func addGeneratorTestPendingJob(
83+
t *testing.T, ssn *framework.Session, tasksPerJob int, jobID int, queueName string,
84+
) *podgroup_info.PodGroupInfo {
85+
t.Helper()
86+
87+
job, _ := createJobWithTasks(tasksPerJob, jobID, queueName, v1.PodPending, []v1.ResourceRequirements{requireOneGPU()})
88+
addGeneratorTestQueue(ssn, queueName)
89+
ssn.ClusterInfo.PodGroupInfos[job.UID] = job
90+
return job
91+
}
92+
93+
func addGeneratorTestJob(
94+
t *testing.T, ssn *framework.Session, tasksPerJob int, jobID int, queueName string, nodeNames ...string,
95+
) (*podgroup_info.PodGroupInfo, []*pod_info.PodInfo) {
96+
t.Helper()
97+
98+
job, tasks := createJobWithTasks(tasksPerJob, jobID, queueName, v1.PodRunning, []v1.ResourceRequirements{requireOneGPU()})
99+
addGeneratorTestQueue(ssn, queueName)
100+
ssn.ClusterInfo.PodGroupInfos[job.UID] = job
101+
102+
for index, task := range tasks {
103+
nodeName := nodeNames[index%len(nodeNames)]
104+
task.NodeName = nodeName
105+
task.Pod.Spec.NodeName = nodeName
106+
require.NoError(t, ssn.ClusterInfo.Nodes[nodeName].AddTask(task))
107+
}
108+
return job, tasks
109+
}
110+
111+
func addGeneratorTestQueue(ssn *framework.Session, queueName string) {
112+
queue := createQueue(queueName)
113+
ssn.ClusterInfo.Queues[queue.UID] = queue
114+
}
115+
116+
func setGeneratorTestMinAvailable(job *podgroup_info.PodGroupInfo, minAvailable int) {
117+
for _, podSet := range job.GetAllPodSets() {
118+
podSet.SetMinAvailable(int32(minAvailable))
119+
}
120+
job.PodGroup.Spec.MinMember = ptr.To(int32(minAvailable))
121+
}
122+
123+
func generatorTestVictimsQueue(
124+
ssn *framework.Session, jobs ...*podgroup_info.PodGroupInfo,
125+
) *utils.JobsOrderByQueues {
126+
victimsQueue := utils.NewJobsOrderByQueues(ssn, utils.JobsOrderInitOptions{
127+
VictimQueue: true,
128+
MaxJobsQueueDepth: scheduler_util.QueueCapacityInfinite,
129+
})
130+
for _, job := range jobs {
131+
victimsQueue.PushJob(job)
132+
}
133+
return &victimsQueue
134+
}
135+
136+
func generatorTestVictimsQueueFactory(
137+
ssn *framework.Session, jobs ...*podgroup_info.PodGroupInfo,
138+
) GenerateVictimsQueue {
139+
return func() *utils.JobsOrderByQueues {
140+
return generatorTestVictimsQueue(ssn, jobs...)
141+
}
142+
}

pkg/scheduler/actions/common/solvers/job_solver_result_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,11 @@ func TestSolveWithResultReportsDeadlineWhenBudgetExhaustsDuringScenarioSearch(t
135135
nil, resource_info.NewResourceVectorMap(),
136136
)
137137
ssn.ClusterInfo.Nodes[node.Name] = node
138-
ssn.AddScenarioGenerator("deadline-test", NewMultiNodeGangGenerator, framework.Reclaim)
138+
ssn.AddScenarioGenerator("deadline-test", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
139+
solveCtx := ctx.(*SolveContext)
140+
solveCtx.GenerateVictimsQueue()
141+
return &portfolioTestGenerator{name: "deadline-test"}
142+
}, framework.Reclaim)
139143
solver := NewJobsSolver(
140144
[]*node_info.NodeInfo{node},
141145
nil,

0 commit comments

Comments
 (0)