Skip to content

Commit c13ea6b

Browse files
committed
feat: Add train() and optimize() methods to TrainJobTemplate
TrainJobTemplate was a passive data container with no way to actually execute jobs. This change lets the template act as an entrypoint to both TrainerClient and OptimizerClient. - train() delegates to TrainerClient.train() using the template's pre-configured runtime, initializer, and trainer - optimize() delegates to OptimizerClient.optimize() passing self as the trial_template for hyperparameter tuning TYPE_CHECKING is used to avoid circular imports since TrainerClient and OptimizerClient both already import types.py. Signed-off-by: Sujal Shah <sujalshah28092004@gmail.com>
1 parent 43b9590 commit c13ea6b

File tree

1 file changed

+70
-2
lines changed

1 file changed

+70
-2
lines changed

kubeflow/trainer/types/types.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import abc
1617
from collections.abc import Callable
1718
from dataclasses import dataclass, field
1819
from datetime import datetime
1920
from enum import Enum
21+
from typing import TYPE_CHECKING, Any
2022
from urllib.parse import urlparse
2123

2224
import kubeflow.common.constants as common_constants
25+
26+
if TYPE_CHECKING:
27+
# Avoid circular imports — these are only used for type hints.
28+
from kubeflow.optimizer.api.optimizer_client import OptimizerClient
29+
from kubeflow.optimizer.types.algorithm_types import BaseAlgorithm
30+
from kubeflow.optimizer.types.optimization_types import Objective, TrialConfig
31+
from kubeflow.trainer.api.trainer_client import TrainerClient
2332
from kubeflow.trainer.constants import constants
2433

2534

@@ -237,7 +246,7 @@ class BuiltinTrainer:
237246

238247

239248
# Change it to list: BUILTIN_CONFIGS, once we support more Builtin Trainer configs.
240-
TORCH_TUNE = BuiltinTrainer.__annotations__["config"].__name__.lower().replace("config", "")
249+
TORCH_TUNE = TorchTuneConfig.__name__.lower().replace("config", "")
241250

242251

243252
class TrainerType(Enum):
@@ -492,7 +501,6 @@ class Initializer:
492501
model: HuggingFaceModelInitializer | S3ModelInitializer | None = None
493502

494503

495-
# TODO (andreyvelich): Add train() and optimize() methods to this class.
496504
@dataclass
497505
class TrainJobTemplate:
498506
"""TrainJob template configuration.
@@ -515,3 +523,63 @@ def keys(self):
515523

516524
def __getitem__(self, key):
517525
return getattr(self, key)
526+
527+
def train(
528+
self,
529+
client: TrainerClient,
530+
options: list | None = None,
531+
) -> str:
532+
"""Create a TrainJob using this template's configuration.
533+
534+
Args:
535+
client: A TrainerClient instance used to submit the job.
536+
options: Optional list of configuration options to apply to the TrainJob.
537+
538+
Returns:
539+
The unique name of the created TrainJob.
540+
541+
Raises:
542+
ValueError: Input arguments are invalid.
543+
TimeoutError: Timeout to create TrainJob.
544+
RuntimeError: Failed to create TrainJob.
545+
"""
546+
return client.train(
547+
runtime=self.runtime,
548+
initializer=self.initializer,
549+
trainer=self.trainer,
550+
options=options,
551+
)
552+
553+
def optimize(
554+
self,
555+
client: OptimizerClient,
556+
search_space: dict[str, Any],
557+
objectives: list[Objective] | None = None,
558+
algorithm: BaseAlgorithm | None = None,
559+
trial_config: TrialConfig | None = None,
560+
) -> str:
561+
"""Create an OptimizationJob for hyperparameter tuning using this template.
562+
563+
Args:
564+
client: An OptimizerClient instance used to submit the optimization job.
565+
search_space: Dictionary mapping parameter names to Search specifications
566+
using Search.uniform(), Search.loguniform(), Search.choice(), etc.
567+
objectives: List of objectives to optimize (e.g. minimize loss, maximize accuracy).
568+
algorithm: The optimization algorithm to use. Defaults to RandomSearch.
569+
trial_config: Optional configuration for how trials are run.
570+
571+
Returns:
572+
The unique name of the OptimizationJob (Experiment) that has been created.
573+
574+
Raises:
575+
ValueError: Input arguments are invalid.
576+
TimeoutError: Timeout to create OptimizationJob.
577+
RuntimeError: Failed to create OptimizationJob.
578+
"""
579+
return client.optimize(
580+
trial_template=self,
581+
search_space=search_space,
582+
objectives=objectives,
583+
algorithm=algorithm,
584+
trial_config=trial_config,
585+
)

0 commit comments

Comments
 (0)