Skip to content

Commit 9ecc6da

Browse files
committed
feat: add post_training RuntimeConfig
certain APIs require a bunch of runtime arguments per-provider. The best way currently to pass these arguments in is via the provider config. This is tricky because it requires a provider to be pre-configured with certain arguments that a client side user should be able to pass in at runtime Especially with the advent of out-of-tree providers, it would be great for a generic RuntimeConfig class to allow for providers to add and validate their own runtime arguments for things like supervised_fine_tune For example: https://github.com/opendatahub-io/llama-stack-provider-kft has things like `input-pvc`, `model-path`, etc in the Provider Config. This is not sustainable nor is adding each and every field needed to the post_training API spec. RuntimeConfig has a sub-class called Config which allows for extra fields to arbitrarily be specified. It is the providers job to create its own class based on this one and add valid options, parse them, etc Signed-off-by: Charlie Doern <[email protected]>
1 parent bb1a85c commit 9ecc6da

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

llama_stack/apis/post_training/post_training.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ class PostTrainingJobArtifactsResponse(BaseModel):
169169
# TODO(ashwin): metrics, evals
170170

171171

172+
@json_schema_type
173+
class RuntimeConfig(BaseModel):
174+
"""
175+
Provider-specific runtime configuration. Providers should document and parse their own expected fields.
176+
This model allows arbitrary extra fields for maximum flexibility.
177+
"""
178+
class Config:
179+
extra = "allow"
180+
181+
172182
class PostTraining(Protocol):
173183
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
174184
async def supervised_fine_tune(
@@ -183,6 +193,7 @@ async def supervised_fine_tune(
183193
),
184194
checkpoint_dir: Optional[str] = None,
185195
algorithm_config: Optional[AlgorithmConfig] = None,
196+
runtime_config: Optional[RuntimeConfig] = None,
186197
) -> PostTrainingJob: ...
187198

188199
@webmethod(route="/post-training/preference-optimize", method="POST")
@@ -194,6 +205,7 @@ async def preference_optimize(
194205
training_config: TrainingConfig,
195206
hyperparam_search_config: Dict[str, Any],
196207
logger_config: Dict[str, Any],
208+
runtime_config: Optional[RuntimeConfig] = None,
197209
) -> PostTrainingJob: ...
198210

199211
@webmethod(route="/post-training/jobs", method="GET")

llama_stack/providers/inline/post_training/torchtune/post_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
PostTrainingJobArtifactsResponse,
2020
PostTrainingJobStatusResponse,
2121
TrainingConfig,
22+
RuntimeConfig,
2223
)
2324
from llama_stack.providers.inline.post_training.torchtune.config import (
2425
TorchtunePostTrainingConfig,
@@ -80,6 +81,7 @@ async def supervised_fine_tune(
8081
model: str,
8182
checkpoint_dir: Optional[str],
8283
algorithm_config: Optional[AlgorithmConfig],
84+
runtime_config: Optional[RuntimeConfig] = None,
8385
) -> PostTrainingJob:
8486
if isinstance(algorithm_config, LoraFinetuningConfig):
8587

0 commit comments

Comments
 (0)