Skip to content

Commit 0ec5151

Browse files
committed
feat: add post_training RuntimeConfig
certain APIs require a bunch of runtime arguments per-provider. The best way currently to pass these arguments in is via the provider config. This is tricky because it requires a provider to be pre-configured with certain arguments that a client side user should be able to pass in at runtime Especially with the advent of out-of-tree providers, it would be great for a generic RuntimeConfig class to allow for providers to add and validate their own runtime arguments for things like supervised_fine_tune For example: https://github.com/opendatahub-io/llama-stack-provider-kft has things like `input-pvc`, `model-path`, etc in the Provider Config. This is not sustainable nor is adding each and every field needed to the post_training API spec. RuntimeConfig has a sub-class called Config which allows for extra fields to arbitrarily be specified. It is the providers job to create its own class based on this one and add valid options, parse them, etc Signed-off-by: Charlie Doern <[email protected]>
1 parent bb1a85c commit 0ec5151

File tree

4 files changed

+37
-0
lines changed

4 files changed

+37
-0
lines changed

docs/_static/llama-stack-spec.html

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10170,6 +10170,11 @@
1017010170
],
1017110171
"title": "OptimizerType"
1017210172
},
10173+
"RuntimeConfig": {
10174+
"type": "object",
10175+
"title": "RuntimeConfig",
10176+
"description": "Provider-specific runtime configuration. Providers should document and parse their own expected fields. This model allows arbitrary extra fields for maximum flexibility."
10177+
},
1017310178
"TrainingConfig": {
1017410179
"type": "object",
1017510180
"properties": {
@@ -10274,6 +10279,9 @@
1027410279
}
1027510280
]
1027610281
}
10282+
},
10283+
"runtime_config": {
10284+
"$ref": "#/components/schemas/RuntimeConfig"
1027710285
}
1027810286
},
1027910287
"additionalProperties": false,
@@ -11375,6 +11383,9 @@
1137511383
},
1137611384
"algorithm_config": {
1137711385
"$ref": "#/components/schemas/AlgorithmConfig"
11386+
},
11387+
"runtime_config": {
11388+
"$ref": "#/components/schemas/RuntimeConfig"
1137811389
}
1137911390
},
1138011391
"additionalProperties": false,

docs/_static/llama-stack-spec.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7000,6 +7000,13 @@ components:
70007000
- adamw
70017001
- sgd
70027002
title: OptimizerType
7003+
RuntimeConfig:
7004+
type: object
7005+
title: RuntimeConfig
7006+
description: >-
7007+
Provider-specific runtime configuration. Providers should document and parse
7008+
their own expected fields. This model allows arbitrary extra fields for maximum
7009+
flexibility.
70037010
TrainingConfig:
70047011
type: object
70057012
properties:
@@ -7060,6 +7067,8 @@ components:
70607067
- type: string
70617068
- type: array
70627069
- type: object
7070+
runtime_config:
7071+
$ref: '#/components/schemas/RuntimeConfig'
70637072
additionalProperties: false
70647073
required:
70657074
- job_uuid
@@ -7755,6 +7764,8 @@ components:
77557764
type: string
77567765
algorithm_config:
77577766
$ref: '#/components/schemas/AlgorithmConfig'
7767+
runtime_config:
7768+
$ref: '#/components/schemas/RuntimeConfig'
77587769
additionalProperties: false
77597770
required:
77607771
- job_uuid

llama_stack/apis/post_training/post_training.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@ class PostTrainingJobArtifactsResponse(BaseModel):
169169
# TODO(ashwin): metrics, evals
170170

171171

172+
@json_schema_type
173+
class RuntimeConfig(BaseModel):
174+
"""
175+
Provider-specific runtime configuration. Providers should document and parse their own expected fields.
176+
This model allows arbitrary extra fields for maximum flexibility.
177+
"""
178+
179+
class Config:
180+
extra = "allow"
181+
182+
172183
class PostTraining(Protocol):
173184
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
174185
async def supervised_fine_tune(
@@ -183,6 +194,7 @@ async def supervised_fine_tune(
183194
),
184195
checkpoint_dir: Optional[str] = None,
185196
algorithm_config: Optional[AlgorithmConfig] = None,
197+
runtime_config: Optional[RuntimeConfig] = None,
186198
) -> PostTrainingJob: ...
187199

188200
@webmethod(route="/post-training/preference-optimize", method="POST")
@@ -194,6 +206,7 @@ async def preference_optimize(
194206
training_config: TrainingConfig,
195207
hyperparam_search_config: Dict[str, Any],
196208
logger_config: Dict[str, Any],
209+
runtime_config: Optional[RuntimeConfig] = None,
197210
) -> PostTrainingJob: ...
198211

199212
@webmethod(route="/post-training/jobs", method="GET")

llama_stack/providers/inline/post_training/torchtune/post_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
PostTrainingJob,
1919
PostTrainingJobArtifactsResponse,
2020
PostTrainingJobStatusResponse,
21+
RuntimeConfig,
2122
TrainingConfig,
2223
)
2324
from llama_stack.providers.inline.post_training.torchtune.config import (
@@ -80,6 +81,7 @@ async def supervised_fine_tune(
8081
model: str,
8182
checkpoint_dir: Optional[str],
8283
algorithm_config: Optional[AlgorithmConfig],
84+
runtime_config: Optional[RuntimeConfig] = None,
8385
) -> PostTrainingJob:
8486
if isinstance(algorithm_config, LoraFinetuningConfig):
8587

0 commit comments

Comments
 (0)