Skip to content

Commit 4cbd3ea

Browse files
authored
fix(scheduler): clear completed session snapshots (#1726)
Signed-off-by: Erez Freiberger <enoodle@gmail.com>
1 parent 580d302 commit 4cbd3ea

3 files changed

Lines changed: 209 additions & 13 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
2323
- **Breaking:** The podgroup produced for JobSet is now produces as a single PodGroup per JobSet with a two-level SubGroup hierarchy (one parent SubGroup per `replicatedJob`, one leaf SubGroup per replica) regardless of `startupPolicyOrder`. The `kai.scheduler/batch-min-member` annotation on the JobSet now overrides the root `minSubGroup`; the same annotation on `replicatedJobs[].template.metadata.annotations` overrides the leaf `minMember` (defaulting to `template.spec.parallelism`). [#1617](https://github.com/kai-scheduler/KAI-Scheduler/pull/1617) [davidLif](https://github.com/davidLif)
2424

2525
### Fixed
26+
- Reduced scheduler heap retention after scheduling cycles by clearing completed session snapshots and callback references, and by releasing the node scoring pool without waiting for finalizers.
2627
- Fixed Helm chart prometheus RBAC always being installed when `prometheus.enabled` is false, and the `kai-prometheus` ClusterRoleBinding referencing the `prometheus` ServiceAccount in hardcoded `kai-scheduler` namespace instead of the Helm release namespace [#1684](https://github.com/kai-scheduler/KAI-Scheduler/pull/1684) [dttung2905](https://github.com/dttung2905)
2728
- Fixed post-delete cleanup hook hardcoding `kai-scheduler` namespace instead of Helm release namespace on `helm uninstall` [#1619](https://github.com/kai-scheduler/KAI-Scheduler/pull/1619) [dttung2905](https://github.com/dttung2905)
2829
- Improved solver performance in some large reclaim scenarios [#1627](https://github.com/kai-scheduler/KAI-Scheduler/pull/1627) [itsomri](https://github.com/itsomri)

pkg/scheduler/framework/session.go

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,13 +368,45 @@ func (ssn *Session) updatePodOnSession(pod *pod_info.PodInfo, status pod_status.
368368
}
369369

370370
func (ssn *Session) clear() {
371-
ssn.ClusterInfo.PodGroupInfos = nil
372-
ssn.ClusterInfo.Nodes = nil
371+
ssn.ClusterInfo = nil
373372
ssn.plugins = nil
374373
ssn.eventHandlers = nil
375-
ssn.TaskOrderFns = nil
376-
ssn.SubGroupOrderFns = nil
374+
ssn.GpuOrderFns = nil
375+
ssn.NodePreOrderFns = nil
376+
ssn.NodeOrderFns = nil
377377
ssn.JobOrderFns = nil
378+
ssn.SubGroupOrderFns = nil
379+
ssn.TaskOrderFns = nil
380+
ssn.QueueOrderFns = nil
381+
ssn.CanReclaimResourcesFns = nil
382+
ssn.ReclaimVictimFilterFns = nil
383+
ssn.PreemptVictimFilterFns = nil
384+
ssn.ReclaimScenarioValidatorFns = nil
385+
ssn.PreemptScenarioValidatorFns = nil
386+
ssn.OnJobSolutionStartFns = nil
387+
ssn.GetQueueAllocatedResourcesFns = nil
388+
ssn.GetQueueDeservedResourcesFns = nil
389+
ssn.GetQueueFairShareFns = nil
390+
ssn.IsNonPreemptibleJobOverQueueQuotaFns = nil
391+
ssn.IsJobOverCapacityFns = nil
392+
ssn.IsTaskAllocationOnNodeOverCapacityFns = nil
393+
ssn.SubsetNodesFns = nil
394+
ssn.PrePredicateFns = nil
395+
ssn.VictimInvariantPrePredicateFns = nil
396+
ssn.PredicateFns = nil
397+
ssn.BindRequestMutateFns = nil
398+
ssn.NumaPlacementFn = nil
399+
ssn.PreJobAllocationFns = nil
400+
ssn.Config = nil
401+
ssn.k8sResourceStateCache = sync.Map{}
402+
}
403+
404+
func (ssn *Session) releaseNodeScoringPool() {
405+
if ssn.nodeScoringPool != nil {
406+
ssn.nodeScoringPool.Release()
407+
ssn.nodeScoringPool = nil
408+
}
409+
ssn.scoringPoolWorkerCount = 0
378410
}
379411

380412
func (ssn *Session) InitNodeScoringPool() error {
@@ -385,12 +417,6 @@ func (ssn *Session) InitNodeScoringPool() error {
385417
}
386418
ssn.nodeScoringPool = pool
387419
ssn.scoringPoolWorkerCount = numWorkers
388-
runtime.SetFinalizer(ssn, func(s *Session) {
389-
if s.nodeScoringPool != nil {
390-
s.nodeScoringPool.Release()
391-
s.nodeScoringPool = nil
392-
}
393-
})
394420
return nil
395421
}
396422

@@ -414,6 +440,7 @@ func openSession(cache cache.Cache, sessionId string, schedulerParams conf.Sched
414440
log.InfraLogger.V(2).Infof("Taking cluster snapshot ...")
415441
snapshot, err := cache.Snapshot()
416442
if err != nil {
443+
ssn.releaseNodeScoringPool()
417444
return nil, err
418445
}
419446

@@ -436,9 +463,7 @@ func closeSession(ssn *Session) {
436463
}
437464
}
438465

439-
ssn.nodeScoringPool.Release()
440-
ssn.nodeScoringPool = nil
441-
ssn.scoringPoolWorkerCount = 0
466+
ssn.releaseNodeScoringPool()
442467
ssn.clear()
443468
stopCh := make(chan struct{})
444469
ssn.Cache.WaitForWorkers(stopCh)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Copyright 2026 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package framework
5+
6+
import (
7+
"runtime"
8+
"strconv"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
"go.uber.org/mock/gomock"
14+
15+
schedulingv1alpha2 "github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v1alpha2"
16+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api"
17+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/bindrequest_info"
18+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/api/common_info"
19+
scheduler_cache "github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/cache"
20+
"github.com/kai-scheduler/KAI-scheduler/pkg/scheduler/conf"
21+
)
22+
23+
const benchmarkBindRequestCount = 5000
24+
25+
func TestSessionClearDropsRetainedReferences(t *testing.T) {
26+
ssn := &Session{
27+
ClusterInfo: &api.ClusterInfo{
28+
BindRequests: bindrequest_info.BindRequestMap{
29+
bindrequest_info.NewKey("namespace", "pod"): &bindrequest_info.BindRequestInfo{},
30+
},
31+
BindRequestsForDeletedNodes: []*bindrequest_info.BindRequestInfo{{}},
32+
},
33+
Config: &conf.SchedulerConfiguration{},
34+
plugins: map[string]Plugin{"plugin": nil},
35+
eventHandlers: []*EventHandler{{}},
36+
TaskOrderFns: []common_info.CompareFn{nil},
37+
PrePredicateFns: []api.PrePredicateFn{nil},
38+
BindRequestMutateFns: []api.BindRequestMutateFn{nil},
39+
}
40+
ssn.k8sResourceStateCache.Store("resource", "state")
41+
42+
ssn.clear()
43+
44+
assert.Nil(t, ssn.ClusterInfo)
45+
assert.Nil(t, ssn.Config)
46+
assert.Nil(t, ssn.plugins)
47+
assert.Nil(t, ssn.eventHandlers)
48+
assert.Nil(t, ssn.TaskOrderFns)
49+
assert.Nil(t, ssn.PrePredicateFns)
50+
assert.Nil(t, ssn.BindRequestMutateFns)
51+
_, found := ssn.k8sResourceStateCache.Load("resource")
52+
assert.False(t, found)
53+
}
54+
55+
func TestCloseSessionReleasesSnapshotReferencesWhileSessionIsLive(t *testing.T) {
56+
finalized := make(chan struct{})
57+
cacheMock := scheduler_cache.NewMockCache(gomock.NewController(t))
58+
cacheMock.EXPECT().WaitForWorkers(gomock.Any()).Times(1)
59+
ssn := newSessionWithFinalizedBindRequest(t, cacheMock, finalized)
60+
61+
closeSession(ssn)
62+
63+
requireFinalized(t, finalized, ssn)
64+
}
65+
66+
func BenchmarkOpenCloseSessionWithLargeSnapshot(b *testing.B) {
67+
cacheMock := scheduler_cache.NewMockCache(gomock.NewController(b))
68+
cacheMock.EXPECT().Snapshot().AnyTimes().DoAndReturn(func() (*api.ClusterInfo, error) {
69+
return newClusterInfoWithBindRequests(benchmarkBindRequestCount), nil
70+
})
71+
cacheMock.EXPECT().WaitForWorkers(gomock.Any()).AnyTimes()
72+
73+
runtime.GC()
74+
before := heapAlloc()
75+
76+
b.ReportAllocs()
77+
b.ResetTimer()
78+
for i := 0; i < b.N; i++ {
79+
ssn, err := openSession(cacheMock, "benchmark", conf.SchedulerParams{}, nil)
80+
if err != nil {
81+
b.Fatal(err)
82+
}
83+
closeSession(ssn)
84+
}
85+
b.StopTimer()
86+
87+
runtime.GC()
88+
after := heapAlloc()
89+
retained := int64(after) - int64(before)
90+
if retained < 0 {
91+
retained = 0
92+
}
93+
b.ReportMetric(benchmarkBindRequestCount, "bind_requests/op")
94+
b.ReportMetric(float64(retained)/float64(b.N), "retained_after_1_gc_B/op")
95+
}
96+
97+
func newSessionWithFinalizedBindRequest(
98+
t *testing.T, cache scheduler_cache.Cache, finalized chan<- struct{},
99+
) *Session {
100+
t.Helper()
101+
102+
bindRequest := &schedulingv1alpha2.BindRequest{
103+
Spec: schedulingv1alpha2.BindRequestSpec{
104+
PodName: "pod",
105+
},
106+
}
107+
runtime.SetFinalizer(bindRequest, func(*schedulingv1alpha2.BindRequest) {
108+
close(finalized)
109+
})
110+
111+
ssn := &Session{
112+
Cache: cache,
113+
ClusterInfo: &api.ClusterInfo{
114+
BindRequests: bindrequest_info.BindRequestMap{
115+
bindrequest_info.NewKey("namespace", "pod"): bindrequest_info.NewBindRequestInfo(bindRequest),
116+
},
117+
BindRequestsForDeletedNodes: []*bindrequest_info.BindRequestInfo{
118+
bindrequest_info.NewBindRequestInfo(bindRequest),
119+
},
120+
},
121+
}
122+
if err := ssn.InitNodeScoringPool(); err != nil {
123+
t.Fatalf("failed to initialize node scoring pool: %v", err)
124+
}
125+
return ssn
126+
}
127+
128+
func requireFinalized(t *testing.T, finalized <-chan struct{}, keepAlive *Session) {
129+
t.Helper()
130+
131+
deadline := time.NewTimer(3 * time.Second)
132+
defer deadline.Stop()
133+
134+
for {
135+
runtime.GC()
136+
select {
137+
case <-finalized:
138+
runtime.KeepAlive(keepAlive)
139+
return
140+
case <-deadline.C:
141+
runtime.KeepAlive(keepAlive)
142+
t.Fatal("snapshot bind request was still reachable after closeSession")
143+
default:
144+
time.Sleep(10 * time.Millisecond)
145+
}
146+
}
147+
}
148+
149+
func heapAlloc() uint64 {
150+
var stats runtime.MemStats
151+
runtime.ReadMemStats(&stats)
152+
return stats.HeapAlloc
153+
}
154+
155+
func newClusterInfoWithBindRequests(count int) *api.ClusterInfo {
156+
clusterInfo := &api.ClusterInfo{
157+
BindRequests: make(bindrequest_info.BindRequestMap, count),
158+
}
159+
for i := 0; i < count; i++ {
160+
podName := "pod-" + strconv.Itoa(i)
161+
bindRequest := &schedulingv1alpha2.BindRequest{
162+
Spec: schedulingv1alpha2.BindRequestSpec{
163+
PodName: podName,
164+
},
165+
}
166+
clusterInfo.BindRequests[bindrequest_info.NewKey("namespace", podName)] =
167+
bindrequest_info.NewBindRequestInfo(bindRequest)
168+
}
169+
return clusterInfo
170+
}

0 commit comments

Comments
 (0)