Skip to content

Commit 4bcadf4

Browse files
feat(runtimes): add support for ClusterTrainingRuntimes in Helm chart (#3124)
* feat(runtimes): add support for ClusterTrainingRuntimes in Helm chart Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix: remove initializerImage from user-configurable values Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * chore: regenerate README with helm-docs Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix: address Copilot review suggestions Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * feat: Introduce helper to centralize image tag resolution Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * refactor: nest cache image configuration and update copyright year. Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * chore: run make generate to sync Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * feat: add TorchTune distributed runtime, image helper usage, and default runtimes enabled flag Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * feat: enable runtime via a new default flag and add a comment Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * refactor: Torchtune runtimes to use model specific configurations, add default enabled option Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix: update README and fix trailing whitespace for CI Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * refactor: relocate runtime template and update its configuration Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * feat: add JAX distributed training support and update runtime configurations Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix(docs): remove JAX runtime from default runtimes in README Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix(docs): update default runtimes in README and values.yaml to include JAX Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix(workflows): increase Papermill timeout for e2e tests in GPU cluster Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> * fix(notebooks): add missing newline at end of qwen2.5-1.5B-with-alpaca.ipynb Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> --------- Signed-off-by: khushiiagrawal <khushisaritaagrawal@gmail.com> Signed-off-by: Khushi Agrawal <149886195+khushiiagrawal@users.noreply.github.com>
1 parent 7a14ec5 commit 4bcadf4

18 files changed

+1625
-258
lines changed

.github/workflows/test-e2e-gpu.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ jobs:
5656
- name: Run e2e test on GPU cluster
5757
run: |
5858
mkdir -p artifacts/notebooks
59-
make test-e2e-notebook NOTEBOOK_INPUT=./examples/torchtune/qwen2_5/qwen2.5-1.5B-with-alpaca.ipynb NOTEBOOK_OUTPUT=./artifacts/notebooks/${{ matrix.kubernetes-version }}_qwen2_5_with_alpaca-trainjob-yaml.ipynb TIMEOUT=600
60-
make test-e2e-notebook NOTEBOOK_INPUT=./examples/jax/image-classification/mnist.ipynb NOTEBOOK_OUTPUT=./artifacts/notebooks/${{ matrix.kubernetes-version }}_jax_mnist.ipynb PAPERMILL_PARAMS="-p num_cpu 8 -p num_gpu 1 -p num_nodes 1" TIMEOUT=600
59+
make test-e2e-notebook NOTEBOOK_INPUT=./examples/torchtune/qwen2_5/qwen2.5-1.5B-with-alpaca.ipynb NOTEBOOK_OUTPUT=./artifacts/notebooks/${{ matrix.kubernetes-version }}_qwen2_5_with_alpaca-trainjob-yaml.ipynb PAPERMILL_TIMEOUT=1800
60+
make test-e2e-notebook NOTEBOOK_INPUT=./examples/jax/image-classification/mnist.ipynb NOTEBOOK_OUTPUT=./artifacts/notebooks/${{ matrix.kubernetes-version }}_jax_mnist.ipynb PAPERMILL_PARAMS="-p num_cpu 8 -p num_gpu 1 -p num_nodes 1" PAPERMILL_TIMEOUT=1800
6161
6262
- name: Upload Artifacts to GitHub
6363
if: always()

charts/kubeflow-trainer/README.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,64 @@ Alternatively, you can install the latest version from the master branch (e.g. `
3131
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer --version 0.0.0-sha-bfccb7b
3232
```
3333

34+
### Install with ClusterTrainingRuntimes
35+
36+
You can optionally deploy ClusterTrainingRuntimes as part of the Helm installation. Runtimes are disabled by default to keep the chart lightweight.
37+
38+
To enable all default runtimes (torch, deepspeed, mlx, torchtune):
39+
40+
```bash
41+
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer \
42+
--version 2.1.0 \
43+
--set runtimes.defaultEnabled=true
44+
```
45+
46+
To enable specific runtimes:
47+
48+
```bash
49+
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer \
50+
--version 2.1.0 \
51+
--set runtimes.torchDistributed.enabled=true \
52+
--set runtimes.deepspeedDistributed.enabled=true
53+
```
54+
55+
Or use a custom values file:
56+
57+
```yaml
58+
# values.yaml
59+
runtimes:
60+
torchDistributed:
61+
enabled: true
62+
deepspeedDistributed:
63+
enabled: true
64+
mlxDistributed:
65+
enabled: true
66+
67+
# For torch-distributed-with-cache, enable both dataCache.enabled and dataCache.runtimes.torchDistributed.enabled
68+
dataCache:
69+
enabled: true
70+
cacheImage:
71+
tag: "v2.0.0"
72+
runtimes:
73+
torchDistributed:
74+
enabled: true
75+
```
76+
77+
Then install with:
78+
79+
```bash
80+
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer \
81+
--version 2.1.0 \
82+
-f values.yaml
83+
```
84+
85+
### Available Runtimes
86+
87+
- **torch-distributed**: PyTorch distributed training (no custom images)
88+
- **torch-distributed-with-cache**: PyTorch with distributed data cache support (requires `dataCache.enabled=true`)
89+
- **deepspeed-distributed**: DeepSpeed distributed training with MPI
90+
- **mlx-distributed**: MLX distributed training with MPI
91+
3492
### Uninstall the chart
3593

3694
```shell
@@ -72,6 +130,37 @@ See [helm uninstall](https://helm.sh/docs/helm/helm_uninstall) for command docum
72130
| dataCache.enabled | bool | `false` | Enable/disable data cache support (LWS dependency, ClusterRole). Set to `true` to install data cache components. |
73131
| dataCache.lws.install | bool | `true` | Whether to install LeaderWorkerSet as a dependency. Set to `false` if LeaderWorkerSet is already installed in the cluster. |
74132
| dataCache.lws.fullnameOverride | string | `"lws"` | String to fully override LeaderWorkerSet release name. |
133+
| dataCache.cacheImage.registry | string | `"ghcr.io"` | Data cache image registry |
134+
| dataCache.cacheImage.repository | string | `"kubeflow/trainer/data-cache"` | Data cache image repository |
135+
| dataCache.cacheImage.tag | string | `""` | Data cache image tag. Defaults to chart version if empty. |
136+
| dataCache.runtimes.torchDistributed | object | `{"enabled":false}` | PyTorch distributed training with data cache support |
137+
| dataCache.runtimes.torchDistributed.enabled | bool | `false` | Enable deployment of torch-distributed-with-cache runtime |
138+
| runtimes | object | `{"deepspeedDistributed":{"enabled":false,"image":{"registry":"ghcr.io","repository":"kubeflow/trainer/deepspeed-runtime","tag":""}},"defaultEnabled":false,"jaxDistributed":{"enabled":false},"mlxDistributed":{"enabled":false,"image":{"registry":"ghcr.io","repository":"kubeflow/trainer/mlx-runtime","tag":""}},"torchDistributed":{"enabled":false},"torchtuneDistributed":{"image":{"registry":"ghcr.io","repository":"kubeflow/trainer/torchtune-trainer","tag":""},"llama3_2_1B":{"enabled":false},"llama3_2_3B":{"enabled":false},"qwen2_5_1_5B":{"enabled":false}}}` | ClusterTrainingRuntimes configuration These are optional runtime templates that can be deployed with the Helm chart. Each runtime provides a blueprint for different ML frameworks and configurations. |
139+
| runtimes.defaultEnabled | bool | `false` | Enable all default runtimes (torch, deepspeed, mlx, jax, torchtune) when set to true. Individual runtime settings will be ignored if this is enabled. |
140+
| runtimes.torchDistributed | object | `{"enabled":false}` | PyTorch distributed training runtime (no custom images required) |
141+
| runtimes.torchDistributed.enabled | bool | `false` | Enable deployment of torch-distributed runtime |
142+
| runtimes.deepspeedDistributed | object | `{"enabled":false,"image":{"registry":"ghcr.io","repository":"kubeflow/trainer/deepspeed-runtime","tag":""}}` | DeepSpeed distributed training runtime |
143+
| runtimes.deepspeedDistributed.enabled | bool | `false` | Enable deployment of deepspeed-distributed runtime |
144+
| runtimes.deepspeedDistributed.image.registry | string | `"ghcr.io"` | DeepSpeed runtime image registry |
145+
| runtimes.deepspeedDistributed.image.repository | string | `"kubeflow/trainer/deepspeed-runtime"` | DeepSpeed runtime image repository |
146+
| runtimes.deepspeedDistributed.image.tag | string | `""` | DeepSpeed runtime image tag. Defaults to chart version if empty. |
147+
| runtimes.mlxDistributed | object | `{"enabled":false,"image":{"registry":"ghcr.io","repository":"kubeflow/trainer/mlx-runtime","tag":""}}` | MLX distributed training runtime |
148+
| runtimes.mlxDistributed.enabled | bool | `false` | Enable deployment of mlx-distributed runtime |
149+
| runtimes.mlxDistributed.image.registry | string | `"ghcr.io"` | MLX runtime image registry |
150+
| runtimes.mlxDistributed.image.repository | string | `"kubeflow/trainer/mlx-runtime"` | MLX runtime image repository |
151+
| runtimes.mlxDistributed.image.tag | string | `""` | MLX runtime image tag. Defaults to chart version if empty. |
152+
| runtimes.jaxDistributed | object | `{"enabled":false}` | JAX distributed training runtime (no custom images required) |
153+
| runtimes.jaxDistributed.enabled | bool | `false` | Enable deployment of jax-distributed runtime |
154+
| runtimes.torchtuneDistributed | object | `{"image":{"registry":"ghcr.io","repository":"kubeflow/trainer/torchtune-trainer","tag":""},"llama3_2_1B":{"enabled":false},"llama3_2_3B":{"enabled":false},"qwen2_5_1_5B":{"enabled":false}}` | TorchTune distributed training runtime |
155+
| runtimes.torchtuneDistributed.image.registry | string | `"ghcr.io"` | TorchTune runtime image registry |
156+
| runtimes.torchtuneDistributed.image.repository | string | `"kubeflow/trainer/torchtune-trainer"` | TorchTune runtime image repository |
157+
| runtimes.torchtuneDistributed.image.tag | string | `""` | TorchTune runtime image tag. Defaults to chart version if empty. |
158+
| runtimes.torchtuneDistributed.llama3_2_1B | object | `{"enabled":false}` | Llama 3.2 1B model configuration |
159+
| runtimes.torchtuneDistributed.llama3_2_1B.enabled | bool | `false` | Enable deployment of Llama 3.2 1B runtime |
160+
| runtimes.torchtuneDistributed.llama3_2_3B | object | `{"enabled":false}` | Llama 3.2 3B model configuration |
161+
| runtimes.torchtuneDistributed.llama3_2_3B.enabled | bool | `false` | Enable deployment of Llama 3.2 3B runtime |
162+
| runtimes.torchtuneDistributed.qwen2_5_1_5B | object | `{"enabled":false}` | Qwen 2.5 1.5B model configuration |
163+
| runtimes.torchtuneDistributed.qwen2_5_1_5B.enabled | bool | `false` | Enable deployment of Qwen 2.5 1.5B runtime |
75164

76165
## Maintainers
77166

charts/kubeflow-trainer/README.md.gotmpl

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{{- /*
2-
Copyright 2025 The Kubeflow authors.
2+
Copyright 2026 The Kubeflow authors.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -49,6 +49,64 @@ Alternatively, you can install the latest version from the master branch (e.g. `
4949
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer --version 0.0.0-sha-bfccb7b
5050
```
5151

52+
### Install with ClusterTrainingRuntimes
53+
54+
You can optionally deploy ClusterTrainingRuntimes as part of the Helm installation. Runtimes are disabled by default to keep the chart lightweight.
55+
56+
To enable all default runtimes (torch, deepspeed, mlx, torchtune):
57+
58+
```bash
59+
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer \
60+
--version 2.1.0 \
61+
--set runtimes.defaultEnabled=true
62+
```
63+
64+
To enable specific runtimes:
65+
66+
```bash
67+
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer \
68+
--version 2.1.0 \
69+
--set runtimes.torchDistributed.enabled=true \
70+
--set runtimes.deepspeedDistributed.enabled=true
71+
```
72+
73+
Or use a custom values file:
74+
75+
```yaml
76+
# values.yaml
77+
runtimes:
78+
torchDistributed:
79+
enabled: true
80+
deepspeedDistributed:
81+
enabled: true
82+
mlxDistributed:
83+
enabled: true
84+
85+
# For torch-distributed-with-cache, enable both dataCache.enabled and dataCache.runtimes.torchDistributed.enabled
86+
dataCache:
87+
enabled: true
88+
cacheImage:
89+
tag: "v2.0.0"
90+
runtimes:
91+
torchDistributed:
92+
enabled: true
93+
```
94+
95+
Then install with:
96+
97+
```bash
98+
helm install kubeflow-trainer oci://ghcr.io/kubeflow/charts/kubeflow-trainer \
99+
--version 2.1.0 \
100+
-f values.yaml
101+
```
102+
103+
### Available Runtimes
104+
105+
- **torch-distributed**: PyTorch distributed training (no custom images)
106+
- **torch-distributed-with-cache**: PyTorch with distributed data cache support (requires `dataCache.enabled=true`)
107+
- **deepspeed-distributed**: DeepSpeed distributed training with MPI
108+
- **mlx-distributed**: MLX distributed training with MPI
109+
52110
### Uninstall the chart
53111

54112
```shell

charts/kubeflow-trainer/templates/_helpers.tpl

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,61 @@ app.kubernetes.io/name: {{ include "trainer.name" . }}
6464
app.kubernetes.io/instance: {{ .Release.Name }}
6565
{{- end }}
6666

67+
{{/*
68+
Resolve the effective image tag, using a provided tag if present or
69+
falling back to the default image tag derived from the chart version.
70+
Usage: include "trainer.resolveImageTag" (dict "tag" .Values.image.tag "context" .)
71+
*/}}
72+
{{- define "trainer.resolveImageTag" -}}
73+
{{- if .tag }}
74+
{{- .tag -}}
75+
{{- else -}}
76+
{{- include "trainer.defaultImageTag" .context -}}
77+
{{- end -}}
78+
{{- end }}
79+
6780
{{- define "trainer.image" -}}
6881
{{- $imageRegistry := .Values.image.registry | default "docker.io" }}
6982
{{- $imageRepository := .Values.image.repository }}
70-
{{- $imageTag := .Values.image.tag -}}
71-
{{- if not $imageTag -}}
72-
{{- if hasPrefix "0.0.0-" .Chart.Version -}}
73-
{{- $imageTag = trimPrefix "0.0.0-" .Chart.Version -}}
74-
{{- else -}}
75-
{{- $imageTag = printf "v%s" .Chart.Version -}}
76-
{{- end -}}
77-
{{- end -}}
83+
{{- $imageTag := include "trainer.resolveImageTag" (dict "tag" .Values.image.tag "context" .) -}}
7884
{{- if eq $imageRegistry "docker.io" }}
7985
{{- printf "%s:%s" $imageRepository $imageTag }}
8086
{{- else }}
8187
{{- printf "%s/%s:%s" $imageRegistry $imageRepository $imageTag }}
8288
{{- end }}
8389
{{- end }}
8490

91+
{{/*
92+
Generate the default image tag for runtimes based on chart version
93+
*/}}
94+
{{- define "trainer.defaultImageTag" -}}
95+
{{- if hasPrefix "0.0.0-" .Chart.Version -}}
96+
{{- trimPrefix "0.0.0-" .Chart.Version -}}
97+
{{- else -}}
98+
{{- printf "v%s" .Chart.Version -}}
99+
{{- end -}}
100+
{{- end }}
101+
102+
{{/*
103+
Generate runtime image with registry, repository, and tag from values
104+
Usage: include "trainer.runtimeImage" (list .Values.runtimes.deepspeedDistributed.image .)
105+
*/}}
106+
{{- define "trainer.runtimeImage" -}}
107+
{{- $imageConfig := index . 0 }}
108+
{{- $root := index . 1 }}
109+
{{- $registry := $imageConfig.registry | default "ghcr.io" }}
110+
{{- $repository := $imageConfig.repository }}
111+
{{- $tag := include "trainer.resolveImageTag" (dict "tag" ($imageConfig.tag) "context" $root) -}}
112+
{{- if eq $registry "docker.io" }}
113+
{{- printf "%s:%s" $repository $tag }}
114+
{{- else }}
115+
{{- printf "%s/%s:%s" $registry $repository $tag }}
116+
{{- end }}
117+
{{- end }}
118+
{{/*
119+
Return the version of the trainer.
120+
If the version is 0.0.0, we assume it is a development version.
121+
*/}}
85122
{{- define "trainer.version" -}}
86123
{{- if hasPrefix "0.0.0-" .Chart.Version -}}
87124
dev
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
{{- /*
2+
Copyright 2026 The Kubeflow authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/ -}}
16+
17+
{{- if and .Values.dataCache.enabled .Values.dataCache.runtimes.torchDistributed.enabled }}
18+
apiVersion: trainer.kubeflow.org/v1alpha1
19+
kind: ClusterTrainingRuntime
20+
metadata:
21+
name: torch-distributed-with-cache
22+
labels:
23+
trainer.kubeflow.org/framework: torch
24+
{{- include "trainer.labels" . | nindent 4 }}
25+
spec:
26+
mlPolicy:
27+
numNodes: 1
28+
torch:
29+
numProcPerNode: auto
30+
template:
31+
spec:
32+
replicatedJobs:
33+
- name: dataset-initializer
34+
replicas: 1
35+
template:
36+
metadata:
37+
labels:
38+
trainer.kubeflow.org/trainjob-ancestor-step: dataset-initializer
39+
spec:
40+
template:
41+
spec:
42+
serviceAccountName: kubeflow-trainer-cache-initializer
43+
containers:
44+
- name: dataset-initializer
45+
image: {{ printf "ghcr.io/kubeflow/trainer/dataset-initializer:%s" (include "trainer.defaultImageTag" .) }}
46+
env:
47+
- name: CACHE_IMAGE
48+
value: {{ include "trainer.runtimeImage" (list .Values.dataCache.cacheImage .) | quote }}
49+
- name: TRAIN_JOB_NAME
50+
valueFrom:
51+
fieldRef:
52+
apiVersion: v1
53+
fieldPath: metadata.labels['jobset.sigs.k8s.io/jobset-name']
54+
- name: node
55+
dependsOn:
56+
- name: dataset-initializer
57+
status: Complete
58+
template:
59+
metadata:
60+
labels:
61+
trainer.kubeflow.org/trainjob-ancestor-step: trainer
62+
spec:
63+
template:
64+
spec:
65+
containers:
66+
- name: node
67+
image: pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime
68+
env:
69+
- name: TRAIN_JOB_NAME
70+
valueFrom:
71+
fieldRef:
72+
apiVersion: v1
73+
fieldPath: metadata.labels['jobset.sigs.k8s.io/jobset-name']
74+
{{- end }}

0 commit comments

Comments
 (0)