Skip to content

Commit c536d88

Browse files
committed
Added rd.CreateDistributedBatchJob and options
Signed-off-by: itsomri <omric@nvidia.com>
1 parent 08625cf commit c536d88

4 files changed

Lines changed: 208 additions & 107 deletions

File tree

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Copyright 2026 NVIDIA CORPORATION
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package rd
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"maps"
10+
"time"
11+
12+
batchv1 "k8s.io/api/batch/v1"
13+
v1 "k8s.io/api/core/v1"
14+
"k8s.io/apimachinery/pkg/api/errors"
15+
"k8s.io/apimachinery/pkg/types"
16+
"k8s.io/apimachinery/pkg/util/wait"
17+
"k8s.io/utils/ptr"
18+
runtimeClient "sigs.k8s.io/controller-runtime/pkg/client"
19+
20+
v2 "github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v2"
21+
"github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
22+
pgconstants "github.com/kai-scheduler/KAI-scheduler/pkg/podgrouper/podgrouper/plugins/constants"
23+
)
24+
25+
const (
26+
// JobNameLabel is the label the k8s Job controller sets on every pod it creates.
27+
JobNameLabel = "batch.kubernetes.io/job-name"
28+
29+
podGroupFetchTimeout = 30 * time.Second
30+
podGroupFetchPoll = 250 * time.Millisecond
31+
)
32+
33+
// DistributedBatchJobOptions configures CreateDistributedBatchJob. Every field is optional
34+
// — pass DistributedBatchJobOptions{} to get a single-pod gang Job with no resource requests.
35+
type DistributedBatchJobOptions struct {
36+
// Parallelism is the number of pods the Job spawns. nil means 1.
37+
Parallelism *int32
38+
// MinMember is the PodGroup MinAvailable. nil means Parallelism (gang).
39+
// Gang: MinMember == Parallelism
40+
// Elastic: 1 <= MinMember < Parallelism
41+
MinMember *int32
42+
// Resources applied to each pod. Zero value means no requests/limits.
43+
Resources v1.ResourceRequirements
44+
// NamePrefix is prepended to the generated Job name.
45+
NamePrefix string
46+
// TopologyConstraint is propagated to the auto-created PodGroup via annotations.
47+
TopologyConstraint *v2alpha2.TopologyConstraint
48+
// PriorityClassName is set on the pod template; the podgrouper reads it onto the PodGroup.
49+
PriorityClassName string
50+
// Preemptibility is set as a Job label; the podgrouper reads it onto the PodGroup.
51+
Preemptibility v2alpha2.Preemptibility
52+
// ExtraLabels are merged into pod template labels (e.g. for test filtering).
53+
ExtraLabels map[string]string
54+
// PodSpecMutator is applied to the pod template spec after defaults are set. Scale
55+
// tests use this to inject KWOK tolerations/affinity without importing scale into rd.
56+
PodSpecMutator func(*v1.PodSpec)
57+
}
58+
59+
// CreateDistributedBatchJob submits a batch Job annotated with kai.scheduler/batch-min-member
60+
// so the podgrouper produces a single PodGroup with MinAvailable=opts.MinMember. Returns the
61+
// Job, the PodGroup (once the podgrouper has created it), and the pods the Job spawned.
62+
func CreateDistributedBatchJob(
63+
ctx context.Context,
64+
kubeClient runtimeClient.Client,
65+
jobQueue *v2.Queue,
66+
opts DistributedBatchJobOptions,
67+
) (*batchv1.Job, *v2alpha2.PodGroup, []*v1.Pod, error) {
68+
parallelism := ptr.Deref(opts.Parallelism, 1)
69+
minMember := ptr.Deref(opts.MinMember, parallelism)
70+
71+
job := buildDistributedBatchJob(jobQueue, opts, parallelism, minMember)
72+
if err := kubeClient.Create(ctx, job); err != nil {
73+
return nil, nil, nil, fmt.Errorf("create Job: %w", err)
74+
}
75+
76+
podGroup, err := waitForPodGroup(ctx, kubeClient, job)
77+
if err != nil {
78+
return job, nil, nil, err
79+
}
80+
81+
pods, err := waitForJobPods(ctx, kubeClient, job, parallelism)
82+
if err != nil {
83+
return job, podGroup, nil, err
84+
}
85+
86+
return job, podGroup, pods, nil
87+
}
88+
89+
func buildDistributedBatchJob(
90+
jobQueue *v2.Queue, opts DistributedBatchJobOptions, parallelism, minMember int32,
91+
) *batchv1.Job {
92+
job := CreateBatchJobObject(jobQueue, opts.Resources)
93+
job.Name = opts.NamePrefix + job.Name
94+
job.Spec.Parallelism = ptr.To(parallelism)
95+
job.Spec.Completions = ptr.To(parallelism)
96+
97+
if job.Annotations == nil {
98+
job.Annotations = map[string]string{}
99+
}
100+
job.Annotations[pgconstants.MinMemberOverrideKey] = fmt.Sprintf("%d", minMember)
101+
102+
if tc := opts.TopologyConstraint; tc != nil {
103+
if tc.Topology != "" {
104+
job.Annotations[pgconstants.TopologyKey] = tc.Topology
105+
}
106+
if tc.RequiredTopologyLevel != "" {
107+
job.Annotations[pgconstants.TopologyRequiredPlacementKey] = tc.RequiredTopologyLevel
108+
}
109+
if tc.PreferredTopologyLevel != "" {
110+
job.Annotations[pgconstants.TopologyPreferredPlacementKey] = tc.PreferredTopologyLevel
111+
}
112+
}
113+
114+
if opts.Preemptibility != "" {
115+
job.Labels[pgconstants.PreemptibilityLabelKey] = string(opts.Preemptibility)
116+
}
117+
118+
if opts.PriorityClassName != "" {
119+
job.Spec.Template.Spec.PriorityClassName = opts.PriorityClassName
120+
}
121+
122+
maps.Copy(job.Spec.Template.ObjectMeta.Labels, opts.ExtraLabels)
123+
124+
if opts.PodSpecMutator != nil {
125+
opts.PodSpecMutator(&job.Spec.Template.Spec)
126+
}
127+
128+
return job
129+
}
130+
131+
func waitForPodGroup(
132+
ctx context.Context, kubeClient runtimeClient.Client, job *batchv1.Job,
133+
) (*v2alpha2.PodGroup, error) {
134+
name := PodGroupNameForJob(job)
135+
pg := &v2alpha2.PodGroup{}
136+
key := types.NamespacedName{Namespace: job.Namespace, Name: name}
137+
138+
err := wait.PollUntilContextTimeout(ctx, podGroupFetchPoll, podGroupFetchTimeout, true,
139+
func(ctx context.Context) (bool, error) {
140+
err := kubeClient.Get(ctx, key, pg)
141+
if errors.IsNotFound(err) {
142+
return false, nil
143+
}
144+
return err == nil, err
145+
})
146+
if err != nil {
147+
return nil, fmt.Errorf("wait for PodGroup %s: %w", name, err)
148+
}
149+
return pg, nil
150+
}
151+
152+
func waitForJobPods(
153+
ctx context.Context, kubeClient runtimeClient.Client, job *batchv1.Job, expected int32,
154+
) ([]*v1.Pod, error) {
155+
var pods []*v1.Pod
156+
err := wait.PollUntilContextTimeout(ctx, podGroupFetchPoll, podGroupFetchTimeout, true,
157+
func(ctx context.Context) (bool, error) {
158+
list := &v1.PodList{}
159+
err := kubeClient.List(ctx, list,
160+
runtimeClient.InNamespace(job.Namespace),
161+
runtimeClient.MatchingLabels{JobNameLabel: job.Name},
162+
)
163+
if err != nil {
164+
return false, err
165+
}
166+
if int32(len(list.Items)) < expected {
167+
return false, nil
168+
}
169+
pods = make([]*v1.Pod, 0, len(list.Items))
170+
for i := range list.Items {
171+
pods = append(pods, &list.Items[i])
172+
}
173+
return true, nil
174+
})
175+
if err != nil {
176+
return nil, fmt.Errorf("wait for %d pods of Job %s: %w", expected, job.Name, err)
177+
}
178+
return pods, nil
179+
}
180+
181+
// PodGroupNameForJob returns the deterministic name the podgrouper uses for a Job-owned PodGroup.
182+
func PodGroupNameForJob(job *batchv1.Job) string {
183+
return fmt.Sprintf("%s-%s-%s", pgconstants.PodGroupNamePrefix, job.Name, job.UID)
184+
}

test/e2e/scale/kwok_job_creation.go

Lines changed: 16 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@ package scale
55

66
import (
77
"context"
8-
"fmt"
9-
"sync"
108

119
. "github.com/onsi/gomega"
12-
"go.uber.org/multierr"
13-
"golang.org/x/exp/maps"
1410
batchv1 "k8s.io/api/batch/v1"
1511
v1 "k8s.io/api/core/v1"
1612
"k8s.io/utils/ptr"
@@ -19,9 +15,6 @@ import (
1915
"github.com/kai-scheduler/KAI-scheduler/pkg/apis/scheduling/v2alpha2"
2016
testcontext "github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/context"
2117
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd"
22-
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/pod_group"
23-
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/queue"
24-
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/utils"
2518
)
2619

2720
func createJobObjectForKwok(
@@ -30,74 +23,28 @@ func createJobObjectForKwok(
3023
resources v1.ResourceRequirements,
3124
extraLabels map[string]string,
3225
) *batchv1.Job {
33-
job := rd.CreateBatchJobObject(jobQueue, resources)
34-
addKWOKTaintsAndAffinity(&job.Spec.Template.Spec)
35-
36-
maps.Copy(job.Spec.Template.ObjectMeta.Labels, extraLabels)
37-
38-
Expect(createObjectWithRetries(ctx, testCtx.ControllerClient, job)).To(Succeed())
39-
26+
job, _, _, err := rd.CreateDistributedBatchJob(ctx, testCtx.ControllerClient, jobQueue,
27+
rd.DistributedBatchJobOptions{
28+
Resources: resources,
29+
ExtraLabels: extraLabels,
30+
PodSpecMutator: addKWOKTaintsAndAffinity,
31+
})
32+
Expect(err).To(Succeed())
4033
return job
4134
}
4235

43-
// createDistributedJobForKwok creates one distributed job with podsPerDistributedJob batch jobs each with one pod
4436
func createDistributedJobForKwok(
4537
ctx context.Context, testCtx *testcontext.TestContext,
4638
jobQueue *v2.Queue, resourcesPerPod v1.ResourceRequirements, numberOfTasks int,
4739
extraLabels map[string]string, topologyConstraint *v2alpha2.TopologyConstraint,
4840
) (*v2alpha2.PodGroup, []*v1.Pod, error) {
49-
namespace := queue.GetConnectedNamespaceToQueue(jobQueue)
50-
podGroup := pod_group.Create(
51-
namespace, "distributed-job-"+utils.GenerateRandomK8sName(10), jobQueue.Name,
52-
)
53-
podGroup.Spec.MinMember = ptr.To(int32(numberOfTasks))
54-
maps.Copy(podGroup.Labels, extraLabels)
55-
if topologyConstraint != nil {
56-
podGroup.Spec.TopologyConstraint = *topologyConstraint
57-
}
58-
59-
err := createObjectWithRetries(ctx, testCtx.ControllerClient, podGroup)
60-
if err != nil {
61-
return nil, nil, err
62-
}
63-
64-
var pods []*v1.Pod
65-
var creationError error
66-
podsLock := sync.Mutex{}
67-
var wg sync.WaitGroup
68-
69-
for i := range numberOfTasks {
70-
wg.Add(1)
71-
go func(i int) {
72-
defer wg.Done()
73-
74-
pod := rd.CreatePodObject(jobQueue, resourcesPerPod)
75-
pod.Name = fmt.Sprintf("distributed-pod-%d-%s", i, utils.GenerateRandomK8sName(10))
76-
77-
if pod.Annotations == nil {
78-
pod.Annotations = map[string]string{}
79-
}
80-
pod.Annotations[pod_group.PodGroupNameAnnotation] = podGroup.Name
81-
82-
maps.Copy(pod.Labels, extraLabels)
83-
addKWOKTaintsAndAffinity(&pod.Spec)
84-
85-
err := createObjectWithRetries(ctx, testCtx.ControllerClient, pod)
86-
87-
podsLock.Lock()
88-
if err != nil {
89-
creationError = multierr.Append(creationError, err)
90-
} else {
91-
pods = append(pods, pod)
92-
}
93-
podsLock.Unlock()
94-
}(i)
95-
}
96-
wg.Wait()
97-
98-
if creationError != nil {
99-
return nil, nil, fmt.Errorf("failed to create some pods: %w", creationError)
100-
}
101-
102-
return podGroup, pods, nil
41+
_, pg, pods, err := rd.CreateDistributedBatchJob(ctx, testCtx.ControllerClient, jobQueue,
42+
rd.DistributedBatchJobOptions{
43+
Parallelism: ptr.To(int32(numberOfTasks)),
44+
Resources: resourcesPerPod,
45+
ExtraLabels: extraLabels,
46+
TopologyConstraint: topologyConstraint,
47+
PodSpecMutator: addKWOKTaintsAndAffinity,
48+
})
49+
return pg, pods, err
10350
}

test/e2e/scale/kwok_test_utils.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,6 @@ var (
4242
}
4343
)
4444

45-
func createObjectWithRetries(ctx context.Context, kubeClient runtimeClient.Client, obj runtimeClient.Object) error {
46-
key := runtimeClient.ObjectKeyFromObject(obj)
47-
err := kubeClient.Get(ctx, key, obj)
48-
if err == nil {
49-
// object is not expected to exist in the cluster
50-
return fmt.Errorf("object %v already exists in the cluster", key)
51-
}
52-
53-
for i := 0; i < operationAttemptsRetries; i++ {
54-
err = kubeClient.Create(ctx, obj)
55-
if err == nil || errors.IsAlreadyExists(err) {
56-
return nil
57-
}
58-
time.Sleep(retryInterval)
59-
}
60-
return err
61-
}
62-
6345
func deleteObjectWithRetries(
6446
ctx context.Context, kubeClient runtimeClient.Client,
6547
obj runtimeClient.Object, opts ...runtimeClient.DeleteOption) error {

test/e2e/suites/allocate/topology/topology_test.go

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import (
1515
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/configurations/feature_flags"
1616
testcontext "github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/context"
1717
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd"
18-
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/pod_group"
1918
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/resources/rd/queue"
2019
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/utils"
2120
"github.com/kai-scheduler/KAI-scheduler/test/e2e/modules/wait"
@@ -260,24 +259,13 @@ var _ = Describe("Topology", Ordered, func() {
260259

261260
func createDistributedWorkload(ctx context.Context, testCtx *testcontext.TestContext,
262261
podCount int, podResource v1.ResourceList, topologyConstraint v2alpha2.TopologyConstraint) []*v1.Pod {
263-
namespace := queue.GetConnectedNamespaceToQueue(testCtx.Queues[0])
264-
queueName := testCtx.Queues[0].Name
265-
266-
podGroup := pod_group.Create(namespace, "distributed-pod-group"+utils.GenerateRandomK8sName(10), queueName)
267-
podGroup.Spec.MinMember = ptr.To(int32(podCount))
268-
podGroup.Spec.TopologyConstraint = topologyConstraint
269-
270-
pods := []*v1.Pod{}
271-
Expect(testCtx.ControllerClient.Create(ctx, podGroup)).To(Succeed())
272-
for i := 0; i < podCount; i++ {
273-
pod := rd.CreatePodObject(testCtx.Queues[0], v1.ResourceRequirements{Requests: podResource, Limits: podResource})
274-
pod.Name = "distributed-pod-" + utils.GenerateRandomK8sName(10)
275-
pod.Annotations[pod_group.PodGroupNameAnnotation] = podGroup.Name
276-
pod.Labels[pod_group.PodGroupNameAnnotation] = podGroup.Name
277-
_, err := rd.CreatePod(ctx, testCtx.KubeClientset, pod)
278-
Expect(err).To(Succeed())
279-
pods = append(pods, pod)
280-
}
281-
262+
_, _, pods, err := rd.CreateDistributedBatchJob(ctx, testCtx.ControllerClient, testCtx.Queues[0],
263+
rd.DistributedBatchJobOptions{
264+
Parallelism: ptr.To(int32(podCount)),
265+
NamePrefix: "distributed-" + utils.GenerateRandomK8sName(5) + "-",
266+
Resources: v1.ResourceRequirements{Requests: podResource, Limits: podResource},
267+
TopologyConstraint: &topologyConstraint,
268+
})
269+
Expect(err).To(Succeed())
282270
return pods
283271
}

0 commit comments

Comments
 (0)