Skip to content

Commit 696eb4a

Browse files
committed
feat(scheduler): drive solver from generator portfolio
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent 924e997 commit 696eb4a

4 files changed

Lines changed: 592 additions & 54 deletions

File tree

pkg/scheduler/actions/common/solvers/job_solver.go

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"strings"
99
"time"
1010

11-
solverscenario "github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1211
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1312
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
1413
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
@@ -95,54 +94,99 @@ func (s *JobSolver) Solve(
9594
// describing why the scenario search stopped.
9695
func (s *JobSolver) SolveWithResult(
9796
ssn *framework.Session, pendingJob *podgroup_info.PodGroupInfo,
98-
) (bool, *framework.Statement, []string, *SearchResult) {
99-
state := solvingState{}
97+
) (solved bool, statement *framework.Statement, victimTaskNames []string, searchResult *SearchResult) {
10098
originalNumActiveTasks := pendingJob.GetNumActiveUsedTasks()
10199

102100
tasksToAllocate := podgroup_info.GetTasksToAllocate(pendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false)
103101
n := len(tasksToAllocate)
104102
if n == 0 {
105-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
106-
terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
103+
return false, nil, nil, terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
107104
}
108105

109106
actionBudget := s.ensureActionBudget()
110107
jobBudget := actionBudget.BeginJob()
111108
if actionBudget.Exhausted() {
112-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
113-
terminalSearchResult(SearchResultNotAttempted, false, false)
109+
return false, nil, nil, terminalSearchResult(SearchResultNotAttempted, false, false)
114110
}
115111

112+
if s.generateVictimsQueue == nil {
113+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
114+
}
115+
registrations := applicableScenarioGeneratorRegistrations(ssn, s.actionType)
116+
if len(registrations) == 0 {
117+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
118+
}
119+
120+
enteredSearch := false
121+
var lastVictimTasks []*pod_info.PodInfo
122+
var lastResult *SearchResult
123+
for _, registration := range registrations {
124+
state := solvingState{}
125+
generatorBudget := jobBudget.BeginGenerator(registration.Name)
126+
result := s.solvePendingJobWithGenerator(
127+
ssn, &state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
128+
)
129+
enteredSearch = enteredSearch || searchResultEntered(result) || resultSolved(result)
130+
lastVictimTasks = state.recordedVictimsTasks
131+
lastResult = result
132+
133+
if resultSolved(result) {
134+
solution := result.solution
135+
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
136+
jobSolved := pendingJob.IsGangSatisfied()
137+
if originalNumActiveTasks >= numActiveTasks {
138+
jobSolved = false
139+
}
140+
141+
log.InfraLogger.V(4).Infof(
142+
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
143+
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
144+
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
145+
}
146+
147+
if shouldStopSearch(result) {
148+
preserveEnteredSearch(result, enteredSearch)
149+
return false, nil, calcVictimNames(lastVictimTasks), result
150+
}
151+
}
152+
153+
if lastResult == nil {
154+
lastResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
155+
}
156+
preserveEnteredSearch(lastResult, enteredSearch)
157+
return false, nil, calcVictimNames(lastVictimTasks), lastResult
158+
}
159+
160+
func (s *JobSolver) solvePendingJobWithGenerator(
161+
ssn *framework.Session,
162+
state *solvingState,
163+
pendingJob *podgroup_info.PodGroupInfo,
164+
tasksToAllocate []*pod_info.PodInfo,
165+
jobBudget *jobSearchBudget,
166+
registration framework.ScenarioGeneratorRegistration,
167+
generatorBudget *generatorSearchBudget,
168+
) *SearchResult {
169+
n := len(tasksToAllocate)
116170
enteredSearch := false
117171
if n > 1 {
118-
maxSolvedK, searchResult := s.searchMaxSolvableK(ssn, &state, pendingJob, tasksToAllocate, jobBudget)
172+
maxSolvedK, searchResult := s.searchMaxSolvableK(
173+
ssn, state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
174+
)
119175
enteredSearch = searchResultEntered(searchResult) || maxSolvedK > 0
120176
if maxSolvedK == 0 {
121177
if searchResult == nil {
122-
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
178+
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), false)
123179
}
124180
preserveEnteredSearch(searchResult, enteredSearch)
125-
return false, nil, calcVictimNames(state.recordedVictimsTasks), searchResult
181+
return searchResult
126182
}
127183
}
128184

129-
result := s.probeAtK(ssn, &state, pendingJob, tasksToAllocate, n, jobBudget)
185+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, n, jobBudget, registration, generatorBudget)
130186
if !resultSolved(result) {
131187
preserveEnteredSearch(result, enteredSearch)
132-
return false, nil, calcVictimNames(state.recordedVictimsTasks), result
133188
}
134-
135-
solution := result.solution
136-
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
137-
jobSolved := pendingJob.IsGangSatisfied()
138-
if originalNumActiveTasks >= numActiveTasks {
139-
jobSolved = false
140-
}
141-
142-
log.InfraLogger.V(4).Infof(
143-
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
144-
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
145-
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
189+
return result
146190
}
147191

148192
// searchMaxSolvableK returns the largest k in [0, n) for which a probe at k succeeds.
@@ -156,14 +200,18 @@ func (s *JobSolver) searchMaxSolvableK(
156200
pendingJob *podgroup_info.PodGroupInfo,
157201
tasksToAllocate []*pod_info.PodInfo,
158202
jobBudget *jobSearchBudget,
203+
registration framework.ScenarioGeneratorRegistration,
204+
generatorBudget *generatorSearchBudget,
159205
) (int, *SearchResult) {
160206
n := len(tasksToAllocate)
161207
if n <= 1 {
162208
return 0, nil
163209
}
164210

165211
return searchMaxSolvableK(n, func(k int) *SearchResult {
166-
return s.tryProbeAndDiscard(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
212+
return s.tryProbeAndDiscard(
213+
ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget,
214+
)
167215
})
168216
}
169217

@@ -220,8 +268,10 @@ func (s *JobSolver) tryProbeAndDiscard(
220268
tasksToAllocate []*pod_info.PodInfo,
221269
k int,
222270
jobBudget *jobSearchBudget,
271+
registration framework.ScenarioGeneratorRegistration,
272+
generatorBudget *generatorSearchBudget,
223273
) *SearchResult {
224-
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
274+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget)
225275
if !resultSolved(result) {
226276
log.InfraLogger.V(5).Infof("No solution found for %d tasks out of %d tasks to allocate for %s",
227277
k, len(tasksToAllocate), pendingJob.Name)
@@ -246,15 +296,18 @@ func (s *JobSolver) probeAtK(
246296
tasksToAllocate []*pod_info.PodInfo,
247297
k int,
248298
jobBudget *jobSearchBudget,
299+
registration framework.ScenarioGeneratorRegistration,
300+
generatorBudget *generatorSearchBudget,
249301
) *SearchResult {
250302
pendingTasks := tasksToAllocate[:k]
251303
partialPendingJob := getPartialJobRepresentative(pendingJob, pendingTasks)
252-
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget)
304+
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget, registration, generatorBudget, k)
253305
}
254306

255307
func (s *JobSolver) solvePartialJob(
256308
ssn *framework.Session, state *solvingState, partialPendingJob *podgroup_info.PodGroupInfo,
257-
jobBudget *jobSearchBudget,
309+
jobBudget *jobSearchBudget, registration framework.ScenarioGeneratorRegistration,
310+
generatorBudget *generatorSearchBudget, probeK int,
258311
) *SearchResult {
259312
actionBudget := s.ensureActionBudget()
260313
if jobBudget == nil {
@@ -270,40 +323,33 @@ func (s *JobSolver) solvePartialJob(
270323
feasibleNodeMap[task.NodeName] = node
271324
}
272325

273-
if s.generateVictimsQueue == nil {
274-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
275-
}
276-
victimsQueue := s.generateVictimsQueue()
277-
if victimsQueue == nil {
278-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
326+
solveCtx := &SolveContext{
327+
Session: ssn,
328+
ActionType: s.actionType,
329+
PartialPendingJob: partialPendingJob,
330+
RecordedVictimsJobs: state.recordedVictimsJobs,
331+
RecordedVictimsTasks: state.recordedVictimsTasks,
332+
GenerateVictimsQueue: s.generateVictimsQueue,
333+
FeasibleNodes: feasibleNodeMap,
334+
ProbeK: probeK,
279335
}
336+
portfolio := newSingleGeneratorScenarioPortfolio(solveCtx, jobBudget, registration, generatorBudget)
280337

281-
scenarioBuilder := NewPodAccumulatedScenarioBuilder(
282-
ssn, partialPendingJob, state.recordedVictimsJobs, victimsQueue, feasibleNodeMap)
283-
284-
enteredSearch := false
285-
firstScenario := true
286338
for {
287339
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
288-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
289-
}
290-
var scenarioToSolve *solverscenario.ByNodeScenario
291-
if firstScenario {
292-
scenarioToSolve = scenarioBuilder.GetValidScenario()
293-
firstScenario = false
294-
} else {
295-
scenarioToSolve = scenarioBuilder.GetNextScenario()
340+
return terminalSearchResult(
341+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
342+
)
296343
}
344+
scenarioToSolve := portfolio.Next()
297345
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
298-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
346+
return terminalSearchResult(
347+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
348+
)
299349
}
300350
if scenarioToSolve == nil {
301-
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
302-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
303-
}
304351
break
305352
}
306-
enteredSearch = true
307353
scenarioSolver := newByPodSolver(feasibleNodeMap, s.solutionValidator, ssn.AllowConsolidatingReclaim(),
308354
s.actionType)
309355

@@ -316,7 +362,7 @@ func (s *JobSolver) solvePartialJob(
316362
}
317363
}
318364

319-
return terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
365+
return terminalSearchResult(portfolio.StopReason(), jobBudget.ReducedBudget(), portfolio.enteredSearch)
320366
}
321367

322368
func searchResultEntered(result *SearchResult) bool {

pkg/scheduler/actions/common/solvers/job_solver_result_test.go

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,20 @@
44
package solvers
55

66
import (
7+
"fmt"
78
"testing"
89
"time"
910

1011
"github.com/stretchr/testify/require"
1112
v1 "k8s.io/api/core/v1"
1213

1314
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
15+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1416
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1517
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
1618
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
1719
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
20+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
1821
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/podgroup_info"
1922
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/queue_info"
2023
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/conf"
@@ -103,12 +106,63 @@ func TestSearchMaxSolvableKSkipsSingleTaskFullProbe(t *testing.T) {
103106
)
104107
tasksToAllocate := podgroup_info.GetTasksToAllocate(pendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false)
105108

106-
maxSolvedK, result := solver.searchMaxSolvableK(ssn, &solvingState{}, pendingJob, tasksToAllocate, jobBudget)
109+
maxSolvedK, result := solver.searchMaxSolvableK(
110+
ssn, &solvingState{}, pendingJob, tasksToAllocate, jobBudget,
111+
framework.ScenarioGeneratorRegistration{}, nil,
112+
)
107113

108114
require.Equal(t, 0, maxSolvedK)
109115
require.Nil(t, result)
110116
}
111117

118+
func TestSolveWithResultRunsCompletePartialSearchForOneGeneratorBeforeNext(t *testing.T) {
119+
ssn := newGeneratorTestSession(t, map[string]int{
120+
"node-1": 1,
121+
"node-2": 1,
122+
"node-3": 1,
123+
})
124+
require.NoError(t, ssn.InitNodeScoringPool())
125+
pendingJob := addGeneratorTestPendingJob(t, ssn, 3, 10, "team-pending")
126+
setGeneratorTestMinAvailable(pendingJob, 3)
127+
victimJob, victimTasks := addGeneratorTestJob(t, ssn, 3, 20, "team-victim", "node-1", "node-2", "node-3")
128+
factoryCalls := []string{}
129+
130+
ssn.AddScenarioGenerator("first", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
131+
solveCtx := ctx.(*SolveContext)
132+
factoryCalls = append(factoryCalls, fmt.Sprintf("first:%d", solveCtx.ProbeK))
133+
return &portfolioTestGenerator{name: "first"}
134+
}, framework.Reclaim)
135+
ssn.AddScenarioGenerator("second", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
136+
solveCtx := ctx.(*SolveContext)
137+
factoryCalls = append(factoryCalls, fmt.Sprintf("second:%d", solveCtx.ProbeK))
138+
pendingTasks := podgroup_info.GetTasksToAllocate(
139+
solveCtx.PartialPendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false,
140+
)
141+
sn := scenario.NewByNodeScenario(
142+
ssn, solveCtx.PartialPendingJob, pendingTasks,
143+
unrecordedVictimsForProbe(victimTasks, solveCtx.RecordedVictimsTasks, solveCtx.ProbeK),
144+
solveCtx.RecordedVictimsJobs,
145+
)
146+
return &portfolioTestGenerator{name: "second", scenarios: []api.ScenarioInfo{sn}}
147+
}, framework.Reclaim)
148+
solver := NewJobsSolver(
149+
jobSolverResultTestFeasibleNodes(ssn),
150+
nil,
151+
generatorTestVictimsQueueFactory(ssn, victimJob),
152+
framework.Reclaim,
153+
nil,
154+
)
155+
156+
solved, statement, _, result := solver.SolveWithResult(ssn, pendingJob)
157+
if statement != nil {
158+
defer statement.Discard()
159+
}
160+
161+
require.True(t, solved)
162+
require.Equal(t, SearchResultSolved, result.Reason())
163+
require.Equal(t, []string{"first:1", "second:1", "second:2", "second:3"}, factoryCalls)
164+
}
165+
112166
func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
113167
clock := &fakeClock{now: time.Unix(0, 0)}
114168
actionBudget, err := newActionSearchBudgetWithClock(
@@ -123,6 +177,7 @@ func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
123177
)
124178
require.NoError(t, err)
125179
ssn, pendingJob := newJobSolverResultTestSession(t, 1)
180+
ssn.AddScenarioGenerator("deadline-test", NewMultiNodeGangGenerator, framework.Reclaim)
126181
solver := NewJobsSolver(
127182
nil,
128183
nil,
@@ -158,6 +213,40 @@ func TestSearchMaxSolvableKPreservesEnteredSearchAfterTerminalPartialProbe(t *te
158213
require.True(t, result.EnteredSearch())
159214
}
160215

216+
func jobSolverResultTestFeasibleNodes(ssn *framework.Session) []*node_info.NodeInfo {
217+
nodes := make([]*node_info.NodeInfo, 0, len(ssn.ClusterInfo.Nodes))
218+
for _, node := range ssn.ClusterInfo.Nodes {
219+
nodes = append(nodes, node)
220+
}
221+
return nodes
222+
}
223+
224+
func unrecordedVictimsForProbe(
225+
victimTasks []*pod_info.PodInfo, recordedVictims []*pod_info.PodInfo, probeK int,
226+
) []*pod_info.PodInfo {
227+
recordedByUID := map[common_info.PodID]struct{}{}
228+
for _, task := range recordedVictims {
229+
recordedByUID[task.UID] = struct{}{}
230+
}
231+
232+
neededVictims := probeK - len(recordedVictims)
233+
if neededVictims <= 0 {
234+
return nil
235+
}
236+
237+
selectedVictims := make([]*pod_info.PodInfo, 0, neededVictims)
238+
for _, task := range victimTasks {
239+
if _, alreadyRecorded := recordedByUID[task.UID]; alreadyRecorded {
240+
continue
241+
}
242+
selectedVictims = append(selectedVictims, task)
243+
if len(selectedVictims) == neededVictims {
244+
return selectedVictims
245+
}
246+
}
247+
return selectedVictims
248+
}
249+
161250
func TestPreserveEnteredSearchMarksTerminalResult(t *testing.T) {
162251
result := terminalSearchResult(SearchResultDeadlineExhausted, false, false)
163252

0 commit comments

Comments
 (0)