Skip to content

Commit 34a39b4

Browse files
committed
feat(scheduler): drive solver from generator portfolio
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent eb3e25c commit 34a39b4

4 files changed

Lines changed: 593 additions & 53 deletions

File tree

pkg/scheduler/actions/common/solvers/job_solver.go

Lines changed: 99 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"strings"
99
"time"
1010

11-
solverscenario "github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1211
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1312
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
1413
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
@@ -95,52 +94,97 @@ func (s *JobSolver) Solve(
9594
// describing why the scenario search stopped.
9695
func (s *JobSolver) SolveWithResult(
9796
ssn *framework.Session, pendingJob *podgroup_info.PodGroupInfo,
98-
) (bool, *framework.Statement, []string, *SearchResult) {
99-
state := solvingState{}
97+
) (solved bool, statement *framework.Statement, victimTaskNames []string, searchResult *SearchResult) {
10098
originalNumActiveTasks := pendingJob.GetNumActiveUsedTasks()
10199

102100
tasksToAllocate := podgroup_info.GetTasksToAllocate(pendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false)
103101
n := len(tasksToAllocate)
104102
if n == 0 {
105-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
106-
terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
103+
return false, nil, nil, terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
107104
}
108105

109106
actionBudget := s.ensureActionBudget()
110107
jobBudget := actionBudget.BeginJob()
111108
if actionBudget.Exhausted() {
112-
return false, nil, calcVictimNames(state.recordedVictimsTasks),
113-
terminalSearchResult(SearchResultNotAttempted, false, false)
109+
return false, nil, nil, terminalSearchResult(SearchResultNotAttempted, false, false)
114110
}
115111

112+
if s.generateVictimsQueue == nil {
113+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
114+
}
115+
registrations := applicableScenarioGeneratorRegistrations(ssn, s.actionType)
116+
if len(registrations) == 0 {
117+
return false, nil, nil, terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
118+
}
119+
120+
enteredSearch := false
121+
var lastVictimTasks []*pod_info.PodInfo
122+
var lastResult *SearchResult
123+
for _, registration := range registrations {
124+
state := solvingState{}
125+
generatorBudget := jobBudget.BeginGenerator(registration.Name)
126+
result := s.solvePendingJobWithGenerator(
127+
ssn, &state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
128+
)
129+
enteredSearch = enteredSearch || searchResultEntered(result) || resultSolved(result)
130+
lastVictimTasks = state.recordedVictimsTasks
131+
lastResult = result
132+
133+
if resultSolved(result) {
134+
solution := result.solution
135+
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
136+
jobSolved := pendingJob.IsGangSatisfied()
137+
if originalNumActiveTasks >= numActiveTasks {
138+
jobSolved = false
139+
}
140+
141+
log.InfraLogger.V(4).Infof(
142+
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
143+
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
144+
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
145+
}
146+
147+
if shouldStopSearch(result) {
148+
preserveEnteredSearch(result, enteredSearch)
149+
return false, nil, calcVictimNames(lastVictimTasks), result
150+
}
151+
}
152+
153+
if lastResult == nil {
154+
lastResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
155+
}
156+
preserveEnteredSearch(lastResult, enteredSearch)
157+
return false, nil, calcVictimNames(lastVictimTasks), lastResult
158+
}
159+
160+
func (s *JobSolver) solvePendingJobWithGenerator(
161+
ssn *framework.Session,
162+
state *solvingState,
163+
pendingJob *podgroup_info.PodGroupInfo,
164+
tasksToAllocate []*pod_info.PodInfo,
165+
jobBudget *jobSearchBudget,
166+
registration framework.ScenarioGeneratorRegistration,
167+
generatorBudget *generatorSearchBudget,
168+
) *SearchResult {
169+
n := len(tasksToAllocate)
116170
enteredSearch := false
117-
maxSolvedK, searchResult := s.searchMaxSolvableK(ssn, &state, pendingJob, tasksToAllocate, jobBudget)
171+
maxSolvedK, searchResult := s.searchMaxSolvableK(
172+
ssn, state, pendingJob, tasksToAllocate, jobBudget, registration, generatorBudget,
173+
)
118174
enteredSearch = searchResultEntered(searchResult) || maxSolvedK > 0
119175
if maxSolvedK == 0 {
120176
if searchResult == nil {
121-
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, false, false)
177+
searchResult = terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), false)
122178
}
123179
preserveEnteredSearch(searchResult, enteredSearch)
124-
return false, nil, calcVictimNames(state.recordedVictimsTasks), searchResult
180+
return searchResult
125181
}
126182

127-
result := s.probeAtK(ssn, &state, pendingJob, tasksToAllocate, n, jobBudget)
183+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, n, jobBudget, registration, generatorBudget)
128184
if !resultSolved(result) {
129185
preserveEnteredSearch(result, enteredSearch)
130-
return false, nil, calcVictimNames(state.recordedVictimsTasks), result
131186
}
132-
133-
solution := result.solution
134-
numActiveTasks := pendingJob.GetNumActiveUsedTasks()
135-
jobSolved := pendingJob.IsGangSatisfied()
136-
if originalNumActiveTasks >= numActiveTasks {
137-
jobSolved = false
138-
}
139-
140-
log.InfraLogger.V(4).Infof(
141-
"Scenario solved for %d tasks to allocate for %s. Victims: %s",
142-
n, pendingJob.Name, victimPrintingStruct{solution.victimsTasks})
143-
return jobSolved, solution.statement, calcVictimNames(solution.victimsTasks), result
187+
return result
144188
}
145189

146190
// searchMaxSolvableK returns the largest k in [0, n] for which a probe at k succeeds.
@@ -154,14 +198,18 @@ func (s *JobSolver) searchMaxSolvableK(
154198
pendingJob *podgroup_info.PodGroupInfo,
155199
tasksToAllocate []*pod_info.PodInfo,
156200
jobBudget *jobSearchBudget,
201+
registration framework.ScenarioGeneratorRegistration,
202+
generatorBudget *generatorSearchBudget,
157203
) (int, *SearchResult) {
158204
n := len(tasksToAllocate)
159205
if n == 0 {
160206
return 0, nil
161207
}
162208

163209
return searchMaxSolvableK(n, func(k int) *SearchResult {
164-
return s.tryProbeAndDiscard(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
210+
return s.tryProbeAndDiscard(
211+
ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget,
212+
)
165213
})
166214
}
167215

@@ -224,8 +272,10 @@ func (s *JobSolver) tryProbeAndDiscard(
224272
tasksToAllocate []*pod_info.PodInfo,
225273
k int,
226274
jobBudget *jobSearchBudget,
275+
registration framework.ScenarioGeneratorRegistration,
276+
generatorBudget *generatorSearchBudget,
227277
) *SearchResult {
228-
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget)
278+
result := s.probeAtK(ssn, state, pendingJob, tasksToAllocate, k, jobBudget, registration, generatorBudget)
229279
if !resultSolved(result) {
230280
log.InfraLogger.V(5).Infof("No solution found for %d tasks out of %d tasks to allocate for %s",
231281
k, len(tasksToAllocate), pendingJob.Name)
@@ -250,15 +300,18 @@ func (s *JobSolver) probeAtK(
250300
tasksToAllocate []*pod_info.PodInfo,
251301
k int,
252302
jobBudget *jobSearchBudget,
303+
registration framework.ScenarioGeneratorRegistration,
304+
generatorBudget *generatorSearchBudget,
253305
) *SearchResult {
254306
pendingTasks := tasksToAllocate[:k]
255307
partialPendingJob := getPartialJobRepresentative(pendingJob, pendingTasks)
256-
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget)
308+
return s.solvePartialJob(ssn, state, partialPendingJob, jobBudget, registration, generatorBudget, k)
257309
}
258310

259311
func (s *JobSolver) solvePartialJob(
260312
ssn *framework.Session, state *solvingState, partialPendingJob *podgroup_info.PodGroupInfo,
261-
jobBudget *jobSearchBudget,
313+
jobBudget *jobSearchBudget, registration framework.ScenarioGeneratorRegistration,
314+
generatorBudget *generatorSearchBudget, probeK int,
262315
) *SearchResult {
263316
actionBudget := s.ensureActionBudget()
264317
if jobBudget == nil {
@@ -274,40 +327,33 @@ func (s *JobSolver) solvePartialJob(
274327
feasibleNodeMap[task.NodeName] = node
275328
}
276329

277-
if s.generateVictimsQueue == nil {
278-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
279-
}
280-
victimsQueue := s.generateVictimsQueue()
281-
if victimsQueue == nil {
282-
return terminalSearchResult(SearchResultNoGenerator, jobBudget.ReducedBudget(), false)
330+
solveCtx := &SolveContext{
331+
Session: ssn,
332+
ActionType: s.actionType,
333+
PartialPendingJob: partialPendingJob,
334+
RecordedVictimsJobs: state.recordedVictimsJobs,
335+
RecordedVictimsTasks: state.recordedVictimsTasks,
336+
GenerateVictimsQueue: s.generateVictimsQueue,
337+
FeasibleNodes: feasibleNodeMap,
338+
ProbeK: probeK,
283339
}
340+
portfolio := newSingleGeneratorScenarioPortfolio(solveCtx, jobBudget, registration, generatorBudget)
284341

285-
scenarioBuilder := NewPodAccumulatedScenarioBuilder(
286-
ssn, partialPendingJob, state.recordedVictimsJobs, victimsQueue, feasibleNodeMap)
287-
288-
enteredSearch := false
289-
firstScenario := true
290342
for {
291343
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
292-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
293-
}
294-
var scenarioToSolve *solverscenario.ByNodeScenario
295-
if firstScenario {
296-
scenarioToSolve = scenarioBuilder.GetValidScenario()
297-
firstScenario = false
298-
} else {
299-
scenarioToSolve = scenarioBuilder.GetNextScenario()
344+
return terminalSearchResult(
345+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
346+
)
300347
}
348+
scenarioToSolve := portfolio.Next()
301349
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
302-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
350+
return terminalSearchResult(
351+
SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), portfolio.enteredSearch,
352+
)
303353
}
304354
if scenarioToSolve == nil {
305-
if actionBudget.Exhausted() || jobBudget.Remaining() <= 0 {
306-
return terminalSearchResult(SearchResultDeadlineExhausted, jobBudget.ReducedBudget(), enteredSearch)
307-
}
308355
break
309356
}
310-
enteredSearch = true
311357
scenarioSolver := newByPodSolver(feasibleNodeMap, s.solutionValidator, ssn.AllowConsolidatingReclaim(),
312358
s.actionType)
313359

@@ -320,7 +366,7 @@ func (s *JobSolver) solvePartialJob(
320366
}
321367
}
322368

323-
return terminalSearchResult(SearchResultGeneratorsExhausted, jobBudget.ReducedBudget(), enteredSearch)
369+
return terminalSearchResult(portfolio.StopReason(), jobBudget.ReducedBudget(), portfolio.enteredSearch)
324370
}
325371

326372
func searchResultEntered(result *SearchResult) bool {

pkg/scheduler/actions/common/solvers/job_solver_result_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package solvers
55

66
import (
7+
"fmt"
78
"testing"
89
"time"
910

@@ -13,10 +14,12 @@ import (
1314

1415
kaiv1 "github.com/kai-scheduler/KAI-scheduler/pkg/apis/kai/v1"
1516
"github.com/kai-scheduler/KAI-scheduler/pkg/common/constants"
17+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/common/solvers/scenario"
1618
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/actions/utils"
1719
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
1820
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
1921
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/node_info"
22+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/pod_info"
2023
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/podgroup_info"
2124
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/queue_info"
2225
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/framework"
@@ -88,6 +91,54 @@ func TestSolveWithResultReturnsNoGeneratorWhenGeneratorReturnsNil(t *testing.T)
8891
require.False(t, result.EnteredSearch())
8992
}
9093

94+
func TestSolveWithResultRunsCompletePartialSearchForOneGeneratorBeforeNext(t *testing.T) {
95+
ssn := newGeneratorTestSession(t, map[string]int{
96+
"node-1": 1,
97+
"node-2": 1,
98+
"node-3": 1,
99+
})
100+
require.NoError(t, ssn.InitNodeScoringPool())
101+
pendingJob := addGeneratorTestPendingJob(t, ssn, 3, 10, "team-pending")
102+
setGeneratorTestMinAvailable(pendingJob, 3)
103+
victimJob, victimTasks := addGeneratorTestJob(t, ssn, 3, 20, "team-victim", "node-1", "node-2", "node-3")
104+
factoryCalls := []string{}
105+
106+
ssn.AddScenarioGenerator("first", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
107+
solveCtx := ctx.(*SolveContext)
108+
factoryCalls = append(factoryCalls, fmt.Sprintf("first:%d", solveCtx.ProbeK))
109+
return &portfolioTestGenerator{name: "first"}
110+
}, framework.Reclaim)
111+
ssn.AddScenarioGenerator("second", func(ctx framework.ScenarioGeneratorContext) framework.ScenarioGenerator {
112+
solveCtx := ctx.(*SolveContext)
113+
factoryCalls = append(factoryCalls, fmt.Sprintf("second:%d", solveCtx.ProbeK))
114+
pendingTasks := podgroup_info.GetTasksToAllocate(
115+
solveCtx.PartialPendingJob, ssn.SubGroupOrderFn, ssn.TaskOrderFn, false,
116+
)
117+
sn := scenario.NewByNodeScenario(
118+
ssn, solveCtx.PartialPendingJob, pendingTasks,
119+
unrecordedVictimsForProbe(victimTasks, solveCtx.RecordedVictimsTasks, solveCtx.ProbeK),
120+
solveCtx.RecordedVictimsJobs,
121+
)
122+
return &portfolioTestGenerator{name: "second", scenarios: []api.ScenarioInfo{sn}}
123+
}, framework.Reclaim)
124+
solver := NewJobsSolver(
125+
jobSolverResultTestFeasibleNodes(ssn),
126+
nil,
127+
generatorTestVictimsQueueFactory(ssn, victimJob),
128+
framework.Reclaim,
129+
nil,
130+
)
131+
132+
solved, statement, _, result := solver.SolveWithResult(ssn, pendingJob)
133+
if statement != nil {
134+
defer statement.Discard()
135+
}
136+
137+
require.True(t, solved)
138+
require.Equal(t, SearchResultSolved, result.Reason())
139+
require.Equal(t, []string{"first:1", "second:1", "second:2", "second:3", "second:3"}, factoryCalls)
140+
}
141+
91142
func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
92143
clock := &fakeClock{now: time.Unix(0, 0)}
93144
actionBudget, err := newActionSearchBudgetWithClock(
@@ -102,6 +153,7 @@ func TestSolveWithResultReportsDeadlineBeforeScenarioSimulation(t *testing.T) {
102153
)
103154
require.NoError(t, err)
104155
ssn, pendingJob := newJobSolverResultTestSession(t, 1)
156+
ssn.AddScenarioGenerator("deadline-test", NewMultiNodeGangGenerator, framework.Reclaim)
105157
solver := NewJobsSolver(
106158
nil,
107159
nil,
@@ -137,6 +189,40 @@ func TestSearchMaxSolvableKPreservesEnteredSearchAfterTerminalPartialProbe(t *te
137189
require.True(t, result.EnteredSearch())
138190
}
139191

192+
func jobSolverResultTestFeasibleNodes(ssn *framework.Session) []*node_info.NodeInfo {
193+
nodes := make([]*node_info.NodeInfo, 0, len(ssn.ClusterInfo.Nodes))
194+
for _, node := range ssn.ClusterInfo.Nodes {
195+
nodes = append(nodes, node)
196+
}
197+
return nodes
198+
}
199+
200+
func unrecordedVictimsForProbe(
201+
victimTasks []*pod_info.PodInfo, recordedVictims []*pod_info.PodInfo, probeK int,
202+
) []*pod_info.PodInfo {
203+
recordedByUID := map[common_info.PodID]struct{}{}
204+
for _, task := range recordedVictims {
205+
recordedByUID[task.UID] = struct{}{}
206+
}
207+
208+
neededVictims := probeK - len(recordedVictims)
209+
if neededVictims <= 0 {
210+
return nil
211+
}
212+
213+
selectedVictims := make([]*pod_info.PodInfo, 0, neededVictims)
214+
for _, task := range victimTasks {
215+
if _, alreadyRecorded := recordedByUID[task.UID]; alreadyRecorded {
216+
continue
217+
}
218+
selectedVictims = append(selectedVictims, task)
219+
if len(selectedVictims) == neededVictims {
220+
return selectedVictims
221+
}
222+
}
223+
return selectedVictims
224+
}
225+
140226
func TestPreserveEnteredSearchMarksTerminalResult(t *testing.T) {
141227
result := terminalSearchResult(SearchResultDeadlineExhausted, false, false)
142228

0 commit comments

Comments
 (0)