Skip to content

Commit 1fe8f48

Browse files
author
Siba Rajendran
committed
prompt tuning improvements
1 parent d2b92d9 commit 1fe8f48

File tree

4 files changed

+29
-24
lines changed

4 files changed

+29
-24
lines changed

src/fmcore/experimental/prompt_tuner/dspy/adapters/chat_adapter.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,18 @@
55
from typing import Any, Dict, Optional, Type
66

77
from dspy import ChatAdapter
8-
from dspy.adapters.chat_adapter import enumerate_fields, format_fields, get_dspy_field_type, prepare_schema, prepare_instructions
8+
from dspy.adapters.chat_adapter import (
9+
enumerate_fields,
10+
format_fields,
11+
get_dspy_field_type,
12+
prepare_schema,
13+
prepare_instructions,
14+
)
915
from dspy.adapters.chat_adapter import FieldInfoWithName, BuiltInCompletedOutputFieldInfo, FieldInfo
1016
from dspy.signatures.signature import Signature, SignatureMeta
1117
from dspy.utils.callback import BaseCallback
1218

19+
1320
def custom_prepare_instructions(signature: SignatureMeta):
1421
parts = []
1522

@@ -19,7 +26,9 @@ def custom_prepare_instructions(signature: SignatureMeta):
1926

2027
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
2128
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
22-
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
29+
parts.append(
30+
"All interactions will be structured in the following way, with the appropriate values filled in."
31+
)
2332

2433
def field_metadata(field_name, field_info):
2534
field_type = field_info.annotation
@@ -48,7 +57,9 @@ def field_metadata(field_name, field_info):
4857
def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
4958
return format_fields(
5059
fields_with_values={
51-
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
60+
FieldInfoWithName(name=field_name, info=field_info): field_metadata(
61+
field_name, field_info
62+
)
5263
for field_name, field_info in fields.items()
5364
},
5465
)

src/fmcore/experimental/prompt_tuner/dspy/optimizers/miprov2_optimizer.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -46,47 +46,42 @@ def optimize(
4646
"""
4747
# Initialize MIPROv2 optimizer with filtered constructor params
4848
constructor_params = IntrospectionUtils.filter_params(
49-
func=MIPROv2,
50-
params=optimizer_params or {}
49+
func=MIPROv2, params=optimizer_params or {}
5150
)
5251
optimizer = MIPROv2(
5352
metric=self.evaluate,
5453
prompt_model=self.teacher,
5554
task_model=self.student,
56-
**constructor_params
55+
**constructor_params,
5756
)
5857

5958
# Run optimization with filtered compile params
6059
compile_params = IntrospectionUtils.filter_params(
61-
func=MIPROv2.compile,
62-
params=optimizer_params or {}
60+
func=MIPROv2.compile, params=optimizer_params or {}
6361
)
6462
optimized_program = optimizer.compile(
6563
student=self.module,
6664
trainset=dataset.train,
6765
valset=dataset.dev,
6866
requires_permission_to_run=False,
69-
**compile_params
67+
**compile_params,
7068
)
7169

7270
optimized_prompts = []
73-
71+
7472
# MIPROv2 returns a list of candidate programs, each containing:
7573
# - A "program" object with the optimized DSPy module in predictor.predict
7674
# - A "score" indicating how well that program performed
7775
for candidate in optimized_program.candidate_programs:
7876
# Extract the optimized DSPy module from the candidate program
7977
optimized_module = candidate["program"].predictor.predict
8078
prompt_template = DSPyUtils.convert_module_to_prompt(module=optimized_module)
81-
82-
optimized_prompt = OptimizedPrompt(
83-
template=prompt_template,
84-
score=candidate["score"]
85-
)
86-
79+
80+
optimized_prompt = OptimizedPrompt(template=prompt_template, score=candidate["score"])
81+
8782
optimized_prompts.append(optimized_prompt)
8883

8984
# Sort prompts by score in descending order
9085
optimized_prompts.sort(key=lambda x: x.score, reverse=True)
91-
86+
9287
return PromptTunerResult(prompts=optimized_prompts)

src/fmcore/experimental/prompt_tuner/dspy/utils/commons.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,11 @@ def evaluate_func(
232232

233233
# Evaluate the criteria expression to get the final score
234234
expression_evaluator = Interpreter()
235-
expression_evaluator .symtable.update(evaluation_response)
235+
expression_evaluator.symtable.update(evaluation_response)
236236
return expression_evaluator(criteria)
237237

238238
return evaluate_func
239-
239+
240240
@staticmethod
241241
def convert_module_to_messages(module: dspy.Module) -> List[Dict[str, str]]:
242242
"""
@@ -259,12 +259,11 @@ def convert_module_to_messages(module: dspy.Module) -> List[Dict[str, str]]:
259259
# Get input field names from signature and create template variables
260260
signature: dspy.Signature = module.signature
261261
inputs = {
262-
field_name: f"{{{{{field_name}}}}}"
263-
for field_name in signature.input_fields.keys()
262+
field_name: f"{{{{{field_name}}}}}" for field_name in signature.input_fields.keys()
264263
}
265264

266265
# Format the module into chat messages using the adapter
267-
messages = adapter.format(signature=signature,demos=module.demos,inputs=inputs)
266+
messages = adapter.format(signature=signature, demos=module.demos, inputs=inputs)
268267

269268
return messages
270269

src/fmcore/experimental/prompt_tuner/dspy_prompt_tuner.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121

2222
import dspy.adapters.chat_adapter as chat_adapter_module
2323
from fmcore.experimental.prompt_tuner.dspy.adapters.chat_adapter import custom_prepare_instructions
24+
2425
chat_adapter_module.prepare_instructions = custom_prepare_instructions
2526

27+
2628
class DSPyPromptTuner(BasePromptTuner):
2729
"""
2830
A prompt tuner implementation using the DSPy framework.
@@ -94,8 +96,6 @@ def tune(self, data: DataFrame) -> str:
9496
Raises:
9597
ValueError: If the optimization process fails or returns invalid results
9698
"""
97-
98-
9999

100100
# Convert data to DSPy examples
101101
dataset: DspyDataset = DspyDataset(data=data, prompt_config=self.config.prompt_config)

0 commit comments

Comments
 (0)