Skip to content

Commit 09e6137

Browse files
authored
refactor(scheduler): patch pod labels concurrently (#147)
1 parent 1b5cc09 commit 09e6137

File tree

5 files changed

+63
-53
lines changed

5 files changed

+63
-53
lines changed

pkg/scheduler/actions/allocate/allocate_test.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,16 @@ import (
88
"testing"
99

1010
. "go.uber.org/mock/gomock"
11-
v1 "k8s.io/api/core/v1"
12-
"k8s.io/client-go/kubernetes"
1311
"k8s.io/utils/pointer"
1412

1513
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/actions/allocate"
1614
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/actions/integration_tests/integration_tests_utils"
1715
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
16+
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1817
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_status"
18+
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/cache"
1919
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/conf"
2020
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/constants"
21-
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/k8s_utils"
2221
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/test_utils"
2322
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/test_utils/jobs_fake"
2423
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/test_utils/nodes_fake"
@@ -1787,18 +1786,16 @@ func TestHandleElasticJobCommitFailure(t *testing.T) {
17871786
},
17881787
controller,
17891788
)
1790-
k8s_utils.Helpers = &failingK8sUtils{k8s_utils.Helpers}
1789+
ssn.Cache = &failingBindCache{Cache: ssn.Cache}
17911790

17921791
allocateAction := allocate.New()
17931792
allocateAction.Execute(ssn)
17941793
}
17951794

1796-
type failingK8sUtils struct {
1797-
k8s_utils.Interface
1795+
type failingBindCache struct {
1796+
cache.Cache
17981797
}
17991798

1800-
func (f *failingK8sUtils) PatchPodAnnotationsAndLabelsInterface(
1801-
_ kubernetes.Interface, _ *v1.Pod, _, _ map[string]interface{},
1802-
) error {
1799+
func (f *failingBindCache) Bind(podInfo *pod_info.PodInfo, hostname string) error {
18031800
return fmt.Errorf("create pod error")
18041801
}

pkg/scheduler/cache/cache.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,19 @@ func (sc *SchedulerCache) Bind(taskInfo *pod_info.PodInfo, hostname string) erro
224224
startTime := time.Now()
225225
defer metrics.UpdateTaskBindDuration(startTime)
226226

227-
err := sc.createBindRequest(taskInfo, hostname)
228-
return sc.StatusUpdater.Bound(taskInfo.Pod, hostname, err, sc.getNodPoolName())
227+
log.InfraLogger.V(3).Infof(
228+
"Creating bind request for task <%v/%v> to node <%v> gpuGroup: <%v>, requires: <%v> GPUs",
229+
taskInfo.Namespace, taskInfo.Name, hostname, taskInfo.GPUGroups, taskInfo.ResReq)
230+
if bindRequestError := sc.createBindRequest(taskInfo, hostname); bindRequestError != nil {
231+
return sc.StatusUpdater.Bound(taskInfo.Pod, hostname, bindRequestError, sc.getNodPoolName())
232+
}
233+
234+
labelsPatch := sc.nodePoolLabelsChange(taskInfo.Pod.Labels)
235+
if len(labelsPatch) > 0 {
236+
sc.StatusUpdater.PatchPodLabels(taskInfo.Pod, labelsPatch)
237+
}
238+
239+
return sc.StatusUpdater.Bound(taskInfo.Pod, hostname, nil, sc.getNodPoolName())
229240
}
230241

231242
// +kubebuilder:rbac:groups="scheduling.run.ai",resources=bindrequests,verbs=create;update;patch
@@ -276,6 +287,18 @@ func (sc *SchedulerCache) getNodPoolName() string {
276287
return "default"
277288
}
278289

290+
func (sc *SchedulerCache) nodePoolLabelsChange(currentLabels map[string]string) map[string]any {
291+
labels := map[string]any{}
292+
if sc.schedulingNodePoolParams.NodePoolLabelKey == "" {
293+
return labels
294+
}
295+
if value, found := currentLabels[sc.schedulingNodePoolParams.NodePoolLabelKey]; found && value == sc.schedulingNodePoolParams.NodePoolLabelValue {
296+
return labels
297+
}
298+
labels[sc.schedulingNodePoolParams.NodePoolLabelKey] = sc.schedulingNodePoolParams.NodePoolLabelValue
299+
return labels
300+
}
301+
279302
func (sc *SchedulerCache) String() string {
280303
str := "Cache:\n"
281304

pkg/scheduler/cache/status_updater/default_status_updater.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,33 @@ func (su *defaultStatusUpdater) Pipelined(pod *v1.Pod, message string) {
146146
su.recorder.Eventf(pod, v1.EventTypeNormal, "Pipelined", message)
147147
}
148148

149+
func (su *defaultStatusUpdater) PatchPodLabels(pod *v1.Pod, labels map[string]any) {
150+
log.InfraLogger.V(6).Infof("Patching pod labels for %s/%s", pod.Namespace, pod.Name)
151+
152+
patchBytes, err := json.Marshal(map[string]any{
153+
"metadata": map[string]any{
154+
"labels": labels,
155+
},
156+
})
157+
158+
if err != nil {
159+
log.InfraLogger.Errorf("Failed to create patch for pod labels <%s/%s>: %v",
160+
pod.Namespace, pod.Name, err)
161+
return
162+
}
163+
164+
su.pushToUpdateQueue(
165+
&updatePayload{
166+
key: su.keyForPayload(pod.Name, pod.Namespace, pod.UID) + "-Labels",
167+
objectType: podType,
168+
},
169+
&inflightUpdate{
170+
object: pod,
171+
patchData: patchBytes,
172+
},
173+
)
174+
}
175+
149176
func (su *defaultStatusUpdater) RecordJobStatusEvent(job *podgroup_info.PodGroupInfo) error {
150177
var err error
151178
var patchData []byte
@@ -247,7 +274,7 @@ func (su *defaultStatusUpdater) updatePodCondition(pod *v1.Pod, condition *v1.Po
247274

248275
su.pushToUpdateQueue(
249276
&updatePayload{
250-
key: su.keyForPayload(pod.Name, pod.Namespace, pod.UID),
277+
key: su.keyForPayload(pod.Name, pod.Namespace, pod.UID) + "-Status",
251278
objectType: podType,
252279
},
253280
&inflightUpdate{

pkg/scheduler/cache/status_updater/interface.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type Interface interface {
2020
Evicted(evictedPodGroup *enginev2alpha2.PodGroup, evictionMetadata eviction_info.EvictionMetadata, message string)
2121
Bound(pod *v1.Pod, hostname string, bindError error, nodePoolName string) error
2222
Pipelined(pod *v1.Pod, message string)
23+
PatchPodLabels(pod *v1.Pod, labels map[string]interface{})
2324
RecordJobStatusEvent(job *podgroup_info.PodGroupInfo) error
2425

2526
Run(stopCh <-chan struct{})

pkg/scheduler/framework/statement.go

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,11 @@ import (
99
"golang.org/x/exp/slices"
1010
"k8s.io/apimachinery/pkg/types"
1111

12-
commonconstants "github.com/NVIDIA/KAI-scheduler/pkg/common/constants"
1312
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/common_info"
1413
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/eviction_info"
1514
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/node_info"
1615
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_info"
1716
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/pod_status"
18-
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/api/podgroup_info"
19-
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/k8s_utils"
2017
"github.com/NVIDIA/KAI-scheduler/pkg/scheduler/log"
2118
)
2219

@@ -345,9 +342,6 @@ func (s *Statement) commitAllocate(task *pod_info.PodInfo) error {
345342
}
346343
}()
347344

348-
log.InfraLogger.V(3).Infof(
349-
"Creating bind request for task <%v/%v> to node <%v> gpuGroup: <%v>, requires: <%v> GPUs",
350-
task.Namespace, task.Name, node.Name, task.GPUGroups, task.ResReq)
351345
if task.IsFractionAllocation() {
352346
for _, gpuGroup := range task.GPUGroups {
353347
if _, found := node.UsedSharedGPUsMemory[gpuGroup]; !found {
@@ -356,17 +350,11 @@ func (s *Statement) commitAllocate(task *pod_info.PodInfo) error {
356350
}
357351
}
358352

359-
podGroup, found := s.ssn.PodGroupInfos[task.Job]
360-
if !found {
361-
return fmt.Errorf("failed to find podGroup <%v> for pod <%v/%v>", task.Job, task.Namespace, task.Name)
353+
if err = s.ssn.BindPod(task); err != nil {
354+
log.InfraLogger.Errorf("Failed to bind task <%v/%v>. Error: %v",
355+
task.Namespace, task.Name, err)
362356
}
363-
labels := getNodePoolLabelsToPatchForPod(task, podGroup)
364357

365-
err = k8s_utils.Helpers.PatchPodAnnotationsAndLabelsInterface(
366-
s.ssn.Cache.KubeClient(), task.Pod, map[string]interface{}{}, labels)
367-
if err == nil {
368-
task.IsVirtualStatus = false
369-
}
370358
return err
371359
}
372360

@@ -547,13 +535,6 @@ func (s *Statement) Commit() error {
547535
s.clearOperations()
548536
return err
549537
}
550-
551-
taskInfo.Pod.Spec.NodeName = taskInfo.NodeName
552-
if err = s.ssn.BindPod(taskInfo); err != nil {
553-
log.InfraLogger.Errorf("Failed to bind task <%v/%v>. Error: %v",
554-
taskInfo.Namespace, taskInfo.Name, err)
555-
}
556-
taskInfo.Pod.Spec.NodeName = ""
557538
}
558539
}
559540

@@ -637,25 +618,6 @@ func (s *Statement) cleanupFailedAllocation(task *pod_info.PodInfo, node *node_i
637618
_ = s.unallocate(task, node.Name, false)
638619
}
639620

640-
func getNodePoolLabelsToPatchForPod(task *pod_info.PodInfo, job *podgroup_info.PodGroupInfo) map[string]interface{} {
641-
jobsNodePool, jobsNodePoolFound := job.PodGroup.Labels[commonconstants.NodePoolNameLabel]
642-
podsNodePool, podsNodePoolFound := task.Pod.Labels[commonconstants.NodePoolNameLabel]
643-
644-
if (!jobsNodePoolFound && !podsNodePoolFound) || (podsNodePool == jobsNodePool) {
645-
return map[string]interface{}{}
646-
}
647-
648-
labels := map[string]interface{}{}
649-
650-
if !jobsNodePoolFound {
651-
// to delete the label
652-
labels[commonconstants.NodePoolNameLabel] = nil
653-
} else {
654-
labels[commonconstants.NodePoolNameLabel] = jobsNodePool
655-
}
656-
return labels
657-
}
658-
659621
func (s *Statement) operationValid(i int) bool {
660622
for undoIndex, operation := range s.operations {
661623
if operation.Name() != undo {

0 commit comments

Comments
 (0)