Skip to content

feat(operator): support multi-slice TPU by enabling trainer replicas > 1#3408

Open
krishdef7 wants to merge 1 commit intokubeflow:masterfrom
krishdef7:feat/trainer-multi-slice-tpu
Open

feat(operator): support multi-slice TPU by enabling trainer replicas > 1#3408
krishdef7 wants to merge 1 commit intokubeflow:masterfrom
krishdef7:feat/trainer-multi-slice-tpu

Conversation

@krishdef7
Copy link
Copy Markdown
Contributor

Summary

Resolves #3407.

For multi-slice TPU, JobSet models each TPU slice as a ReplicatedJob replica (replicas = slice count, parallelism = hosts per slice). The operator previously blocked this with two hard constraints:

  1. builder.go unconditionally set Replicas = 1 for the trainer ancestor, silently destroying any value from the runtime template.
  2. trainingruntime_webhook.go rejected replicas != 1 for all known ancestors including trainer.

Changes

File What changed
pkg/runtime/framework/plugins/jobset/builder.go nil-guard for trainer Replicas — preserves value from runtime template instead of unconditional overwrite
pkg/runtime/framework/plugins/jobset/jobset.go in Build(), compute perSlice = numNodes / replicas for the trainer ancestor so each slice gets the correct Parallelism/Completions
pkg/webhooks/trainingruntime_webhook.go allow trainer ancestor replicas > 1; non-trainer ancestors are unchanged (still validated)
pkg/webhooks/trainingruntime_webhook_test.go update invalid_replicas to reflect that trainer replicas > 1 is now valid
pkg/runtime/core/trainingruntime_test.go new test: 4 slices × 8 hosts (NumNodes=32), verifies Replicas=4, Parallelism=8 per slice, MinMember=34

Semantics

numNodes = total hosts across all slices.

The operator derives perSlice = numNodes / replicas and sets it as Parallelism and Completions per ReplicatedJob. Users configure:

# ClusterTrainingRuntime
replicatedJobs:
  - name: node
    replicas: 4        # number of TPU slices
    template:
      spec:
        parallelism: 1 # placeholder; overwritten by operator

# TrainJob
trainer:
  numNodes: 32         # total hosts across all 4 slices (8 per slice)

The operator computes perSlice = 32 / 4 = 8 and applies it to each
replica's Parallelism/Completions.

Relationship to PR #3284

⚠️ This PR conflicts with #3284 (feat: support multiple replicas for non-trainer replicatedJobs).

Both PRs touch trainingruntime_webhook.go and jobset.go in opposite directions:

The correct combined final state is to allow replicas > 1 for all ancestors and remove the per-ancestor split in the webhook. Suggested merge strategy: merge #3284 first, then rebase this PR on top of it with the webhook restriction removed entirely.

Testing

All unit tests pass (go test ./pkg/...). New test exercises the full path through newRuntimeInfoEnforceMLPolicyBuild().

Copilot AI review requested due to automatic review settings April 3, 2026 20:13
@google-oss-prow
Copy link
Copy Markdown

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign johnugeorge for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@krishdef7
Copy link
Copy Markdown
Contributor Author

This PR (#3284) and #3408 conflict on trainingruntime_webhook.go and jobset.go, they make opposite changes to the trainer/non-trainer split. The correct combined final state once both are merged is to remove the per-ancestor distinction in the webhook entirely (allow replicas > 1 for all ancestors) and unify the Build() path. Suggested merge order: #3284 first, then rebase the multi-slice PR on top with the webhook restriction removed. Flagging so maintainers can coordinate.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR enables multi-slice TPU support by allowing the trainer ancestor to have replicas > 1 in JobSet configurations. Previously, the operator unconditionally restricted all replicatedJobs to replicas = 1. The changes preserve the replicas value from the runtime template, validate it conditionally (trainer can have replicas > 1, non-trainer must be 1), and adjust per-slice parallelism calculations to account for multiple replicas.

Changes:

  • Webhook validation now allows trainer replicatedJobs to have replicas > 1 while keeping non-trainer ancestors restricted to replicas = 1
  • Builder preserves replicas from runtime template instead of unconditionally overwriting with 1 (nil-guard pattern)
  • JobSet Build() computes per-slice parallelism by dividing total node count by replica count for trainers
  • Adds comprehensive test case for 4-slice TPU with 32 total nodes (8 per slice)

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
pkg/webhooks/trainingruntime_webhook.go Conditional validation: trainer ancestor can have replicas > 1, others remain restricted to 1
pkg/webhooks/trainingruntime_webhook_test.go Test updated to expect trainer with replicas > 1 to be valid
pkg/runtime/framework/plugins/jobset/builder.go Nil-guard for trainer Replicas field to preserve template value
pkg/runtime/framework/plugins/jobset/jobset.go Per-slice parallelism calculation divides total count by replica count for trainers
pkg/runtime/core/trainingruntime.go Comment updates clarifying multi-slice TPU semantics
pkg/runtime/core/trainingruntime_test.go New test case validating 4-slice configuration with correct parallelism and pod group sizing

Comment thread pkg/runtime/framework/plugins/jobset/jobset.go
For multi-slice TPU, JobSet models each TPU slice as a ReplicatedJob
replica, with parallelism = hosts per slice and replicas = slice count.
The operator previously blocked this with two hard constraints:

1. builder.go unconditionally set trainer Replicas = 1, destroying any
   value from the runtime template.
2. trainingruntime_webhook.go rejected replicas != 1 for all ancestors
   including trainer.

Changes:
- builder.go: nil-guard for trainer Replicas, preserving the value from
  the runtime template instead of unconditional overwrite.
- jobset.go: in Build(), compute perSlice = numNodes / replicas for the
  trainer ancestor so each slice runs the correct number of hosts.
- trainingruntime_webhook.go: allow trainer ancestor replicas > 1 to
  enable multi-slice configurations to pass admission.
- trainingruntime_webhook_test.go: update invalid_replicas test to
  reflect that trainer replicas > 1 is now valid.
- trainingruntime_test.go: add test case for 4-slice x 8 hosts
  (NumNodes=32), verifying Parallelism=8 per slice and MinMember=34.

Semantics: numNodes = total hosts across all slices.
Per-slice hosts = numNodes / replicas.

REF: kubeflow#3407
Signed-off-by: krishdef7 <gargkrish06@gmail.com>
@krishdef7 krishdef7 force-pushed the feat/trainer-multi-slice-tpu branch from e2e5600 to 41a8ca9 Compare April 3, 2026 20:20
@krishdef7
Copy link
Copy Markdown
Contributor Author

@andreyvelich @siyuanfoundation, this implements the multi-slice TPU support from #3407. The E2E 1.34.0 failure is the pre-existing single-version flake (1.32.3, 1.33.1, and 1.35.0 all pass). All unit tests pass. Could you take a look when you get a chance?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support multi-slice TPU in trainer

2 participants