Skip to content

Commit 44d118d

Browse files
committed
feat: support multiple replicas for non-trainer replicatedJobs
Support .template.spec.replicatedJobs[*].replicas > 1 to allow multiple replicated Jobs instead of a single Job with thousands of completions, which causes kube-controller-manager memory leaks and reconciliation delays at scale. Changes: - trainingruntime.go: read replicas from each replicatedJob and multiply with parallelism for PodGroup MinMember count; trainer ancestor uses NumNodes directly and is unaffected - builder.go: preserve Replicas field from runtime template instead of unconditionally overwriting with 1 - jobset.go: split Parallelism/Completions assignment in Build() — trainer uses count directly, non-trainer divides by replicas to get per-replica value Note: endpoint generation in IdentifyPodNetwork for multi-replica non-trainer jobs is tracked separately; initializer jobs do not participate in training network topology so this is currently harmless. Fixes #2318 Signed-off-by: krishdef7 <gargkrish06@gmail.com>
1 parent 66d0b0b commit 44d118d

File tree

6 files changed

+124
-19
lines changed

6 files changed

+124
-19
lines changed

pkg/runtime/core/trainingruntime.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,8 @@ func (r *TrainingRuntime) newRuntimeInfo(
166166
}
167167

168168
for i, rJob := range jobSetSpecApply.ReplicatedJobs {
169-
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
170-
// REF: https://github.com/kubeflow/trainer/issues/2318
171-
count := ptr.Deref(rJob.Template.Spec.Parallelism, 1)
169+
replicas := ptr.Deref(rJob.Replicas, 1)
170+
count := ptr.Deref(rJob.Template.Spec.Parallelism, 1) * replicas
172171
var ancestor *string
173172
if metadata := rJob.Template.ObjectMetaApplyConfiguration; metadata != nil && metadata.Labels != nil {
174173
if labelAncestor, ok := metadata.Labels[constants.LabelTrainJobAncestor]; ok {

pkg/runtime/core/trainingruntime_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,6 +2065,102 @@ test-job-node-0-1.test-job slots=8
20652065
Obj(),
20662066
},
20672067
},
2068+
"succeeded to build PodGroup and JobSet with multiple replicas for non-trainer replicatedJob.": {
2069+
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").
2070+
RuntimeSpec(
2071+
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
2072+
WithMLPolicy(
2073+
testingutil.MakeMLPolicyWrapper().
2074+
WithNumNodes(10).
2075+
Obj(),
2076+
).
2077+
PodGroupPolicyCoschedulingSchedulingTimeout(120).
2078+
Replicas(3, constants.DatasetInitializer).
2079+
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2080+
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2081+
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2082+
Obj(),
2083+
).Obj(),
2084+
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
2085+
UID("uid").
2086+
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime").
2087+
Trainer(
2088+
testingutil.MakeTrainJobTrainerWrapper().
2089+
NumNodes(10).
2090+
Obj(),
2091+
).
2092+
Obj(),
2093+
wantObjs: []runtime.Object{
2094+
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
2095+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
2096+
Replicas(3, constants.DatasetInitializer).
2097+
Replicas(1, constants.ModelInitializer, constants.Node, constants.Launcher).
2098+
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
2099+
Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
2100+
NumNodes(10).
2101+
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2102+
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2103+
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2104+
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
2105+
Obj(),
2106+
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
2107+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
2108+
MinMember(14). // 14 = 10 (trainer nodes) + 3 (DatasetInitializer replicas * 1 pod each) + 1 (ModelInitializer)
2109+
MinResources(corev1.ResourceList{
2110+
corev1.ResourceCPU: resource.MustParse("14"),
2111+
}).
2112+
SchedulingTimeout(120).
2113+
Obj(),
2114+
},
2115+
},
2116+
"succeeded to build PodGroup and JobSet with trainer replicatedJob Replicas ignored when NumNodes is set.": {
2117+
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").
2118+
RuntimeSpec(
2119+
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
2120+
WithMLPolicy(
2121+
testingutil.MakeMLPolicyWrapper().
2122+
WithNumNodes(100).
2123+
Obj(),
2124+
).
2125+
PodGroupPolicyCoschedulingSchedulingTimeout(120).
2126+
Replicas(4, constants.Node).
2127+
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2128+
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2129+
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2130+
Obj(),
2131+
).Obj(),
2132+
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
2133+
UID("uid").
2134+
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime").
2135+
Trainer(
2136+
testingutil.MakeTrainJobTrainerWrapper().
2137+
NumNodes(5).
2138+
Obj(),
2139+
).
2140+
Obj(),
2141+
wantObjs: []runtime.Object{
2142+
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
2143+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
2144+
Replicas(4, constants.Node).
2145+
Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
2146+
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
2147+
Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
2148+
NumNodes(5).
2149+
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2150+
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2151+
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
2152+
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
2153+
Obj(),
2154+
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
2155+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
2156+
MinMember(7). // 7 = 5 (NumNodes, NOT 5*4=20) + 1 (DatasetInitializer) + 1 (ModelInitializer)
2157+
MinResources(corev1.ResourceList{
2158+
corev1.ResourceCPU: resource.MustParse("7"),
2159+
}).
2160+
SchedulingTimeout(120).
2161+
Obj(),
2162+
},
2163+
},
20682164
// Failed test cases.
20692165
"missing trainingRuntime resource": {
20702166
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job-3").

pkg/runtime/framework/plugins/jobset/builder.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder {
4747
}
4848
// Update values for the Dataset Initializer Job.
4949
if ancestor, ok := jobMetadata.Labels[constants.LabelTrainJobAncestor]; ok && ancestor == constants.DatasetInitializer {
50-
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
51-
// REF: https://github.com/kubeflow/trainer/issues/2318
52-
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
50+
if b.Spec.ReplicatedJobs[i].Replicas == nil {
51+
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
52+
}
5353
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
5454
// Update values for the dataset initializer container.
5555
if *container.Name == constants.DatasetInitializer && trainJob.Spec.Initializer != nil && trainJob.Spec.Initializer.Dataset != nil {
@@ -73,9 +73,9 @@ func (b *Builder) Initializer(trainJob *trainer.TrainJob) *Builder {
7373
}
7474
// Update values for the Model Initializer Job.
7575
if ancestor, ok := jobMetadata.Labels[constants.LabelTrainJobAncestor]; ok && ancestor == constants.ModelInitializer {
76-
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
77-
// REF: https://github.com/kubeflow/trainer/issues/2318
78-
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
76+
if b.Spec.ReplicatedJobs[i].Replicas == nil {
77+
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
78+
}
7979
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
8080
// Update values for the model initializer container.
8181
if *container.Name == constants.ModelInitializer && trainJob.Spec.Initializer != nil && trainJob.Spec.Initializer.Model != nil {
@@ -118,9 +118,9 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build
118118
ancestor = jobMetadata.Labels[constants.LabelTrainJobAncestor]
119119
}
120120
if ancestor == constants.AncestorTrainer {
121-
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
122-
// REF: https://github.com/kubeflow/trainer/issues/2318
123-
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
121+
if b.Spec.ReplicatedJobs[i].Replicas == nil {
122+
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
123+
}
124124
// Update values for the Trainer container.
125125
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
126126
if *container.Name == constants.Node {

pkg/runtime/framework/plugins/jobset/jobset.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,22 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine
273273

274274
for psIdx, ps := range info.TemplateSpec.PodSets {
275275
if ps.Count != nil {
276-
jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count
277-
jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count
276+
rJob := &jobSetSpec.ReplicatedJobs[psIdx]
277+
jobMetadata := rJob.Template.ObjectMetaApplyConfiguration
278+
isTrainer := jobMetadata != nil && jobMetadata.Labels != nil &&
279+
jobMetadata.Labels[constants.LabelTrainJobAncestor] == constants.AncestorTrainer
280+
if isTrainer {
281+
// For trainer: count = NumNodes, use directly as Parallelism/Completions.
282+
rJob.Template.Spec.Parallelism = ps.Count
283+
rJob.Template.Spec.Completions = ps.Count
284+
} else {
285+
// For non-trainer: count = parallelism * replicas.
286+
// Parallelism/Completions must be per-replica, not total pod count.
287+
replicas := ptr.Deref(rJob.Replicas, 1)
288+
perReplica := *ps.Count / replicas
289+
rJob.Template.Spec.Parallelism = &perReplica
290+
rJob.Template.Spec.Completions = &perReplica
291+
}
278292
}
279293
apply.UpsertVolumes(&jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...)
280294
for containerIdx, container := range ps.Containers {

pkg/webhooks/trainingruntime_webhook.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ func validateReplicatedJobs(rJobs []jobsetv1alpha2.ReplicatedJob) field.ErrorLis
7676
}
7777

7878
if labelAncestor, ok := rJob.Template.Labels[constants.LabelTrainJobAncestor]; ok && ancestors.Has(labelAncestor) {
79-
if rJob.Replicas != 1 {
79+
if labelAncestor == constants.AncestorTrainer && rJob.Replicas != 1 {
8080
allErrs = append(allErrs, field.Invalid(rJobsPath.Index(idx).Child("replicas"), rJob.Replicas, rJobReplicasErrorMsg))
8181
}
8282

pkg/webhooks/trainingruntime_webhook_test.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ func TestValidateReplicatedJobs(t *testing.T) {
5252
Replicas(2, constants.Launcher, constants.Node, constants.DatasetInitializer, constants.ModelInitializer).
5353
Obj().Spec.ReplicatedJobs,
5454
wantError: field.ErrorList{
55-
field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(0).Child("replicas"),
56-
"2", ""),
57-
field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(1).Child("replicas"),
58-
"2", ""),
5955
field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(3).Child("replicas"),
6056
"2", ""),
6157
},

0 commit comments

Comments
 (0)