Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import abc
from collections.abc import Callable
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

import kubeflow.common.constants as common_constants

if TYPE_CHECKING:
# Avoid circular imports — these are only used for type hints.
from kubeflow.optimizer.api.optimizer_client import OptimizerClient
from kubeflow.optimizer.types.algorithm_types import BaseAlgorithm
from kubeflow.optimizer.types.optimization_types import Objective, TrialConfig
from kubeflow.trainer.api.trainer_client import TrainerClient
from kubeflow.trainer.constants import constants


Expand Down Expand Up @@ -237,7 +246,7 @@ class BuiltinTrainer:


# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs.
TORCH_TUNE = BuiltinTrainer.__annotations__["config"].__name__.lower().replace("config", "")
TORCH_TUNE = TorchTuneConfig.__name__.lower().replace("config", "")


class TrainerType(Enum):
Expand Down Expand Up @@ -492,7 +501,6 @@ class Initializer:
model: HuggingFaceModelInitializer | S3ModelInitializer | None = None


# TODO (andreyvelich): Add train() and optimize() methods to this class.
@dataclass
class TrainJobTemplate:
"""TrainJob template configuration.
Expand All @@ -515,3 +523,63 @@ def keys(self):

def __getitem__(self, key):
return getattr(self, key)

def train(
self,
client: TrainerClient,
options: list | None = None,
) -> str:
"""Create a TrainJob using this template's configuration.
Comment on lines +527 to +532
Copy link

Copilot AI Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment above TrainJobTemplate about adding train()/optimize() is now outdated after introducing these methods; remove or update it to avoid misleading future readers.

Copilot uses AI. Check for mistakes.

Args:
client: A TrainerClient instance used to submit the job.
options: Optional list of configuration options to apply to the TrainJob.

Returns:
The unique name of the created TrainJob.

Raises:
ValueError: Input arguments are invalid.
TimeoutError: Timeout to create TrainJob.
RuntimeError: Failed to create TrainJob.
"""
return client.train(
runtime=self.runtime,
initializer=self.initializer,
trainer=self.trainer,
options=options,
)

def optimize(
self,
client: OptimizerClient,
search_space: dict[str, Any],
objectives: list[Objective] | None = None,
algorithm: BaseAlgorithm | None = None,
trial_config: TrialConfig | None = None,
) -> str:
"""Create an OptimizationJob for hyperparameter tuning using this template.

Args:
client: An OptimizerClient instance used to submit the optimization job.
search_space: Dictionary mapping parameter names to Search specifications
using Search.uniform(), Search.loguniform(), Search.choice(), etc.
objectives: List of objectives to optimize (e.g. minimize loss, maximize accuracy).
algorithm: The optimization algorithm to use. Defaults to RandomSearch.
trial_config: Optional configuration for how trials are run.

Returns:
The unique name of the OptimizationJob (Experiment) that has been created.

Raises:
ValueError: Input arguments are invalid.
TimeoutError: Timeout to create OptimizationJob.
RuntimeError: Failed to create OptimizationJob.
"""
return client.optimize(
trial_template=self,
search_space=search_space,
objectives=objectives,
algorithm=algorithm,
trial_config=trial_config,
)