Conversation
Proposing framework-aware trainer classes (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer) with automatic runtime discovery via the trainer.kubeflow.org/framework label, and a RuntimeConfig dataclass to separate per-job environment settings from training logic. Issue: kubeflow#285 Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
There was a problem hiding this comment.
Pull request overview
This PR adds a comprehensive design proposal for specialized trainer abstractions and a RuntimeConfig dataclass to the Kubeflow SDK. The proposal addresses current limitations in the SDK's trainer subsystem by introducing framework-aware trainer classes that bridge the gap between the generic CustomTrainer and the highly specific BuiltinTrainer.
Changes:
- Adds a detailed design proposal document describing a new BaseTrainer abstract interface and specialized framework trainers (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer)
- Proposes a RuntimeConfig dataclass to cleanly separate runtime environment settings from training logic
- Includes comprehensive documentation covering motivation, design details, API examples, migration strategy, test plan, and alternatives considered
| 3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported. | ||
| Specialized trainers are an additional option, not a replacement. | ||
| 4. **Tier 2 trainer implementations.** This proposal defines the extension mechanism | ||
| and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth, |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" throughout the document. This applies to references in text and comments, though the class name "HuggingFaceTrainer" would be correct as Python class names don't use spaces.
| and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth, | |
| and interface. Concrete Tier 2 implementations (Hugging Face, DeepSpeed, Unsloth, |
| # Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope) | ||
|
|
||
| @dataclass | ||
| class TransformersTrainer(BaseTrainer): | ||
| """Trainer for HuggingFace Transformers training. | ||
|
|
||
| Wraps HuggingFace's Trainer API and maps to a PyTorch runtime. |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the comment and docstring text.
| # Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope) | |
| @dataclass | |
| class TransformersTrainer(BaseTrainer): | |
| """Trainer for HuggingFace Transformers training. | |
| Wraps HuggingFace's Trainer API and maps to a PyTorch runtime. | |
| # Example: future Hugging Face trainer (NOT part of this proposal's implementation scope) | |
| @dataclass | |
| class TransformersTrainer(BaseTrainer): | |
| """Trainer for Hugging Face Transformers training. | |
| Wraps Hugging Face's Trainer API and maps to a PyTorch runtime. |
| │ | ||
| ┌─────┴──────────┐ | ||
| │ │ | ||
| HuggingFace DeepSpeed |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the diagram text.
| HuggingFace DeepSpeed | |
| Hugging Face DeepSpeed |
|
@szaher @andreyvelich — this proposal is really well thought out, especially the separation between BaseTrainer and framework-specific trainers along with RuntimeConfig. I had a question regarding TorchTrainer extensibility and runtime selection: Given that multiple torch-based runtimes may coexist (as discussed earlier in #287), how do you envision selecting the appropriate runtime for a given TorchTrainer instance? One possible approach could be:
This might help keep the API simple while still supporting multiple backends (e.g., TorchTune vs custom PEFT/TRL runtimes for LLM workflows) Curious if something along these lines aligns with the intended direction. Happy to explore this further or prototype once the design is clearer. |
kramaranya
left a comment
There was a problem hiding this comment.
Thanks @szaher!
Looks great to me, and it should be a great improvement to the user experience in Kubeflow SDK!
/assign @andreyvelich @astefanutti @briangallagher @Fiona-Waters @MStokluska
| 2. **`RuntimeConfig` dataclass** — A dedicated configuration object that cleanly separates | ||
| per-job runtime environment settings (packages, pip config, environment variables) from | ||
| training logic and scaling parameters. This replaces the current pattern where | ||
| `CustomTrainer` conflates runtime concerns with trainer concerns. |
There was a problem hiding this comment.
Would this require runtime/controller changes?
There was a problem hiding this comment.
no, it should required any backend changes but it should align closely with it.
| 3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported. | ||
| Specialized trainers are an additional option, not a replacement. |
There was a problem hiding this comment.
Is the plan to eventually deprecate those or do we want to always maintain both options?
There was a problem hiding this comment.
this is non-goal not a goal of the kep
| if runtime.trainer.framework not in self.supported_frameworks: | ||
| raise ValueError( | ||
| f"{type(self).__name__} supports frameworks " | ||
| f"{self.supported_frameworks}, but runtime '{runtime.name}' " | ||
| f"has framework '{runtime.trainer.framework}'" | ||
| ) |
There was a problem hiding this comment.
We also would need to validate runtime.trainer.trainer_type too
There was a problem hiding this comment.
is that supported by a backend label or annotation?
| def get_framework_args(self) -> dict: | ||
| args = {} | ||
| if self.max_restarts is not None: | ||
| args["max-restarts"] = str(self.max_restarts) | ||
| if self.monitor_interval is not None: | ||
| args["monitor-interval"] = str(self.monitor_interval) | ||
| return args |
There was a problem hiding this comment.
where these new args go in the TrainJob spec?
There was a problem hiding this comment.
those aren't backend args it's framework args. it gets passed via the entrypoint to the script that runs in the pods
|
@kramaranya: GitHub didn't allow me to assign the following users: MStokluska. Note that only kubeflow members with read permissions, repo collaborators and people who have commented on this issue/PR can be assigned. Additionally, issues/PRs can only have 10 assignees at the same time. DetailsIn response to this:
Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
+1 on the points around validation and argument placement, I had a related question while reading through this. For the framework-specific args (e.g.
This also seems tied to whether Clarifying this mapping would help understand how far the abstraction goes (SDK-only vs API/CRD impact). |
|
Thanks @szaher - it looks really good to me! |
|
/lgtm |
astefanutti
left a comment
There was a problem hiding this comment.
Thanks @szaher!
/lgtm
|
/assign @kubeflow/kubeflow-sdk-team |
andreyvelich
left a comment
There was a problem hiding this comment.
Thanks @szaher! I left a few comments, sorry for the delay.
|
|
||
| <!-- | ||
| This proposal targets the kubeflow/sdk repository. | ||
| Directory: docs/proposals/285-specialized-trainers/README.md | ||
| --> | ||
|
|
||
| | | | | ||
| | -------------- | ------------------------------------------------------------ | | ||
| | **Authors** | @szaher | | ||
| | **Status** | Draft | | ||
| | **Created** | 2026-02-11 | | ||
| | **Reviewers** | | | ||
| | **Supersedes** | N/A | | ||
| | **Relevant Issues** | https://github.com/kubeflow/sdk/issues/285 | |
There was a problem hiding this comment.
Instead of this, can we consider to add kep.yaml in a format that we use for k8s. Check: https://github.com/kubernetes-sigs/jobset/blob/main/keps/463-ElasticJobsets/kep.yaml
Also, you can add implementation history as here: https://github.com/kubeflow/trainer/blob/master/docs/proposals/2170-kubeflow-trainer-v2/README.md#implementation-history
| (`kubeflow/sdk`) trainer subsystem: | ||
|
|
||
| 1. **Specialized, framework-aware trainer abstractions** — A new `BaseTrainer` abstract | ||
| interface and a suite of framework-specific implementations (`TorchTrainer`, `MPITrainer`, |
There was a problem hiding this comment.
Why MPI and Torch are in the same category?
MPI is just a technology we use for distributed workloads.
As you can see, we create dedicated runtime like DeepSpeed Distributed and MLX Distributed which leverage MPI: https://github.com/kubeflow/trainer/tree/master/manifests/base/runtimes
There was a problem hiding this comment.
That is similar to TorchTrainer it's a sort of custom trainer where we pass your code as to the dedicate runtime as is.
There was a problem hiding this comment.
But we don't support MPI Runtime upstream, we only deploy DeepSpeedRuntime and MLXRuntime which are MPI-based, so I am not sure if we should have dedicated MPITrainer
There was a problem hiding this comment.
Can we have DeepSpeed Runtime based on torch plugin and another runtime based on MPI?
There was a problem hiding this comment.
DeepSpeed can be bootstrapped via mpirun or torchrun: https://www.deepspeed.ai/getting-started/
It depends how users want to configure it, but today we serve only MPI-based runtime for DeepSpeed in Trainer: https://github.com/kubeflow/trainer/blob/master/manifests/base/runtimes/deepspeed_distributed.yaml
Shall we start with dedicated DeepSpeedTrainer which will be MPI-based for now?
@astefanutti Any thoughts?
Maybe @kuizhiqing or @tenzen-y has more information when users want to use mpirun or torchrun to run DeepSpeed workloads?
| `JAXTrainer`, etc.) that automatically discover and validate the correct | ||
| `ClusterTrainingRuntime` using the `trainer.kubeflow.org/framework` label. This fills the | ||
| "missing middle" between the overly generic `CustomTrainer` and the overly narrow | ||
| `BuiltinTrainer`. |
There was a problem hiding this comment.
Do we require to do any changes to BuiltinTrainer after this? IIRC, we said that this will be the foundation for Builtins as well.
cc @Sapthagiri777 @Electronic-Waste @khushiiagrawal
There was a problem hiding this comment.
This connects directly to the TRL/RLHF integration discussed in #2839, a TRLTrainer (or more generally, a post-training specialized trainer) would sit in Tier 2 here, built on top of BaseTrainer. The BuiltinTrainer today has the hardcoded isinstance(trainer.config, TorchTuneConfig) check, a dynamic registry pattern (as proposed in #2839) would let TRL and other frameworks plug in without modifying that dispatch logic. Happy to draft what that integration point would look like if useful.
There was a problem hiding this comment.
I think we don't need the builtin Trainer but I am just keeping it for backward compatibility unless we don't care I can update the kep to remove it.
There was a problem hiding this comment.
I suppose the question may be how we handle config-driven trainers for post-training LLM fine-tuning that currently fall under the scope of BuiltInTrainers.
If the BaseTrainer hierarchy is purely for function-based trainers then do we segregate config-driven trainers entirely outside of the BaseTrainer scope?
Or if the BaseTrainer hierarchy is meant to be a catch-all for everything then do we classify these "dynamic LLM trainers" as Tier-2 trainers, and if so, how do we differentiate them from function-based Tier 2 trainers such as the proposed TransformersTrainer?
I am assuming either way we may need to retain an interface for config-driven trainers. But the appropriate placement for them may help to determine the scope of kubeflow/trainer#2839 within the broader scope for this KEP.
There was a problem hiding this comment.
Many post-training frameworks (TRL, Unsloth, Axolotl, etc.) are effectively config-driven trainers, where the training entrypoint is a framework trainer object (e.g. SFTTrainer) rather than a user-defined train() function.
While experimenting with a small prototype around the backend registry idea from #2839, one pattern that seemed to work well was resolving the backend based on the trainer config type, e.g.:
@register_backend(TRLConfig)
class TRLBackend(RuntimeBackend):
...This keeps the core BaseTrainer abstraction simple while allowing config-driven frameworks to plug in dynamically without expanding SDK dispatch logic.
Conceptually it would allow both styles to coexist:
- function-based trainers (
TorchTrainer(train_fn=...)) - config-driven trainers (
TorchTrainer(config=TRLConfig(...)))
with the backend registry selecting the appropriate runtime adapter.
Curious if this aligns with the intended direction for the dynamic trainer framework.
There was a problem hiding this comment.
@tariq-hasan is right. @szaher I am trying to understand how are we going to refactor the BuiltinTrainer interface once we implement the BaseTrainer? And how can we dynamically register new LLM fine-tuning framework backends ?
| For the majority of distributed training workloads — "run this PyTorch DDP function on | ||
| N nodes" or "run this MPI script across a cluster" — neither abstraction fits well. | ||
| Users must either use the low-level `CustomTrainer` with manual runtime wiring, or | ||
| fall back to raw YAML. |
There was a problem hiding this comment.
Would love to hear feedback from @vsoch on this proposal.
Do you know if we have any interest from HPC users to leverage our Kubeflow Python SDK for TrainJob submission to manage HPC tasks on k8s?
| 2. Implement Tier 1 framework-specific trainers (`TorchTrainer`, `MPITrainer`, | ||
| `JAXTrainer`, `XGBoostTrainer`) that auto-discover runtimes by the | ||
| `trainer.kubeflow.org/framework` label and validate runtime compatibility. |
There was a problem hiding this comment.
Who will define which frameworks we will support?
For example, we also have DeepSpeed and MLX framework: https://github.com/kubeflow/trainer/blob/master/manifests/base/runtimes/deepspeed_distributed.yaml#L6
@astefanutti @tenzen-y Any thoughts on exposing framework in the Runtime API directly, or it is fine to use labels for now?
There was a problem hiding this comment.
The supported frameworks are defined on the trainer side by the training runtimes. Also my understand is we want it to be extensible so users can bring their own runtimes / frameworks / trainers.
I think using the trainer.kubeflow.org/framework is OK. I remember we discussed whether it could be in the TrainingRuntime API when we introduced it: #31 (comment)
Now that we are aligning the SDK to rely on that information, it may warrant promoting it to a proper spec field.
Accessorily, field indexing is now supported for custom resources: https://kubernetes.io/docs/concepts/overview/working-with-objects/field-selectors/#custom-resources-fields
There was a problem hiding this comment.
That sounds great! We can later refactor this to a dedicated Runtime API field if that is needed.
Also my understand is we want it to be extensible so users can bring their own runtimes / frameworks / trainers.
Yeah, that makes sense, and users can bring their own Trainers via Tier 1 and Tier 2 extensions as @szaher mentioned in the KEP.
| SDK codebase. | ||
| - `PipConfig` is a separate dataclass rather than inline fields, because pip | ||
| configuration is a distinct concern with its own options. | ||
| - `packages` replaces `packages_to_install` for brevity. |
There was a problem hiding this comment.
That will make us in-consistent with KFP.
Do we see any issues since @MStokluska is working on PipelinesClient?
#125
cc @kubeflow/wg-pipeline-leads
There was a problem hiding this comment.
the proposal here groups all pip configurations under a specific class which passed to the train function. the pipelines is using decorator to capture that so in their case it might be easier to keep all parameters flat for ease of user as I think most of the dsl.component might have a safe default value.
@dsl.component(base_image="quay.io/my-org/training-image:latest",
packages_to_install=[packages for customizing pipelines task runtime])
def submit_training(model_path: str, num_epochs: int = 3, nproc_per_node: int = 1):
from kubeflow.trainer import TrainerClient
client = TrainerClient()
client.train(..., pip_config=PipConfig(extra packages, enable verbose pip, extra index),
runtime_config=RuntimeConfig(extra packages that go on the training runtime.))
client.wait_for_job_status("train-job")There was a problem hiding this comment.
I am fine to move these fields under RuntimeConfig, but I suggest to keep parameter name as packages_to_install.
There was a problem hiding this comment.
sure, we can keep it as it is.
| "BaseTrainer", # NEW: accepts any specialized trainer | ||
| ] | ||
| ] = None, | ||
| runtime_config: Optional["RuntimeConfig"] = None, # NEW |
There was a problem hiding this comment.
Shall the runtime_config be part of Trainers?
I would imagine a use-case when users want to set custom env for the initializers as well.
cc @akshaychitneni
There was a problem hiding this comment.
we can move it to trainers if needed. do we need to also, move the pip config?
There was a problem hiding this comment.
Do we need to have separate PipConfig type under RuntimeConfig?
Another feature request I got from users that sometimes they want to install custom tools on top of base image (e.g. using apt install <package_name>).
We don't support it today, and users require to re-build Docker image manually.
There was a problem hiding this comment.
I think this one might need a separate KEP. We can use something like olot
@astefanutti WDYT?
There was a problem hiding this comment.
Yes it would probably make sense workspace snapshotting #48 be part of the dynamic runtime configuration.
Would dynamic runtime configuration make sense for "built-in" / tier-2 trainers? Would we be able to guarantee their runtimes can be configured dynamically?
If not, runtime_config might have to be part of Trainers conceptually, and maybe Initializers eventually.
There was a problem hiding this comment.
Yeah, I think this use-case should be covered by custom trainers when users have more control over training script. But I agree, we can discuss it in the followup proposals.
| Each Tier 1 trainer maps 1:1 to a framework identified by the | ||
| `trainer.kubeflow.org/framework` label value. |
There was a problem hiding this comment.
What happen if users deploy two ClusterTrainingRuntime with the same framework label?
There was a problem hiding this comment.
That case is covered in a latter section specifying the SDK fails fast and the user has to pass the name explicitely.
There was a problem hiding this comment.
But users can still use TorchTrainer if they provide the runtime name, right?
We just validate that Runtime has the correct framework label?
| 1. **Framework label check** (in `BaseTrainer.validate_runtime()`): Ensures the | ||
| runtime's `trainer.kubeflow.org/framework` label value is in the trainer's | ||
| `supported_frameworks` list. | ||
|
|
||
| 2. **Framework-specific checks** (in subclass overrides): For example, `MPITrainer` | ||
| could verify that the runtime's MPI policy source is configured correctly. |
There was a problem hiding this comment.
I think we should talk more on the validation. Usually, we offload this to the control plane (e.g. webhook).
If that is only framework compatibility, we can directly fetch the desired runtime using the label.
There was a problem hiding this comment.
sure, I think we need to validate job before submitting too and backend can validate everything in tact before running the job. SDK validation could just raise warnings.
|
|
||
| ```python | ||
| @dataclass | ||
| class TorchTrainer(BaseTrainer): |
There was a problem hiding this comment.
@kubeflow/kubeflow-sdk-team @Fiona-Waters @abhijeet-dhumal @akshaychitneni Did we ever explore what capabilities Ray Train implements in their trainers, e.g. TorchTrainer.
I know that they do some changes to the dataset and model to attach it to the Distributor.
Additionally (as a future extension), it would be interesting to validate users code in a way that it is correctly configured for distributed training with Torch or any other framework. For example, we can use AI Agents to analyze users' code before submission and suggest changes/enhancements, since we have a lot of context around distributed configuration on k8s.
There was a problem hiding this comment.
ray trainers have different philosophy, The trainer gets everything (runtime config, scale config, job config, job script(s), worker specific config) then calls trainer.fit() so there is no client initialization similar to kubeflow (we need that as we don't have a pre-deployed cluster)
we can discuss it if we want
Signed-off-by: Saad Zaher <szaher@redhat.com>
|
New changes are detected. LGTM label has been removed. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
Signed-off-by: Saad Zaher <szaher@redhat.com>
|
@szaher, |
|
Hi @szaher, great work on KEP-285! I wanted to flag a related KEP that's complementary to this proposal. KEP-2839: Dynamic LLM Trainer Framework (tracking issue: kubeflow/trainer#2839) introduces a pluggable In your KEP-285 terminology, these would be Tier 2 config-driven trainers. The two KEPs are designed to be compatible:
This also relates to @tariq-hasan's and @krishdef7's questions about how config-driven trainers fit into the Happy to coordinate on the design. The KEP PR is here: kubeflow/trainer#3263 |
What this PR does / why we need it:
Which issue(s) this PR fixes (optional, in
Fixes #<issue number>, #<issue number>, ...format, will close the issue(s) when PR gets merged):Fixes #
Checklist: