Skip to content

Commit 875ce7b

Browse files
authored
feat(scheduler): add built-in scenario generators (#1744)
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent 65870a5 commit 875ce7b

10 files changed

Lines changed: 930 additions & 9 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
77
## [Unreleased]
88

99
### Added
10+
- Added built-in `NodeLocalGreedy` and `MultiNodeGang` scenario generator implementations for bounded reclaim, preempt, and consolidation search.
1011
- Added an opt-in `deviceaccess` admission plugin (`--block-nvidia-visible-devices`, config field `admission.blockNvidiaVisibleDevices`, default disabled) that (1) rejects pods overriding the `NVIDIA_VISIBLE_DEVICES` environment variable with values other than `void`/`none` (or via a `valueFrom` reference), and (2) injects `NVIDIA_VISIBLE_DEVICES=void` into containers that do not request a GPU, blocking their access to GPUs on the node.
1112
- Added support for configuring admission Pod Disruption Budget via Helm values (`admission.podDisruptionBudget`) [#1490](https://github.com/kai-scheduler/KAI-Scheduler/pull/1490) [dttung2905](https://github.com/dttung2905)
1213
- Added an opt-in `hamicore` binder plugin (depends on `gpusharing`) to write the HAMI-core GPU memory limit (`CUDA_DEVICE_MEMORY_LIMIT`) for fractional GPU pods.

pkg/scheduler/actions/common/solvers/pod_scenario_builder.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ func (asb *PodAccumulatedScenarioBuilder) GetNextScenario() *solverscenario.ByNo
102102
return asb.iterate(true)
103103
}
104104

105+
// GetValidAccumulatedScenario returns the current filter-valid accumulated outer
106+
// scenario without expanding it into sub-scenarios.
107+
func (asb *PodAccumulatedScenarioBuilder) GetValidAccumulatedScenario() *solverscenario.ByNodeScenario {
108+
return asb.iterateAccumulated(false)
109+
}
110+
111+
// GetNextAccumulatedScenario advances the victim queue and returns the next
112+
// filter-valid accumulated outer scenario without expanding it into sub-scenarios.
113+
func (asb *PodAccumulatedScenarioBuilder) GetNextAccumulatedScenario() *solverscenario.ByNodeScenario {
114+
return asb.iterateAccumulated(true)
115+
}
116+
105117
// iterate is the unified driver behind GetValidScenario / GetNextScenario.
106118
//
107119
// The pipeline runs as a single loop with three exit points:
@@ -143,6 +155,29 @@ func (asb *PodAccumulatedScenarioBuilder) iterate(advanceFirst bool) *solverscen
143155
}
144156
}
145157

158+
func (asb *PodAccumulatedScenarioBuilder) iterateAccumulated(advanceFirst bool) *solverscenario.ByNodeScenario {
159+
if asb.lastScenario == nil {
160+
return nil
161+
}
162+
needAdvance := advanceFirst
163+
for {
164+
if needAdvance {
165+
if asb.victimsJobsQueue.IsEmpty() {
166+
return nil
167+
}
168+
if !asb.addNextPotentialVictims() {
169+
continue
170+
}
171+
}
172+
needAdvance = true
173+
174+
if !asb.outerScenarioValid() {
175+
continue
176+
}
177+
return asb.lastScenario
178+
}
179+
}
180+
146181
// nextFromSubEmitter drains the active sub-scenario emitter (if any) by one. Returns
147182
// nil and clears the emitter when it is exhausted, so callers fall through to outer
148183
// accumulation.

pkg/scheduler/actions/common/solvers/scenario_generator.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type SolveContext struct {
1717
PartialPendingJob *podgroup_info.PodGroupInfo
1818
RecordedVictimsJobs []*podgroup_info.PodGroupInfo
1919
RecordedVictimsTasks []*pod_info.PodInfo
20+
GenerateVictimsQueue GenerateVictimsQueue
2021
VictimsQueue *utils.JobsOrderByQueues
2122
FeasibleNodes map[string]*node_info.NodeInfo
2223
ProbeK int
@@ -26,10 +27,14 @@ func (ctx *SolveContext) Action() framework.ActionType {
2627
return ctx.ActionType
2728
}
2829

29-
func NewNodeLocalGreedyGenerator(_ framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
30-
return nil
31-
}
30+
// ValidateScenarioGeneratorContext extracts the solver context required by scenario generator plugins.
31+
func ValidateScenarioGeneratorContext(ctx framework.ScenarioGeneratorContext) (*SolveContext, GenerateVictimsQueue, bool) {
32+
solveCtx, ok := ctx.(*SolveContext)
33+
if !ok || solveCtx == nil || solveCtx.Session == nil || solveCtx.Session.ClusterInfo == nil ||
34+
solveCtx.Session.ClusterInfo.Nodes == nil || solveCtx.Session.ClusterInfo.PodGroupInfos == nil ||
35+
solveCtx.PartialPendingJob == nil || solveCtx.FeasibleNodes == nil || solveCtx.GenerateVictimsQueue == nil {
36+
return nil, nil, false
37+
}
3238

33-
func NewMultiNodeGangGenerator(_ framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
34-
return nil
39+
return solveCtx, solveCtx.GenerateVictimsQueue, true
3540
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package multinodegang
5+
6+
import (
7+
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
8+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers"
9+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
10+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
11+
)
12+
13+
type multiNodeGangGenerator struct {
14+
builder *solvers.PodAccumulatedScenarioBuilder
15+
first bool
16+
}
17+
18+
func NewMultiNodeGangGenerator(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
19+
solveCtx, generateVictimsQueue, ok := solvers.ValidateScenarioGeneratorContext(ctx)
20+
if !ok {
21+
return nil
22+
}
23+
victimsQueue := generateVictimsQueue()
24+
if victimsQueue == nil {
25+
return nil
26+
}
27+
28+
return &multiNodeGangGenerator{
29+
builder: solvers.NewPodAccumulatedScenarioBuilder(
30+
solveCtx.Session,
31+
solveCtx.PartialPendingJob,
32+
solveCtx.RecordedVictimsJobs,
33+
victimsQueue,
34+
solveCtx.FeasibleNodes,
35+
),
36+
first: true,
37+
}
38+
}
39+
40+
func (g *multiNodeGangGenerator) Name() string {
41+
return constants.GeneratorMultiNodeGang
42+
}
43+
44+
func (g *multiNodeGangGenerator) Next() api.ScenarioInfo {
45+
if g.first {
46+
g.first = false
47+
return g.builder.GetValidScenario()
48+
}
49+
return g.builder.GetNextScenario()
50+
}

pkg/scheduler/plugins/multinodegang/multinodegang.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ package multinodegang
55

66
import (
77
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
8-
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers"
98
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
109
)
1110

@@ -22,7 +21,7 @@ func (p *multiNodeGangPlugin) Name() string {
2221
}
2322

2423
func (p *multiNodeGangPlugin) OnSessionOpen(ssn *framework.Session) {
25-
addScenarioGenerator(ssn, constants.GeneratorMultiNodeGang, solvers.NewMultiNodeGangGenerator)
24+
addScenarioGenerator(ssn, constants.GeneratorMultiNodeGang, NewMultiNodeGangGenerator)
2625
}
2726

2827
func (p *multiNodeGangPlugin) OnSessionClose(_ *framework.Session) {}

pkg/scheduler/plugins/multinodegang/multinodegang_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ func TestMultiNodeGangPluginRegistersMultiNodeGangGenerator(t *testing.T) {
2525
require.Contains(t, ssn.ScenarioGeneratorRegistrations[0].Actions, framework.Preempt)
2626
require.Contains(t, ssn.ScenarioGeneratorRegistrations[0].Actions, framework.Consolidation)
2727
}
28+
29+
func TestMultiNodeGangGeneratorConstructorLivesInPluginPackage(t *testing.T) {
30+
require.Nil(t, NewMultiNodeGangGenerator(nil))
31+
}
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
// Copyright 2025 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package nodelocalgreedy
5+
6+
import (
7+
"sort"
8+
"strings"
9+
10+
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
11+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers"
12+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
13+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
14+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
15+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
16+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/podgroup_info"
17+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
18+
)
19+
20+
type nodeLocalGreedyGenerator struct {
21+
solveCtx *solvers.SolveContext
22+
generateVictimsQueue solvers.GenerateVictimsQueue
23+
builder *solvers.PodAccumulatedScenarioBuilder
24+
scenarios []*scenario.ByNodeScenario
25+
shouldAdvanceAccumulatedScenario bool
26+
}
27+
28+
func NewNodeLocalGreedyGenerator(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
29+
solveCtx, generateVictimsQueue, ok := solvers.ValidateScenarioGeneratorContext(ctx)
30+
if !ok {
31+
return nil
32+
}
33+
return &nodeLocalGreedyGenerator{
34+
solveCtx: solveCtx,
35+
generateVictimsQueue: generateVictimsQueue,
36+
}
37+
}
38+
39+
func (g *nodeLocalGreedyGenerator) Name() string {
40+
return constants.GeneratorNodeLocalGreedy
41+
}
42+
43+
func (g *nodeLocalGreedyGenerator) Next() api.ScenarioInfo {
44+
if !g.ensureBuilder() {
45+
return nil
46+
}
47+
for {
48+
if sn := g.popScenario(); sn != nil {
49+
return sn
50+
}
51+
accumulated := g.nextValidAccumulatedScenario()
52+
if accumulated == nil {
53+
return nil
54+
}
55+
g.scenarios = nodeLocalScenarios(g.solveCtx.Session, accumulated)
56+
}
57+
}
58+
59+
func (g *nodeLocalGreedyGenerator) ensureBuilder() bool {
60+
if g.builder != nil {
61+
return true
62+
}
63+
victimsQueue := g.generateVictimsQueue()
64+
if victimsQueue == nil {
65+
return false
66+
}
67+
g.builder = solvers.NewPodAccumulatedScenarioBuilder(
68+
g.solveCtx.Session,
69+
g.solveCtx.PartialPendingJob,
70+
g.solveCtx.RecordedVictimsJobs,
71+
victimsQueue,
72+
g.solveCtx.FeasibleNodes,
73+
)
74+
return true
75+
}
76+
77+
func (g *nodeLocalGreedyGenerator) popScenario() *scenario.ByNodeScenario {
78+
if len(g.scenarios) == 0 {
79+
return nil
80+
}
81+
sn := g.scenarios[0]
82+
g.scenarios = g.scenarios[1:]
83+
return sn
84+
}
85+
86+
func (g *nodeLocalGreedyGenerator) nextValidAccumulatedScenario() *scenario.ByNodeScenario {
87+
if g.shouldAdvanceAccumulatedScenario {
88+
return g.builder.GetNextAccumulatedScenario()
89+
}
90+
g.shouldAdvanceAccumulatedScenario = true
91+
92+
return g.builder.GetValidAccumulatedScenario()
93+
}
94+
95+
func nodeLocalScenarios(session *framework.Session, base *scenario.ByNodeScenario) []*scenario.ByNodeScenario {
96+
if base == nil {
97+
return nil
98+
}
99+
if len(base.PotentialVictimsTasks()) == 0 {
100+
if len(base.RecordedVictimsTasks()) == 0 {
101+
return nil
102+
}
103+
return []*scenario.ByNodeScenario{base}
104+
}
105+
106+
var scenarios []*scenario.ByNodeScenario
107+
seen := map[string]struct{}{}
108+
for _, nodeName := range nodeNamesOfJob(base.LatestPotentialVictim()) {
109+
victimTasks := base.VictimsTasksFromNodes([]string{nodeName})
110+
if len(victimTasks) == 0 {
111+
continue
112+
}
113+
key := victimUIDSetKey(victimTasks)
114+
if _, found := seen[key]; found {
115+
continue
116+
}
117+
seen[key] = struct{}{}
118+
sn := scenario.NewByNodeScenario(
119+
session,
120+
base.GetPreemptor(),
121+
base.PendingTasks(),
122+
nil,
123+
base.RecordedVictimsJobs(),
124+
)
125+
addPotentialVictimsGroupedByJob(sn, victimTasks)
126+
scenarios = append(scenarios, sn)
127+
}
128+
return scenarios
129+
}
130+
131+
func nodeNamesOfJob(job *podgroup_info.PodGroupInfo) []string {
132+
if job == nil {
133+
return nil
134+
}
135+
seen := map[string]struct{}{}
136+
for _, task := range job.GetAllPodsMap() {
137+
if task.NodeName == "" {
138+
continue
139+
}
140+
seen[task.NodeName] = struct{}{}
141+
}
142+
nodeNames := make([]string, 0, len(seen))
143+
for nodeName := range seen {
144+
nodeNames = append(nodeNames, nodeName)
145+
}
146+
sort.Strings(nodeNames)
147+
return nodeNames
148+
}
149+
150+
func victimUIDSetKey(tasks []*pod_info.PodInfo) string {
151+
uids := make([]string, 0, len(tasks))
152+
for _, task := range tasks {
153+
uids = append(uids, string(task.UID))
154+
}
155+
sort.Strings(uids)
156+
return strings.Join(uids, "\x00")
157+
}
158+
159+
func addPotentialVictimsGroupedByJob(sn *scenario.ByNodeScenario, tasks []*pod_info.PodInfo) {
160+
groupedTasks := map[common_info.PodGroupID][]*pod_info.PodInfo{}
161+
var jobOrder []common_info.PodGroupID
162+
for _, task := range tasks {
163+
if _, found := groupedTasks[task.Job]; !found {
164+
jobOrder = append(jobOrder, task.Job)
165+
}
166+
groupedTasks[task.Job] = append(groupedTasks[task.Job], task)
167+
}
168+
for _, jobID := range jobOrder {
169+
sn.AddPotentialVictimsTasks(groupedTasks[jobID])
170+
}
171+
}

0 commit comments

Comments
 (0)