Skip to content

Commit 6b2e131

Browse files
[Prompt Optimization Backend PR #1] Wrap prompt optimize in mlflow job (mlflow#20001)
1 parent 209c23b commit 6b2e131

File tree

5 files changed

+601
-0
lines changed

5 files changed

+601
-0
lines changed

mlflow/genai/optimize/job.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from enum import Enum
4+
from typing import Any, Callable
5+
6+
from mlflow.exceptions import MlflowException
7+
from mlflow.genai.datasets import get_dataset
8+
from mlflow.genai.optimize import optimize_prompts
9+
from mlflow.genai.optimize.optimizers import (
10+
BasePromptOptimizer,
11+
GepaPromptOptimizer,
12+
MetaPromptOptimizer,
13+
)
14+
from mlflow.genai.prompts import load_prompt
15+
from mlflow.genai.scorers import builtin_scorers
16+
from mlflow.genai.scorers.base import Scorer
17+
from mlflow.genai.scorers.registry import get_scorer
18+
from mlflow.server.jobs import job
19+
from mlflow.telemetry.events import OptimizePromptsJobEvent
20+
from mlflow.telemetry.track import record_usage_event
21+
from mlflow.tracking.client import MlflowClient
22+
from mlflow.tracking.fluent import set_experiment, start_run
23+
24+
_logger = logging.getLogger(__name__)
25+
26+
_DEFAULT_OPTIMIZATION_JOB_MAX_WORKERS = 2
27+
28+
29+
class OptimizerType(str, Enum):
30+
"""Supported prompt optimizer types."""
31+
32+
GEPA = "gepa"
33+
METAPROMPT = "metaprompt"
34+
35+
36+
@dataclass
37+
class PromptOptimizationJobResult:
38+
run_id: str
39+
source_prompt_uri: str
40+
optimized_prompt_uri: str | None
41+
optimizer_name: str
42+
initial_eval_score: float | None
43+
final_eval_score: float | None
44+
dataset_id: str
45+
scorer_names: list[str]
46+
47+
48+
def _create_optimizer(
49+
optimizer_type: str,
50+
optimizer_config: dict[str, Any] | None,
51+
) -> BasePromptOptimizer:
52+
"""
53+
Create an optimizer instance from type string and configuration dict.
54+
55+
Args:
56+
optimizer_type: The optimizer type string (e.g., "gepa", "metaprompt").
57+
optimizer_config: Optimizer-specific configuration dictionary.
58+
59+
Returns:
60+
An instantiated optimizer.
61+
62+
Raises:
63+
MlflowException: If optimizer type is not supported.
64+
"""
65+
config = optimizer_config or {}
66+
optimizer_type_lower = optimizer_type.lower() if optimizer_type else ""
67+
68+
if optimizer_type_lower == OptimizerType.GEPA:
69+
reflection_model = config.get("reflection_model")
70+
if not reflection_model:
71+
raise MlflowException.invalid_parameter_value(
72+
"Missing required optimizer configuration: 'reflection_model' must be specified "
73+
"in optimizer_config for the GEPA optimizer (e.g., 'openai:/gpt-4o')."
74+
)
75+
return GepaPromptOptimizer(
76+
reflection_model=reflection_model,
77+
max_metric_calls=config.get("max_metric_calls", 100),
78+
display_progress_bar=config.get("display_progress_bar", False),
79+
gepa_kwargs=config.get("gepa_kwargs"),
80+
)
81+
elif optimizer_type_lower == OptimizerType.METAPROMPT:
82+
reflection_model = config.get("reflection_model")
83+
if not reflection_model:
84+
raise MlflowException.invalid_parameter_value(
85+
"Missing required optimizer configuration: 'reflection_model' must be specified "
86+
"in optimizer_config for the MetaPrompt optimizer (e.g., 'openai:/gpt-4o')."
87+
)
88+
return MetaPromptOptimizer(
89+
reflection_model=reflection_model,
90+
lm_kwargs=config.get("lm_kwargs"),
91+
guidelines=config.get("guidelines"),
92+
)
93+
elif not optimizer_type:
94+
supported_types = [t.value for t in OptimizerType]
95+
raise MlflowException.invalid_parameter_value(
96+
f"Optimizer type must be specified. Supported types: {supported_types}"
97+
)
98+
else:
99+
supported_types = [t.value for t in OptimizerType]
100+
raise MlflowException.invalid_parameter_value(
101+
f"Unsupported optimizer type: '{optimizer_type}'. Supported types: {supported_types}"
102+
)
103+
104+
105+
def _load_scorers(scorer_names: list[str], experiment_id: str) -> list[Scorer]:
106+
"""
107+
Load scorers by name.
108+
109+
For each scorer name, first tries to load it as a built-in scorer (by class name),
110+
and if not found, falls back to loading from the registered scorer store.
111+
112+
Args:
113+
scorer_names: List of scorer names. Can be built-in scorer class names
114+
(e.g., "Correctness", "Safety") or registered scorer names.
115+
experiment_id: The experiment ID to load registered scorers from.
116+
117+
Returns:
118+
List of Scorer instances.
119+
120+
Raises:
121+
MlflowException: If a scorer cannot be found as either built-in or registered.
122+
"""
123+
124+
scorers = []
125+
for name in scorer_names:
126+
scorer_class = getattr(builtin_scorers, name, None)
127+
if scorer_class is not None:
128+
try:
129+
scorer = scorer_class()
130+
scorers.append(scorer)
131+
continue
132+
except Exception as e:
133+
_logger.debug(f"Failed to instantiate built-in scorer {name}: {e}")
134+
135+
# Load from the registered scorer store if not a built-in scorer
136+
try:
137+
scorer = get_scorer(name=name, experiment_id=experiment_id)
138+
scorers.append(scorer)
139+
except Exception as e:
140+
raise MlflowException.invalid_parameter_value(
141+
f"Scorer '{name}' not found. It is neither a built-in scorer "
142+
f"(e.g., 'Correctness', 'Safety') nor a registered scorer in "
143+
f"experiment '{experiment_id}'. Error: {e}"
144+
)
145+
146+
return scorers
147+
148+
149+
def _build_predict_fn(prompt_uri: str) -> Callable[..., Any]:
150+
"""
151+
Build a predict function for single-prompt optimization.
152+
153+
This creates a simple LLM call using the prompt's model configuration.
154+
The predict function loads the prompt, formats it with inputs, and
155+
calls the LLM via litellm.
156+
157+
Args:
158+
prompt_uri: The URI of the prompt to use for prediction.
159+
160+
Returns:
161+
A callable that takes inputs dict and returns the LLM response.
162+
"""
163+
try:
164+
import litellm
165+
except ImportError as e:
166+
raise MlflowException(
167+
"The 'litellm' package is required for prompt optimization but is not installed. "
168+
"Please install it using: pip install litellm"
169+
) from e
170+
171+
prompt = load_prompt(prompt_uri)
172+
try:
173+
model_config = prompt.model_config
174+
provider = model_config["provider"]
175+
model_name = model_config["model_name"]
176+
except (KeyError, TypeError, AttributeError) as e:
177+
raise MlflowException(
178+
f"Prompt {prompt_uri} doesn't have a model configuration that sets provider and "
179+
"model_name, which are required for optimization."
180+
) from e
181+
182+
litellm_model = f"{provider}/{model_name}"
183+
184+
def predict_fn(**kwargs: Any) -> Any:
185+
response = litellm.completion(
186+
model=litellm_model,
187+
messages=[{"role": "user", "content": prompt.format(**kwargs)}],
188+
)
189+
return response.choices[0].message.content
190+
191+
return predict_fn
192+
193+
194+
@record_usage_event(OptimizePromptsJobEvent)
195+
@job(name="optimize_prompts", max_workers=_DEFAULT_OPTIMIZATION_JOB_MAX_WORKERS)
196+
def optimize_prompts_job(
197+
run_id: str,
198+
experiment_id: str,
199+
prompt_uri: str,
200+
dataset_id: str,
201+
optimizer_type: str,
202+
optimizer_config: dict[str, Any] | None,
203+
scorer_names: list[str],
204+
) -> PromptOptimizationJobResult:
205+
"""
206+
Job function for async single-prompt optimization.
207+
208+
This function is executed as a background job by the MLflow server.
209+
It resumes an existing MLflow run (created by the handler) and calls
210+
`mlflow.genai.optimize_prompts()` which reuses the active run.
211+
212+
Note: This job only supports single-prompt optimization. The predict_fn
213+
is automatically built using the prompt's model_config (provider/model_name)
214+
via litellm, making the optimization self-contained without requiring users
215+
to serialize their own predict function.
216+
217+
Args:
218+
run_id: The MLflow run ID to track the optimization configs and metrics.
219+
experiment_id: The experiment ID to track the optimization in.
220+
prompt_uri: The URI of the prompt to optimize.
221+
dataset_id: The ID of the EvaluationDataset containing training data.
222+
optimizer_type: The optimizer type string (e.g., "gepa", "metaprompt").
223+
optimizer_config: Optimizer-specific configuration dictionary.
224+
scorer_names: List of scorer names. Can be built-in scorer class names
225+
(e.g., "Correctness", "Safety") or registered scorer names.
226+
For custom scorers, use mlflow.genai.make_judge() to create a judge,
227+
then register it using scorer.register(experiment_id=experiment_id),
228+
and pass the registered scorer name here.
229+
230+
Returns:
231+
PromptOptimizationJobResult containing optimization results and metadata.
232+
"""
233+
set_experiment(experiment_id=experiment_id)
234+
235+
dataset = get_dataset(dataset_id=dataset_id)
236+
predict_fn = _build_predict_fn(prompt_uri)
237+
optimizer = _create_optimizer(optimizer_type, optimizer_config)
238+
loaded_scorers = _load_scorers(scorer_names, experiment_id)
239+
source_prompt = load_prompt(prompt_uri)
240+
241+
# Resume the given run ID. Params have already been logged by the handler
242+
with start_run(run_id=run_id):
243+
# Link source prompt to run for lineage
244+
client = MlflowClient()
245+
client.link_prompt_version_to_run(run_id=run_id, prompt=source_prompt)
246+
result = optimize_prompts(
247+
predict_fn=predict_fn,
248+
train_data=dataset,
249+
prompt_uris=[prompt_uri],
250+
optimizer=optimizer,
251+
scorers=loaded_scorers,
252+
enable_tracking=True,
253+
)
254+
255+
return PromptOptimizationJobResult(
256+
run_id=run_id,
257+
source_prompt_uri=prompt_uri,
258+
optimized_prompt_uri=result.optimized_prompts[0].uri if result.optimized_prompts else None,
259+
optimizer_name=result.optimizer_name,
260+
initial_eval_score=result.initial_eval_score,
261+
final_eval_score=result.final_eval_score,
262+
dataset_id=dataset_id,
263+
scorer_names=scorer_names,
264+
)

mlflow/server/jobs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"mlflow.genai.scorers.job.invoke_scorer_job",
2222
"mlflow.genai.scorers.job.run_online_trace_scorer_job",
2323
"mlflow.genai.scorers.job.run_online_session_scorer_job",
24+
"mlflow.genai.optimize.job.optimize_prompts_job",
2425
]
2526

2627
if supported_job_function_list_env := os.environ.get("_MLFLOW_SUPPORTED_JOB_FUNCTION_LIST"):
@@ -32,6 +33,7 @@
3233
"invoke_scorer",
3334
"run_online_trace_scorer",
3435
"run_online_session_scorer",
36+
"optimize_prompts",
3537
]
3638

3739
if allowed_job_name_list_env := os.environ.get("_MLFLOW_ALLOWED_JOB_NAME_LIST"):

mlflow/telemetry/events.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,25 @@ def parse_result(cls, result: Any) -> dict[str, Any] | None:
599599
}
600600

601601

602+
class OptimizePromptsJobEvent(Event):
603+
name: str = "optimize_prompts_job"
604+
605+
@classmethod
606+
def parse(cls, arguments: dict[str, Any]) -> dict[str, Any] | None:
607+
result = {}
608+
609+
if optimizer_type := arguments.get("optimizer_type"):
610+
result["optimizer_type"] = optimizer_type
611+
612+
if "scorer_names" in arguments:
613+
scorer_names = arguments["scorer_names"]
614+
# `scorer_count` is useful for indicating zero-shot vs few-shot optimization, and to
615+
# track the pattern of how users use prompt optimization.
616+
result["scorer_count"] = len(scorer_names)
617+
618+
return result or None
619+
620+
602621
class ScorerCallEvent(Event):
603622
name: str = "scorer_call"
604623

0 commit comments

Comments
 (0)