Skip to content

Commit 024e6cb

Browse files
committed
feat(scheduler): drive solver from generator portfolio
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent 6e28f51 commit 024e6cb

4 files changed

Lines changed: 597 additions & 54 deletions

File tree

pkg/scheduler/actions/common/solvers/job_solver.go

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"strings"
99
"time"
1010

11-
solverscenario "github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1211
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1312
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
1413
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
@@ -95,54 +94,99 @@ func (s *JobSolver) Solve(
9594
// describing why the scenario search stopped.
9695
func (s *JobSolver) SolveWithResult(
9796
ssn *framework.Session, pendingJob *podgroup_info.PodGroupInfo,
98-
) (bool, *framework.Statement, []string, *SearchResult) {
99-
state := solvingState{}
97+
) (solved bool, statement *framework.Statement, victimTaskNames []string, searchResult *SearchResult) {
10098
originalNumActiveTasks := pendingJob.GetNumActiveUsedTasks()
10199

102100
tasksToAllocate := podgroup_info.GetTasksToAllocate(pendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false)
103101
n := len(tasksToAllocate)
104102
if n == 0 {
105-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
106-
terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
103+
return false, nil, nil, terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
107104
}
108105

109106
actionBudget := s.ensureActionBudget()
110107
jobBudget := actionBudget.BeginJob()
111108
if actionBudget.Exhausted() {
112-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
113-
terminalSearchResult(SearchResultNotAttempted, false, false)
109+
return false, nil, nil, terminalSearchResult(SearchResultNotAttempted, false, false)
114110
}
115111

112+
if s.generateVictimsQueue == nil {
113+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
114+
}
115+
registrations := applicableScenarioGeneratorRegistrations(ssn, s.actionType)
116+
if len(registrations) == 0 {
117+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
118+
}
119+
120+
enteredSearch := false
121+
var lastVictimTasks []*pod_info.PodInfo
122+
var lastResult *SearchResult
123+
for _, registration := range registrations {
124+
state := solvingState{}
125+
generatorBudget := jobBudget.BeginGenerator(registration.Name)
126+
result := s.solvePendingJobWithGenerator(
127+
ssn, &state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
128+
)
129+
enteredSearch = enteredSearch || searchResultEntered(result) || resultSolved(result)
130+
lastVictimTasks = state.recordedVictimsTasks
131+
lastResult = result
132+
133+
if resultSolved(result) {
134+
solution := result.solution
135+
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
136+
jobSolved := pendingJob.IsGangSatisfied()
137+
if originalNumActiveTasks >= numActiveTasks {
138+
jobSolved = false
139+
}
140+
141+
log.InfraLogger.V(4).Infof(
142+
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
143+
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
144+
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
145+
}
146+
147+
if shouldStopSearch(result) {
148+
preserveEnteredSearch(result, enteredSearch)
149+
return false, nil, calcVictimNames(lastVictimTasks), result
150+
}
151+
}
152+
153+
if lastResult == nil {
154+
lastResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
155+
}
156+
preserveEnteredSearch(lastResult, enteredSearch)
157+
return false, nil, calcVictimNames(lastVictimTasks), lastResult
158+
}
159+
160+
func (s *JobSolver) solvePendingJobWithGenerator(
161+
ssn *framework.Session,
162+
state *solvingState,
163+
pendingJob *podgroup_info.PodGroupInfo,
164+
tasksToAllocate []*pod_info.PodInfo,
165+
jobBudget *jobSearchBudget,
166+
registration framework.ScenarioGeneratorRegistration,
167+
generatorBudget *generatorSearchBudget,
168+
) *SearchResult {
169+
n := len(tasksToAllocate)
116170
enteredSearch := false
117171
if n > 1 {
118-
maxSolvedK, searchResult := s.searchMaxSolvableK(ssn, &state, pendingJob, tasksToAllocate, jobBudget)
172+
maxSolvedK, searchResult := s.searchMaxSolvableK(
173+
ssn, state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
174+
)
119175
enteredSearch = searchResultEntered(searchResult) || maxSolvedK > 0
120176
if maxSolvedK == 0 {
121177
if searchResult == nil {
122-
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
178+
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), false)
123179
}
124180
preserveEnteredSearch(searchResult, enteredSearch)
125-
return false, nil, calcVictimNames(state.recordedVictimsTasks), searchResult
181+
return searchResult
126182
}
127183
}
128184

129-
result := s.probeAtK(ssn, &state, pendingJob, tasksToAllocate, n, jobBudget)
185+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, n, jobBudget, registration, generatorBudget)
130186
if !resultSolved(result) {
131187
preserveEnteredSearch(result, enteredSearch)
132-
return false, nil, calcVictimNames(state.recordedVictimsTasks), result
133188
}
134-
135-
solution := result.solution
136-
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
137-
jobSolved := pendingJob.IsGangSatisfied()
138-
if originalNumActiveTasks >= numActiveTasks {
139-
jobSolved = false
140-
}
141-
142-
log.InfraLogger.V(4).Infof(
143-
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
144-
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
145-
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
189+
return result
146190
}
147191

148192
// searchMaxSolvableK returns the largest k in [0, n) for which a probe at k succeeds.
@@ -156,14 +200,18 @@ func (s *JobSolver) searchMaxSolvableK(
156200
pendingJob *podgroup_info.PodGroupInfo,
157201
tasksToAllocate []*pod_info.PodInfo,
158202
jobBudget *jobSearchBudget,
203+
registration framework.ScenarioGeneratorRegistration,
204+
generatorBudget *generatorSearchBudget,
159205
) (int, *SearchResult) {
160206
n := len(tasksToAllocate)
161207
if n <= 1 {
162208
return 0, nil
163209
}
164210

165211
return searchMaxSolvableK(n, func(k int) *SearchResult {
166-
return s.tryProbeAndDiscard(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
212+
return s.tryProbeAndDiscard(
213+
ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget,
214+
)
167215
})
168216
}
169217

@@ -220,8 +268,10 @@ func (s *JobSolver) tryProbeAndDiscard(
220268
tasksToAllocate []*pod_info.PodInfo,
221269
k int,
222270
jobBudget *jobSearchBudget,
271+
registration framework.ScenarioGeneratorRegistration,
272+
generatorBudget *generatorSearchBudget,
223273
) *SearchResult {
224-
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
274+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget)
225275
if !resultSolved(result) {
226276
log.InfraLogger.V(5).Infof("No solution found for %d tasks out of %d tasks to allocate for %s",
227277
k, len(tasksToAllocate), pendingJob.Name)
@@ -246,15 +296,18 @@ func (s *JobSolver) probeAtK(
246296
tasksToAllocate []*pod_info.PodInfo,
247297
k int,
248298
jobBudget *jobSearchBudget,
299+
registration framework.ScenarioGeneratorRegistration,
300+
generatorBudget *generatorSearchBudget,
249301
) *SearchResult {
250302
pendingTasks := tasksToAllocate[:k]
251303
partialPendingJob := getPartialJobRepresentative(pendingJob, pendingTasks)
252-
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget)
304+
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget, registration, generatorBudget, k)
253305
}
254306

255307
func (s *JobSolver) solvePartialJob(
256308
ssn *framework.Session, state *solvingState, partialPendingJob *podgroup_info.PodGroupInfo,
257-
jobBudget *jobSearchBudget,
309+
jobBudget *jobSearchBudget, registration framework.ScenarioGeneratorRegistration,
310+
generatorBudget *generatorSearchBudget, probeK int,
258311
) *SearchResult {
259312
actionBudget := s.ensureActionBudget()
260313
if jobBudget == nil {
@@ -270,40 +323,33 @@ func (s *JobSolver) solvePartialJob(
270323
feasibleNodeMap[task.NodeName] = node
271324
}
272325

273-
if s.generateVictimsQueue == nil {
274-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
275-
}
276-
victimsQueue := s.generateVictimsQueue()
277-
if victimsQueue == nil {
278-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
326+
solveCtx := &SolveContext{
327+
Session: ssn,
328+
ActionType: s.actionType,
329+
PartialPendingJob: partialPendingJob,
330+
RecordedVictimsJobs: state.recordedVictimsJobs,
331+
RecordedVictimsTasks: state.recordedVictimsTasks,
332+
GenerateVictimsQueue: s.generateVictimsQueue,
333+
FeasibleNodes: feasibleNodeMap,
334+
ProbeK: probeK,
279335
}
336+
portfolio := newSingleGeneratorScenarioPortfolio(solveCtx, jobBudget, registration, generatorBudget)
280337

281-
scenarioBuilder := NewPodAccumulatedScenarioBuilder(
282-
ssn, partialPendingJob, state.recordedVictimsJobs, victimsQueue, feasibleNodeMap)
283-
284-
enteredSearch := false
285-
firstScenario := true
286338
for {
287339
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
288-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
289-
}
290-
var scenarioToSolve *solverscenario.ByNodeScenario
291-
if firstScenario {
292-
scenarioToSolve = scenarioBuilder.GetValidScenario()
293-
firstScenario = false
294-
} else {
295-
scenarioToSolve = scenarioBuilder.GetNextScenario()
340+
return terminalSearchResult(
341+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
342+
)
296343
}
344+
scenarioToSolve := portfolio.Next()
297345
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
298-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
346+
return terminalSearchResult(
347+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
348+
)
299349
}
300350
if scenarioToSolve == nil {
301-
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
302-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
303-
}
304351
break
305352
}
306-
enteredSearch = true
307353
scenarioSolver := newByPodSolver(feasibleNodeMap, s.solutionValidator, ssn.AllowConsolidatingReclaim(),
308354
s.actionType)
309355

@@ -316,7 +362,7 @@ func (s *JobSolver) solvePartialJob(
316362
}
317363
}
318364

319-
return terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
365+
return terminalSearchResult(portfolio.StopReason(), jobBudget.ReducedBudget(), portfolio.enteredSearch)
320366
}
321367

322368
func searchResultEntered(result *SearchResult) bool {

pkg/scheduler/actions/common/solvers/job_solver_result_test.go

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package solvers
55

66
import (
7+
"fmt"
78
"testing"
89
"time"
910

@@ -13,10 +14,12 @@ import (
1314

1415
kaiv1 "github.com/kai-scheduler/KAI-scheduler/pkg/apis/kai/v1"
1516
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
17+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1618
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1719
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
1820
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
1921
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
22+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
2023
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/podgroup_info"
2124
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/queue_info"
2225
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
@@ -104,12 +107,63 @@ func TestSearchMaxSolvableKSkipsSingleTaskFullProbe(t *testing.T) {
104107
)
105108
tasksToAllocate := podgroup_info.GetTasksToAllocate(pendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false)
106109

107-
maxSolvedK, result := solver.searchMaxSolvableK(ssn, &solvingState{}, pendingJob, tasksToAllocate, jobBudget)
110+
maxSolvedK, result := solver.searchMaxSolvableK(
111+
ssn, &solvingState{}, pendingJob, tasksToAllocate, jobBudget,
112+
framework.ScenarioGeneratorRegistration{}, nil,
113+
)
108114

109115
require.Equal(t, 0, maxSolvedK)
110116
require.Nil(t, result)
111117
}
112118

119+
func TestSolveWithResultRunsCompletePartialSearchForOneGeneratorBeforeNext(t *testing.T) {
120+
ssn := newGeneratorTestSession(t, map[string]int{
121+
"node-1": 1,
122+
"node-2": 1,
123+
"node-3": 1,
124+
})
125+
require.NoError(t, ssn.InitNodeScoringPool())
126+
pendingJob := addGeneratorTestPendingJob(t, ssn, 3, 10, "team-pending")
127+
setGeneratorTestMinAvailable(pendingJob, 3)
128+
victimJob, victimTasks := addGeneratorTestJob(t, ssn, 3, 20, "team-victim", "node-1", "node-2", "node-3")
129+
factoryCalls := []string{}
130+
131+
ssn.AddScenarioGenerator("first", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
132+
solveCtx := ctx.(*SolveContext)
133+
factoryCalls = append(factoryCalls, fmt.Sprintf("first:%d", solveCtx.ProbeK))
134+
return &portfolioTestGenerator{name: "first"}
135+
}, framework.Reclaim)
136+
ssn.AddScenarioGenerator("second", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
137+
solveCtx := ctx.(*SolveContext)
138+
factoryCalls = append(factoryCalls, fmt.Sprintf("second:%d", solveCtx.ProbeK))
139+
pendingTasks := podgroup_info.GetTasksToAllocate(
140+
solveCtx.PartialPendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false,
141+
)
142+
sn := scenario.NewByNodeScenario(
143+
ssn, solveCtx.PartialPendingJob, pendingTasks,
144+
unrecordedVictimsForProbe(victimTasks, solveCtx.RecordedVictimsTasks, solveCtx.ProbeK),
145+
solveCtx.RecordedVictimsJobs,
146+
)
147+
return &portfolioTestGenerator{name: "second", scenarios: []api.ScenarioInfo{sn}}
148+
}, framework.Reclaim)
149+
solver := NewJobsSolver(
150+
jobSolverResultTestFeasibleNodes(ssn),
151+
nil,
152+
generatorTestVictimsQueueFactory(ssn, victimJob),
153+
framework.Reclaim,
154+
nil,
155+
)
156+
157+
solved, statement, _, result := solver.SolveWithResult(ssn, pendingJob)
158+
if statement != nil {
159+
defer statement.Discard()
160+
}
161+
162+
require.True(t, solved)
163+
require.Equal(t, SearchResultSolved, result.Reason())
164+
require.Equal(t, []string{"first:1", "second:1", "second:2", "second:3"}, factoryCalls)
165+
}
166+
113167
func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
114168
clock := &fakeClock{now: time.Unix(0, 0)}
115169
actionBudget, err := newActionSearchBudgetWithClock(
@@ -124,6 +178,7 @@ func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
124178
)
125179
require.NoError(t, err)
126180
ssn, pendingJob := newJobSolverResultTestSession(t, 1)
181+
ssn.AddScenarioGenerator("deadline-test", NewMultiNodeGangGenerator, framework.Reclaim)
127182
solver := NewJobsSolver(
128183
nil,
129184
nil,
@@ -159,6 +214,40 @@ func TestSearchMaxSolvableKPreservesEnteredSearchAfterTerminalPartialProbe(t *te
159214
require.True(t, result.EnteredSearch())
160215
}
161216

217+
func jobSolverResultTestFeasibleNodes(ssn *framework.Session) []*node_info.NodeInfo {
218+
nodes := make([]*node_info.NodeInfo, 0, len(ssn.ClusterInfo.Nodes))
219+
for _, node := range ssn.ClusterInfo.Nodes {
220+
nodes = append(nodes, node)
221+
}
222+
return nodes
223+
}
224+
225+
func unrecordedVictimsForProbe(
226+
victimTasks []*pod_info.PodInfo, recordedVictims []*pod_info.PodInfo, probeK int,
227+
) []*pod_info.PodInfo {
228+
recordedByUID := map[common_info.PodID]struct{}{}
229+
for _, task := range recordedVictims {
230+
recordedByUID[task.UID] = struct{}{}
231+
}
232+
233+
neededVictims := probeK - len(recordedVictims)
234+
if neededVictims <= 0 {
235+
return nil
236+
}
237+
238+
selectedVictims := make([]*pod_info.PodInfo, 0, neededVictims)
239+
for _, task := range victimTasks {
240+
if _, alreadyRecorded := recordedByUID[task.UID]; alreadyRecorded {
241+
continue
242+
}
243+
selectedVictims = append(selectedVictims, task)
244+
if len(selectedVictims) == neededVictims {
245+
return selectedVictims
246+
}
247+
}
248+
return selectedVictims
249+
}
250+
162251
func TestPreserveEnteredSearchMarksTerminalResult(t *testing.T) {
163252
result := terminalSearchResult(SearchResultDeadlineExhausted, false, false)
164253

0 commit comments

Comments
 (0)