1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from __future__ import annotations
1415
1516import abc
1617from collections .abc import Callable
1718from dataclasses import dataclass , field
1819from datetime import datetime
1920from enum import Enum
21+ from typing import TYPE_CHECKING , Any
2022from urllib .parse import urlparse
2123
2224import kubeflow .common .constants as common_constants
25+
26+ if TYPE_CHECKING :
27+ # Avoid circular imports — these are only used for type hints.
28+ from kubeflow .optimizer .api .optimizer_client import OptimizerClient
29+ from kubeflow .optimizer .types .algorithm_types import BaseAlgorithm
30+ from kubeflow .optimizer .types .optimization_types import Objective , TrialConfig
31+ from kubeflow .trainer .api .trainer_client import TrainerClient
2332from kubeflow .trainer .constants import constants
2433
2534
@@ -237,7 +246,7 @@ class BuiltinTrainer:
237246
238247
239248# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs.
240- TORCH_TUNE = BuiltinTrainer . __annotations__ [ "config" ] .__name__ .lower ().replace ("config" , "" )
249+ TORCH_TUNE = TorchTuneConfig .__name__ .lower ().replace ("config" , "" )
241250
242251
243252class TrainerType (Enum ):
@@ -492,7 +501,6 @@ class Initializer:
492501 model : HuggingFaceModelInitializer | S3ModelInitializer | None = None
493502
494503
495- # TODO (andreyvelich): Add train() and optimize() methods to this class.
496504@dataclass
497505class TrainJobTemplate :
498506 """TrainJob template configuration.
@@ -515,3 +523,63 @@ def keys(self):
515523
516524 def __getitem__ (self , key ):
517525 return getattr (self , key )
526+
527+ def train (
528+ self ,
529+ client : TrainerClient ,
530+ options : list | None = None ,
531+ ) -> str :
532+ """Create a TrainJob using this template's configuration.
533+
534+ Args:
535+ client: A TrainerClient instance used to submit the job.
536+ options: Optional list of configuration options to apply to the TrainJob.
537+
538+ Returns:
539+ The unique name of the created TrainJob.
540+
541+ Raises:
542+ ValueError: Input arguments are invalid.
543+ TimeoutError: Timeout to create TrainJob.
544+ RuntimeError: Failed to create TrainJob.
545+ """
546+ return client .train (
547+ runtime = self .runtime ,
548+ initializer = self .initializer ,
549+ trainer = self .trainer ,
550+ options = options ,
551+ )
552+
553+ def optimize (
554+ self ,
555+ client : OptimizerClient ,
556+ search_space : dict [str , Any ],
557+ objectives : list [Objective ] | None = None ,
558+ algorithm : BaseAlgorithm | None = None ,
559+ trial_config : TrialConfig | None = None ,
560+ ) -> str :
561+ """Create an OptimizationJob for hyperparameter tuning using this template.
562+
563+ Args:
564+ client: An OptimizerClient instance used to submit the optimization job.
565+ search_space: Dictionary mapping parameter names to Search specifications
566+ using Search.uniform(), Search.loguniform(), Search.choice(), etc.
567+ objectives: List of objectives to optimize (e.g. minimize loss, maximize accuracy).
568+ algorithm: The optimization algorithm to use. Defaults to RandomSearch.
569+ trial_config: Optional configuration for how trials are run.
570+
571+ Returns:
572+ The unique name of the OptimizationJob (Experiment) that has been created.
573+
574+ Raises:
575+ ValueError: Input arguments are invalid.
576+ TimeoutError: Timeout to create OptimizationJob.
577+ RuntimeError: Failed to create OptimizationJob.
578+ """
579+ return client .optimize (
580+ trial_template = self ,
581+ search_space = search_space ,
582+ objectives = objectives ,
583+ algorithm = algorithm ,
584+ trial_config = trial_config ,
585+ )
0 commit comments