Skip to content

Commit 7215a7c

Browse files
author
Siba Rajendran
committed
dspy changes
1 parent 6853f8d commit 7215a7c

File tree

10 files changed

+138
-76
lines changed

10 files changed

+138
-76
lines changed
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from fmcore.experimental.metrics.base_metric import BaseMetric
22
from fmcore.experimental.metrics.deepeval_geval import DeepEvalGEvalMetric
3-
from fmcore.experimental.metrics.custom_metric import CustomMetric
3+
from fmcore.experimental.metrics.custom_metric import CustomMetric

src/fmcore/experimental/metrics/custom_metric.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import json_repair
22
from typing import Dict, List
3-
from langchain_core.messages import BaseMessage,HumanMessage
3+
from langchain_core.messages import BaseMessage, HumanMessage
44
from jinja2 import Template
55
from fmcore.experimental.llm.base_llm import BaseLLM
66
from fmcore.experimental.metrics.base_metric import BaseMetric
77
from fmcore.experimental.types.enums.metric_enums import SupportedMetrics
88
from fmcore.experimental.types.metric_types import CustomMetricResult, MetricConfig, MetricResult
99

1010

11-
1211
class CustomMetric(BaseMetric):
1312
aliases = [SupportedMetrics.CUSTOM]
1413

1514
metric_name: str
1615
llm: BaseLLM
1716
prompt_template: Template
18-
1917

2018
@classmethod
2119
def _get_constructor_parameters(cls, *, metric_config: MetricConfig) -> Dict:
@@ -25,8 +23,13 @@ def _get_constructor_parameters(cls, *, metric_config: MetricConfig) -> Dict:
2523
prompt_template = Template(prompt)
2624

2725
metric_name: str = metric_config.metric_params["name"]
28-
29-
return {"config": metric_config, "prompt_template": prompt_template, "llm": llm, "metric_name": metric_name}
26+
27+
return {
28+
"config": metric_config,
29+
"prompt_template": prompt_template,
30+
"llm": llm,
31+
"metric_name": metric_name,
32+
}
3033

3134
def evaluate(self, data: Dict) -> MetricResult:
3235
prompt: str = self.prompt_template.render(**data)
@@ -43,3 +46,5 @@ async def aevaluate(self, data: Dict) -> MetricResult:
4346
result: Dict = json_repair.loads(response.content)
4447

4548
return CustomMetricResult(**result)
49+
50+

src/fmcore/experimental/metrics/deepeval_geval.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
from fmcore.experimental.adapters.deepeval_adapter import DeepEvalLLMAdapter
88
from fmcore.experimental.llm.base_llm import BaseLLM
99
from fmcore.experimental.metrics.base_metric import BaseMetric
10-
from fmcore.experimental.types.enums.metric_enums import MetricFramework, SupportedMetrics, EvaluationFieldType
10+
from fmcore.experimental.types.enums.metric_enums import (
11+
MetricFramework,
12+
SupportedMetrics,
13+
EvaluationFieldType,
14+
)
1115
from fmcore.experimental.types.llm_types import LLMConfig
1216
from fmcore.experimental.types.metric_types import (
1317
MetricConfig,
@@ -36,12 +40,12 @@ def _get_constructor_parameters(cls, *, metric_config: MetricConfig) -> Dict:
3640
geval_metric_params["evaluation_params"] = DeepEvalUtils.infer_evaluation_params(
3741
field_mapping=metric_config.field_mapping
3842
)
39-
43+
4044
if not metric_config.framework:
4145
metric_config.framework = MetricFramework.DEEPEVAL
4246

4347
geval_metric_params["model"] = model
44-
48+
4549
return {"config": metric_config, "geval_metric_params": geval_metric_params}
4650

4751
def evaluate(self, data: Dict) -> MetricResult:
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import dspy
2+
import mlflow.dspy
23
from pandas import DataFrame
3-
from typing import Callable, Dict, Optional, List, Tuple, Type
4+
from typing import Callable, Dict, Optional, List, Tuple, Type, Any, Union
45

56
from dspy.teleprompt import Teleprompter
67
from dspy.teleprompt.mipro_optimizer_v2 import MIPROv2
78
from dspy.teleprompt.bootstrap import BootstrapFewShot
89
from dspy import Signature, Module
910

10-
1111
from fmcore.experimental.metrics.base_metric import BaseMetric
1212
from fmcore.experimental.prompt_tuner.base_prompt_tuner import BasePromptTuner
1313
from fmcore.experimental.types.enums.prompt_tuner_enums import PromptTunerFramework
@@ -16,27 +16,52 @@
1616
from fmcore.experimental.adapters.dspy_adapter import DSPyLLMAdapter
1717
from fmcore.experimental.utils.introspection_utils import IntrospectionUtils
1818
from fmcore.experimental.prompt_tuner.utils.dspy_utils import DSPyUtils
19-
from py_expression_eval import Parser
2019
from asteval import Interpreter
2120

2221

23-
24-
2522
class DSPyPromptTuner(BasePromptTuner):
23+
"""
24+
A prompt tuner implementation using the DSPy framework.
25+
26+
This class provides functionality to optimize prompts using various DSPy optimizers
27+
such as MIPROv2 or BootstrapFewShot. It uses a student model for generating responses
28+
and evaluates them using a configured metric to iteratively improve the prompt.
29+
30+
Attributes:
31+
aliases (List[PromptTunerFramework]): Framework identifiers for this tuner.
32+
student (dspy.LM): The student language model used for prompt optimization.
33+
teacher (Optional[dspy.LM]): The teacher language model used in some optimization techniques.
34+
optimizer_metric (BaseMetric): The metric used to evaluate prompt performance.
35+
"""
36+
2637
aliases = [PromptTunerFramework.DSPY]
2738
student: dspy.LM
2839
teacher: Optional[dspy.LM]
2940
optimizer_metric: BaseMetric
3041

3142
@classmethod
32-
def _get_constructor_parameters(cls, *, config: PromptTunerConfig) -> Dict:
43+
def _get_constructor_parameters(cls, *, config: PromptTunerConfig) -> Dict[str, Any]:
44+
"""
45+
Creates and configures the necessary components for DSPy prompt tuning.
46+
47+
Args:
48+
config: Configuration containing all necessary parameters for the prompt tuner.
49+
Must include student model config and optionally teacher model config.
50+
51+
Returns:
52+
Dictionary of parameters needed to initialize the DSPyPromptTuner instance.
53+
"""
54+
# Initialize student model and configure DSPy to use it
3355
student_model = DSPyLLMAdapter(llm_config=config.optimzer_config.student_config)
3456
dspy.configure(lm=student_model)
3557

58+
# Initialize teacher model (or use student if not specified)
3659
if config.optimzer_config.teacher_config:
3760
teacher_model = DSPyLLMAdapter(llm_config=config.optimzer_config.teacher_config)
3861
else:
3962
teacher_model = student_model
63+
64+
# Initialize metric for optimization
4065
optimizer_metric = BaseMetric.of(metric_config=config.optimzer_config.metric_config)
4166

4267
return {
@@ -46,19 +71,33 @@ def _get_constructor_parameters(cls, *, config: PromptTunerConfig) -> Dict:
4671
"config": config,
4772
}
4873

49-
def _create_evaluation_function(self):
74+
def _create_evaluation_function(self) -> Callable:
5075
"""
5176
Creates an evaluation function that uses the configured metric.
52-
77+
78+
The function evaluates DSPy predictions by applying the metric and interpreting
79+
the criteria expression to determine the quality of the prediction.
80+
5381
Returns:
54-
Evaluation function that takes an example and prediction
82+
A callable function that takes an example and prediction and returns a
83+
numerical or boolean evaluation score.
5584
"""
56-
5785
# Store criteria once to avoid re-fetching it in each evaluation call
5886
criteria = self.optimizer_metric.config.metric_params["criteria"]
5987

60-
def evaluate_func(example: dspy.Example, prediction: dspy.Prediction, trace=None):
61-
# Get evaluation results
88+
def evaluate_func(example: dspy.Example, prediction: dspy.Prediction, trace=None) -> Union[float, bool]:
89+
"""
90+
Evaluates a single example-prediction pair using the configured metric.
91+
92+
Args:
93+
example: The DSPy example containing input data
94+
prediction: The model's prediction to evaluate
95+
trace: Optional trace information from DSPy (not used)
96+
97+
Returns:
98+
Evaluation score as determined by the configured criteria
99+
"""
100+
# Get evaluation results from the metric
62101
evaluation_response: dict = DSPyUtils.evaluate(
63102
example=example,
64103
prediction=prediction,
@@ -72,23 +111,33 @@ def evaluate_func(example: dspy.Example, prediction: dspy.Prediction, trace=None
72111

73112
return evaluate_func
74113

75-
76-
77114
def tune(self, data: DataFrame) -> str:
78115
"""
79-
Tunes a prompt using the configured DSPy optimizer.
80-
116+
Tunes a prompt using the configured DSPy optimizer and training data.
117+
118+
This method:
119+
1. Converts the input data to DSPy examples
120+
2. Creates a DSPy signature and module based on the prompt configuration
121+
3. Configures an evaluation function using the specified metric
122+
4. Applies the DSPy optimizer to generate an optimized prompt
123+
81124
Args:
82-
data: DataFrame containing the training data
83-
prompt_config: Configuration containing input and output fields
84-
125+
data: DataFrame containing the training data with input and expected output fields
126+
85127
Returns:
86128
The optimized prompt as a string
129+
130+
Raises:
131+
ValueError: If the optimization process fails or returns invalid results
87132
"""
88133

134+
import mlflow
135+
mlflow.dspy.autolog(log_traces=True, log_traces_from_compile=True, log_traces_from_eval=True, disable=False, silent=False)
136+
89137
# Convert data to DSPy examples
90138
dspy_examples = DSPyUtils.convert_to_dspy_examples(
91-
data=data, prompt_config=self.config.prompt_config
139+
data=data,
140+
prompt_config=self.config.prompt_config
92141
)
93142

94143
# Create signature and module separately
@@ -108,13 +157,20 @@ def tune(self, data: DataFrame) -> str:
108157
evaluate_func=evaluate_func,
109158
)
110159

160+
# Filter optimizer parameters to only include those accepted by the compile method
111161
filtered_optimizer_params = IntrospectionUtils.filter_params(
112-
func=optimizer.compile, params=self.config.optimzer_config.params
162+
func=optimizer.compile,
163+
params=self.config.optimzer_config.params
113164
)
165+
114166
# Compile the module with the optimizer
115167
optimized_module = optimizer.compile(
116168
student=module,
117169
trainset=dspy_examples,
118170
requires_permission_to_run=False,
119171
**filtered_optimizer_params,
120172
)
173+
174+
dspy.inspect_history(optimized_module)
175+
176+
optimized_module.signature.prompt

src/fmcore/experimental/prompt_tuner/utils/dspy_utils.py

+22-21
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ def get_optimizer(
2020
optimzer_config: OptimizerConfig,
2121
evaluate_func: Callable,
2222
) -> Teleprompter:
23-
if optimzer_config.type == DspyOptimizerType.MIPRO_V2:
24-
return MIPROv2(
23+
24+
# This will be a factory method that returns the appropriate optimizer based on the optimizer type
25+
# TODO: Add more optimizers
26+
optimizer: Teleprompter = MIPROv2(
2527
prompt_model=teacher,
2628
task_model=student,
2729
metric=evaluate_func,
2830
**optimzer_config.params,
2931
)
30-
return BootstrapFewShot(**optimzer_config.params, metric=evaluate_func)
32+
33+
return optimizer
3134

3235
@staticmethod
3336
def create_dspy_signature(prompt_config: PromptConfig) -> Type[dspy.Signature]:
@@ -43,24 +46,24 @@ def create_dspy_signature(prompt_config: PromptConfig) -> Type[dspy.Signature]:
4346

4447
# Create a DSPy Signature class dictionary with annotations
4548
attrs = {
46-
'__annotations__': {},
47-
'__doc__': prompt_config.prompt if prompt_config.prompt else ""
49+
"__annotations__": {},
50+
"__doc__": prompt_config.prompt if prompt_config.prompt else "",
4851
}
4952

5053
# Dynamically add input and output fields with type annotations
5154
for field in prompt_config.input_fields:
5255
# Assume field has a type attribute, otherwise default to str
53-
field_type = getattr(field, 'type', str)
54-
attrs['__annotations__'][field.name] = field_type
56+
field_type = getattr(field, "type", str)
57+
attrs["__annotations__"][field.name] = field_type
5558
attrs[field.name] = dspy.InputField(desc=field.description)
56-
59+
5760
for field in prompt_config.output_fields:
58-
field_type = getattr(field, 'type', str)
59-
attrs['__annotations__'][field.name] = field_type
61+
field_type = getattr(field, "type", str)
62+
attrs["__annotations__"][field.name] = field_type
6063
attrs[field.name] = dspy.OutputField(desc=field.description)
6164

6265
# Create the class dynamically with type annotations
63-
TaskSignature = type('TaskSignature', (dspy.Signature,), attrs)
66+
TaskSignature = type("TaskSignature", (dspy.Signature,), attrs)
6467

6568
return TaskSignature
6669

@@ -85,7 +88,7 @@ def __init__(self, signature: dspy.Signature):
8588
def forward(self, **kwargs):
8689
prediction = self.predictor(**kwargs)
8790
return prediction
88-
91+
8992
return TaskModule(signature=signature)
9093

9194
@staticmethod
@@ -106,20 +109,18 @@ def convert_to_dspy_examples(
106109

107110
loader = DataLoader()
108111
input_keys = [field.name for field in prompt_config.input_fields]
109-
examples = loader.from_pandas(
110-
data,
111-
fields=data.columns.tolist(),
112-
input_keys=input_keys
113-
)
114-
112+
examples = loader.from_pandas(data, fields=data.columns.tolist(), input_keys=input_keys)
113+
115114
return examples
116115

117116
@staticmethod
118-
def evaluate(example: dspy.Example, prediction: dspy.Prediction, metric: BaseMetric) -> MetricResult:
117+
def evaluate(
118+
example: dspy.Example, prediction: dspy.Prediction, metric: BaseMetric
119+
) -> MetricResult:
119120
row = {
120121
EvaluationFieldType.INPUT.name: example.toDict(),
121122
EvaluationFieldType.OUTPUT.name: prediction.toDict(),
122123
}
123-
124+
124125
metric_result: MetricResult = metric.evaluate(data=row)
125-
return metric_result.model_dump(exclude_none=True)
126+
return metric_result.model_dump(exclude_none=True)

0 commit comments

Comments
 (0)