Skip to content

Commit cd5bac1

Browse files
committed
feat(scheduler): drive solver from generator portfolio
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent 23e3426 commit cd5bac1

4 files changed

Lines changed: 593 additions & 53 deletions

File tree

pkg/scheduler/actions/common/solvers/job_solver.go

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"strings"
99
"time"
1010

11-
solverscenario "github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1211
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1312
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
1413
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
@@ -94,52 +93,97 @@ func (s *JobSolver) Solve(
9493
// describing why the scenario search stopped.
9594
func (s *JobSolver) SolveWithResult(
9695
ssn *framework.Session, pendingJob *podgroup_info.PodGroupInfo,
97-
) (bool, *framework.Statement, []string, *SearchResult) {
98-
state := solvingState{}
96+
) (solved bool, statement *framework.Statement, victimTaskNames []string, searchResult *SearchResult) {
9997
originalNumActiveTasks := pendingJob.GetNumActiveUsedTasks()
10098

10199
tasksToAllocate := podgroup_info.GetTasksToAllocate(pendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false)
102100
n := len(tasksToAllocate)
103101
if n == 0 {
104-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
105-
terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
102+
return false, nil, nil, terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
106103
}
107104

108105
actionBudget := s.ensureActionBudget()
109106
jobBudget := actionBudget.BeginJob()
110107
if jobBudget.Remaining() <= 0 {
111-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
112-
terminalSearchResult(SearchResultNotAttempted, false, false)
108+
return false, nil, nil, terminalSearchResult(SearchResultNotAttempted, false, false)
113109
}
114110

111+
if s.generateVictimsQueue == nil {
112+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
113+
}
114+
registrations := applicableScenarioGeneratorRegistrations(ssn, s.actionType)
115+
if len(registrations) == 0 {
116+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
117+
}
118+
119+
enteredSearch := false
120+
var lastVictimTasks []*pod_info.PodInfo
121+
var lastResult *SearchResult
122+
for _, registration := range registrations {
123+
state := solvingState{}
124+
generatorBudget := jobBudget.BeginGenerator(registration.Name)
125+
result := s.solvePendingJobWithGenerator(
126+
ssn, &state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
127+
)
128+
enteredSearch = enteredSearch || searchResultEntered(result) || resultSolved(result)
129+
lastVictimTasks = state.recordedVictimsTasks
130+
lastResult = result
131+
132+
if resultSolved(result) {
133+
solution := result.solution
134+
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
135+
jobSolved := pendingJob.IsGangSatisfied()
136+
if originalNumActiveTasks >= numActiveTasks {
137+
jobSolved = false
138+
}
139+
140+
log.InfraLogger.V(4).Infof(
141+
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
142+
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
143+
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
144+
}
145+
146+
if shouldStopSearch(result) {
147+
preserveEnteredSearch(result, enteredSearch)
148+
return false, nil, calcVictimNames(lastVictimTasks), result
149+
}
150+
}
151+
152+
if lastResult == nil {
153+
lastResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
154+
}
155+
preserveEnteredSearch(lastResult, enteredSearch)
156+
return false, nil, calcVictimNames(lastVictimTasks), lastResult
157+
}
158+
159+
func (s *JobSolver) solvePendingJobWithGenerator(
160+
ssn *framework.Session,
161+
state *solvingState,
162+
pendingJob *podgroup_info.PodGroupInfo,
163+
tasksToAllocate []*pod_info.PodInfo,
164+
jobBudget *jobSearchBudget,
165+
registration framework.ScenarioGeneratorRegistration,
166+
generatorBudget *generatorSearchBudget,
167+
) *SearchResult {
168+
n := len(tasksToAllocate)
115169
enteredSearch := false
116-
maxSolvedK, searchResult := s.searchMaxSolvableK(ssn, &state, pendingJob, tasksToAllocate, jobBudget)
170+
maxSolvedK, searchResult := s.searchMaxSolvableK(
171+
ssn, state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
172+
)
117173
enteredSearch = searchResultEntered(searchResult) || maxSolvedK > 0
118174
if maxSolvedK == 0 {
119175
if searchResult == nil {
120-
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
176+
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), false)
121177
}
122178
preserveEnteredSearch(searchResult, enteredSearch)
123-
return false, nil, calcVictimNames(state.recordedVictimsTasks), searchResult
179+
return searchResult
124180
}
125181

126-
result := s.probeAtK(ssn, &state, pendingJob, tasksToAllocate, n, jobBudget)
182+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, n, jobBudget, registration, generatorBudget)
127183
if !resultSolved(result) {
128184
preserveEnteredSearch(result, enteredSearch)
129-
return false, nil, calcVictimNames(state.recordedVictimsTasks), result
130185
}
131-
132-
solution := result.solution
133-
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
134-
jobSolved := pendingJob.IsGangSatisfied()
135-
if originalNumActiveTasks >= numActiveTasks {
136-
jobSolved = false
137-
}
138-
139-
log.InfraLogger.V(4).Infof(
140-
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
141-
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
142-
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
186+
return result
143187
}
144188

145189
// searchMaxSolvableK returns the largest k in [0, n] for which a probe at k succeeds.
@@ -153,14 +197,18 @@ func (s *JobSolver) searchMaxSolvableK(
153197
pendingJob *podgroup_info.PodGroupInfo,
154198
tasksToAllocate []*pod_info.PodInfo,
155199
jobBudget *jobSearchBudget,
200+
registration framework.ScenarioGeneratorRegistration,
201+
generatorBudget *generatorSearchBudget,
156202
) (int, *SearchResult) {
157203
n := len(tasksToAllocate)
158204
if n == 0 {
159205
return 0, nil
160206
}
161207

162208
return searchMaxSolvableK(n, func(k int) *SearchResult {
163-
return s.tryProbeAndDiscard(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
209+
return s.tryProbeAndDiscard(
210+
ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget,
211+
)
164212
})
165213
}
166214

@@ -223,8 +271,10 @@ func (s *JobSolver) tryProbeAndDiscard(
223271
tasksToAllocate []*pod_info.PodInfo,
224272
k int,
225273
jobBudget *jobSearchBudget,
274+
registration framework.ScenarioGeneratorRegistration,
275+
generatorBudget *generatorSearchBudget,
226276
) *SearchResult {
227-
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
277+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget)
228278
if !resultSolved(result) {
229279
log.InfraLogger.V(5).Infof("No solution found for %d tasks out of %d tasks to allocate for %s",
230280
k, len(tasksToAllocate), pendingJob.Name)
@@ -249,15 +299,18 @@ func (s *JobSolver) probeAtK(
249299
tasksToAllocate []*pod_info.PodInfo,
250300
k int,
251301
jobBudget *jobSearchBudget,
302+
registration framework.ScenarioGeneratorRegistration,
303+
generatorBudget *generatorSearchBudget,
252304
) *SearchResult {
253305
pendingTasks := tasksToAllocate[:k]
254306
partialPendingJob := getPartialJobRepresentative(pendingJob, pendingTasks)
255-
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget)
307+
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget, registration, generatorBudget, k)
256308
}
257309

258310
func (s *JobSolver) solvePartialJob(
259311
ssn *framework.Session, state *solvingState, partialPendingJob *podgroup_info.PodGroupInfo,
260-
jobBudget *jobSearchBudget,
312+
jobBudget *jobSearchBudget, registration framework.ScenarioGeneratorRegistration,
313+
generatorBudget *generatorSearchBudget, probeK int,
261314
) *SearchResult {
262315
if jobBudget == nil {
263316
jobBudget = s.ensureActionBudget().BeginJob()
@@ -272,40 +325,33 @@ func (s *JobSolver) solvePartialJob(
272325
feasibleNodeMap[task.NodeName] = node
273326
}
274327

275-
if s.generateVictimsQueue == nil {
276-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
277-
}
278-
victimsQueue := s.generateVictimsQueue()
279-
if victimsQueue == nil {
280-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
328+
solveCtx := &SolveContext{
329+
Session: ssn,
330+
ActionType: s.actionType,
331+
PartialPendingJob: partialPendingJob,
332+
RecordedVictimsJobs: state.recordedVictimsJobs,
333+
RecordedVictimsTasks: state.recordedVictimsTasks,
334+
GenerateVictimsQueue: s.generateVictimsQueue,
335+
FeasibleNodes: feasibleNodeMap,
336+
ProbeK: probeK,
281337
}
338+
portfolio := newSingleGeneratorScenarioPortfolio(solveCtx, jobBudget, registration, generatorBudget)
282339

283-
scenarioBuilder := NewPodAccumulatedScenarioBuilder(
284-
ssn, partialPendingJob, state.recordedVictimsJobs, victimsQueue, feasibleNodeMap)
285-
286-
enteredSearch := false
287-
firstScenario := true
288340
for {
289341
if jobBudget.Remaining() <= 0 {
290-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
291-
}
292-
var scenarioToSolve *solverscenario.ByNodeScenario
293-
if firstScenario {
294-
scenarioToSolve = scenarioBuilder.GetValidScenario()
295-
firstScenario = false
296-
} else {
297-
scenarioToSolve = scenarioBuilder.GetNextScenario()
342+
return terminalSearchResult(
343+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
344+
)
298345
}
346+
scenarioToSolve := portfolio.Next()
299347
if jobBudget.Remaining() <= 0 {
300-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
348+
return terminalSearchResult(
349+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
350+
)
301351
}
302352
if scenarioToSolve == nil {
303-
if jobBudget.Remaining() <= 0 {
304-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
305-
}
306353
break
307354
}
308-
enteredSearch = true
309355
scenarioSolver := newByPodSolver(feasibleNodeMap, s.solutionValidator, ssn.AllowConsolidatingReclaim(),
310356
s.actionType)
311357

@@ -318,7 +364,7 @@ func (s *JobSolver) solvePartialJob(
318364
}
319365
}
320366

321-
return terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
367+
return terminalSearchResult(portfolio.StopReason(), jobBudget.ReducedBudget(), portfolio.enteredSearch)
322368
}
323369

324370
func searchResultEntered(result *SearchResult) bool {

pkg/scheduler/actions/common/solvers/job_solver_result_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package solvers
55

66
import (
7+
"fmt"
78
"testing"
89
"time"
910

@@ -13,10 +14,12 @@ import (
1314

1415
kaiv1 "github.com/kai-scheduler/KAI-scheduler/pkg/apis/kai/v1"
1516
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
17+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1618
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1719
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
1820
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
1921
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
22+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
2023
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/podgroup_info"
2124
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/queue_info"
2225
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
@@ -116,6 +119,54 @@ func TestSolveWithResultUsesMinJobBudgetAfterActionBudgetExpired(t *testing.T) {
116119
require.False(t, result.EnteredSearch())
117120
}
118121

122+
func TestSolveWithResultRunsCompletePartialSearchForOneGeneratorBeforeNext(t *testing.T) {
123+
ssn := newGeneratorTestSession(t, map[string]int{
124+
"node-1": 1,
125+
"node-2": 1,
126+
"node-3": 1,
127+
})
128+
require.NoError(t, ssn.InitNodeScoringPool())
129+
pendingJob := addGeneratorTestPendingJob(t, ssn, 3, 10, "team-pending")
130+
setGeneratorTestMinAvailable(pendingJob, 3)
131+
victimJob, victimTasks := addGeneratorTestJob(t, ssn, 3, 20, "team-victim", "node-1", "node-2", "node-3")
132+
factoryCalls := []string{}
133+
134+
ssn.AddScenarioGenerator("first", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
135+
solveCtx := ctx.(*SolveContext)
136+
factoryCalls = append(factoryCalls, fmt.Sprintf("first:%d", solveCtx.ProbeK))
137+
return &portfolioTestGenerator{name: "first"}
138+
}, framework.Reclaim)
139+
ssn.AddScenarioGenerator("second", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
140+
solveCtx := ctx.(*SolveContext)
141+
factoryCalls = append(factoryCalls, fmt.Sprintf("second:%d", solveCtx.ProbeK))
142+
pendingTasks := podgroup_info.GetTasksToAllocate(
143+
solveCtx.PartialPendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false,
144+
)
145+
sn := scenario.NewByNodeScenario(
146+
ssn, solveCtx.PartialPendingJob, pendingTasks,
147+
unrecordedVictimsForProbe(victimTasks, solveCtx.RecordedVictimsTasks, solveCtx.ProbeK),
148+
solveCtx.RecordedVictimsJobs,
149+
)
150+
return &portfolioTestGenerator{name: "second", scenarios: []api.ScenarioInfo{sn}}
151+
}, framework.Reclaim)
152+
solver := NewJobsSolver(
153+
jobSolverResultTestFeasibleNodes(ssn),
154+
nil,
155+
generatorTestVictimsQueueFactory(ssn, victimJob),
156+
framework.Reclaim,
157+
nil,
158+
)
159+
160+
solved, statement, _, result := solver.SolveWithResult(ssn, pendingJob)
161+
if statement != nil {
162+
defer statement.Discard()
163+
}
164+
165+
require.True(t, solved)
166+
require.Equal(t, SearchResultSolved, result.Reason())
167+
require.Equal(t, []string{"first:1", "second:1", "second:2", "second:3", "second:3"}, factoryCalls)
168+
}
169+
119170
func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
120171
clock := &fakeClock{now: time.Unix(0, 0)}
121172
actionBudget, err := newActionSearchBudgetWithClock(
@@ -130,6 +181,7 @@ func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
130181
)
131182
require.NoError(t, err)
132183
ssn, pendingJob := newJobSolverResultTestSession(t, 1)
184+
ssn.AddScenarioGenerator("deadline-test", NewMultiNodeGangGenerator, framework.Reclaim)
133185
solver := NewJobsSolver(
134186
nil,
135187
nil,
@@ -165,6 +217,40 @@ func TestSearchMaxSolvableKPreservesEnteredSearchAfterTerminalPartialProbe(t *te
165217
require.True(t, result.EnteredSearch())
166218
}
167219

220+
func jobSolverResultTestFeasibleNodes(ssn *framework.Session) []*node_info.NodeInfo {
221+
nodes := make([]*node_info.NodeInfo, 0, len(ssn.ClusterInfo.Nodes))
222+
for _, node := range ssn.ClusterInfo.Nodes {
223+
nodes = append(nodes, node)
224+
}
225+
return nodes
226+
}
227+
228+
func unrecordedVictimsForProbe(
229+
victimTasks []*pod_info.PodInfo, recordedVictims []*pod_info.PodInfo, probeK int,
230+
) []*pod_info.PodInfo {
231+
recordedByUID := map[common_info.PodID]struct{}{}
232+
for _, task := range recordedVictims {
233+
recordedByUID[task.UID] = struct{}{}
234+
}
235+
236+
neededVictims := probeK - len(recordedVictims)
237+
if neededVictims <= 0 {
238+
return nil
239+
}
240+
241+
selectedVictims := make([]*pod_info.PodInfo, 0, neededVictims)
242+
for _, task := range victimTasks {
243+
if _, alreadyRecorded := recordedByUID[task.UID]; alreadyRecorded {
244+
continue
245+
}
246+
selectedVictims = append(selectedVictims, task)
247+
if len(selectedVictims) == neededVictims {
248+
return selectedVictims
249+
}
250+
}
251+
return selectedVictims
252+
}
253+
168254
func TestPreserveEnteredSearchMarksTerminalResult(t *testing.T) {
169255
result := terminalSearchResult(SearchResultDeadlineExhausted, false, false)
170256

0 commit comments

Comments
 (0)