feat(operator): support multi-slice TPU by enabling trainer replicas > 1#3408
feat(operator): support multi-slice TPU by enabling trainer replicas > 1#3408krishdef7 wants to merge 1 commit intokubeflow:masterfrom
Conversation
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
|
This PR (#3284) and #3408 conflict on |
There was a problem hiding this comment.
Pull request overview
This PR enables multi-slice TPU support by allowing the trainer ancestor to have replicas > 1 in JobSet configurations. Previously, the operator unconditionally restricted all replicatedJobs to replicas = 1. The changes preserve the replicas value from the runtime template, validate it conditionally (trainer can have replicas > 1, non-trainer must be 1), and adjust per-slice parallelism calculations to account for multiple replicas.
Changes:
- Webhook validation now allows trainer replicatedJobs to have replicas > 1 while keeping non-trainer ancestors restricted to replicas = 1
- Builder preserves replicas from runtime template instead of unconditionally overwriting with 1 (nil-guard pattern)
- JobSet Build() computes per-slice parallelism by dividing total node count by replica count for trainers
- Adds comprehensive test case for 4-slice TPU with 32 total nodes (8 per slice)
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
pkg/webhooks/trainingruntime_webhook.go |
Conditional validation: trainer ancestor can have replicas > 1, others remain restricted to 1 |
pkg/webhooks/trainingruntime_webhook_test.go |
Test updated to expect trainer with replicas > 1 to be valid |
pkg/runtime/framework/plugins/jobset/builder.go |
Nil-guard for trainer Replicas field to preserve template value |
pkg/runtime/framework/plugins/jobset/jobset.go |
Per-slice parallelism calculation divides total count by replica count for trainers |
pkg/runtime/core/trainingruntime.go |
Comment updates clarifying multi-slice TPU semantics |
pkg/runtime/core/trainingruntime_test.go |
New test case validating 4-slice configuration with correct parallelism and pod group sizing |
For multi-slice TPU, JobSet models each TPU slice as a ReplicatedJob replica, with parallelism = hosts per slice and replicas = slice count. The operator previously blocked this with two hard constraints: 1. builder.go unconditionally set trainer Replicas = 1, destroying any value from the runtime template. 2. trainingruntime_webhook.go rejected replicas != 1 for all ancestors including trainer. Changes: - builder.go: nil-guard for trainer Replicas, preserving the value from the runtime template instead of unconditional overwrite. - jobset.go: in Build(), compute perSlice = numNodes / replicas for the trainer ancestor so each slice runs the correct number of hosts. - trainingruntime_webhook.go: allow trainer ancestor replicas > 1 to enable multi-slice configurations to pass admission. - trainingruntime_webhook_test.go: update invalid_replicas test to reflect that trainer replicas > 1 is now valid. - trainingruntime_test.go: add test case for 4-slice x 8 hosts (NumNodes=32), verifying Parallelism=8 per slice and MinMember=34. Semantics: numNodes = total hosts across all slices. Per-slice hosts = numNodes / replicas. REF: kubeflow#3407 Signed-off-by: krishdef7 <gargkrish06@gmail.com>
e2e5600 to
41a8ca9
Compare
|
@andreyvelich @siyuanfoundation, this implements the multi-slice TPU support from #3407. The E2E 1.34.0 failure is the pre-existing single-version flake (1.32.3, 1.33.1, and 1.35.0 all pass). All unit tests pass. Could you take a look when you get a chance? |
Summary
Resolves #3407.
For multi-slice TPU, JobSet models each TPU slice as a
ReplicatedJobreplica (replicas= slice count,parallelism= hosts per slice). The operator previously blocked this with two hard constraints:builder.gounconditionally setReplicas = 1for the trainer ancestor, silently destroying any value from the runtime template.trainingruntime_webhook.gorejectedreplicas != 1for all known ancestors including trainer.Changes
pkg/runtime/framework/plugins/jobset/builder.goReplicas— preserves value from runtime template instead of unconditional overwritepkg/runtime/framework/plugins/jobset/jobset.goBuild(), computeperSlice = numNodes / replicasfor the trainer ancestor so each slice gets the correctParallelism/Completionspkg/webhooks/trainingruntime_webhook.goreplicas > 1; non-trainer ancestors are unchanged (still validated)pkg/webhooks/trainingruntime_webhook_test.goinvalid_replicasto reflect that trainerreplicas > 1is now validpkg/runtime/core/trainingruntime_test.goNumNodes=32), verifiesReplicas=4,Parallelism=8per slice,MinMember=34Semantics
numNodes= total hosts across all slices.The operator derives
perSlice = numNodes / replicasand sets it asParallelismandCompletionsperReplicatedJob. Users configure:The operator computes
perSlice = 32 / 4 = 8and applies it to eachreplica's
Parallelism/Completions.Relationship to PR #3284
Both PRs touch
trainingruntime_webhook.goandjobset.goin opposite directions:replicas > 1for non-trainer ancestors (DatasetInitializer, ModelInitializer), keeps trainer blocked.replicas > 1for the trainer ancestor, keeps non-trainer blocked.The correct combined final state is to allow
replicas > 1for all ancestors and remove the per-ancestor split in the webhook. Suggested merge strategy: merge #3284 first, then rebase this PR on top of it with the webhook restriction removed entirely.Testing
All unit tests pass (
go test ./pkg/...). New test exercises the full path throughnewRuntimeInfo→EnforceMLPolicy→Build().