|
18 | 18 | import random |
19 | 19 | import string |
20 | 20 | import uuid |
21 | | -from typing import Dict, List, Optional |
| 21 | +from typing import Dict, List, Optional, Union |
22 | 22 |
|
23 | 23 | import kubeflow.trainer.models as models |
24 | 24 | from kubeflow.trainer.constants import constants |
@@ -153,20 +153,21 @@ def train( |
153 | 153 | self, |
154 | 154 | runtime: types.Runtime = types.DEFAULT_RUNTIME, |
155 | 155 | initializer: Optional[types.Initializer] = None, |
156 | | - trainer: Optional[types.CustomTrainer] = None, |
| 156 | + trainer: Optional[Union[types.CustomTrainer, types.BuiltinTrainer]] = None, |
157 | 157 | ) -> str: |
158 | 158 | """ |
159 | 159 | Create the TrainJob. You can configure these types of training task: |
160 | 160 |
|
161 | 161 | - Custom Training Task: Training with a self-contained function that encapsulates |
162 | 162 | the entire model training process, e.g. `CustomTrainer`. |
| 163 | + - Builtin Training Task: Configures a post-training job using torchtune. |
163 | 164 |
|
164 | 165 | Args: |
165 | 166 | runtime (`types.Runtime`): Reference to one of existing Runtimes. |
166 | 167 | initializer (`Optional[types.Initializer]`): |
167 | 168 | Configuration for the dataset and model initializers. |
168 | | - trainer (`Optional[types.CustomTrainer]`): |
169 | | - Configuration for Custom Training Task. |
| 169 | + trainer (`Union[types.CustomTrainer, types.BuiltinTrainer, None]`): |
| 170 | + Configuration for Custom or Builtin Training Task. |
170 | 171 |
|
171 | 172 | Returns: |
172 | 173 | str: The unique name of the TrainJob that has been generated. |
|
0 commit comments