-
Notifications
You must be signed in to change notification settings - Fork 1.9k
SIMBA Improvements #8077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SIMBA Improvements #8077
Changes from all commits
e7a8c11
9fe6d6e
7a8297a
f568a57
17f6353
63b0907
0856d7f
430f046
ce30c63
5bf5e85
c559729
57f0a7d
44bafa9
540ddbb
826bfca
c61a455
90e9929
2670c91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,8 +35,8 @@ def __init__( | |
self, | ||
model: str, | ||
model_type: Literal["chat", "text"] = "chat", | ||
temperature: float = 0.0, | ||
max_tokens: int = 1000, | ||
temperature: Optional[float] = None, | ||
max_tokens: Optional[int] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating defaults, which were throwing errors for reasoning models. Now, instead of defaulting to a temp of 0.0 and max_tokens 1000 (and erroring out automatically for o3mini), we are setting temp and max_tokens based on whether the model is a reasoning model or not. If the user has intentionally set one of the values to something the reasoning model can't handle (i.e. temperature=0.7), then we will still throw an error. |
||
cache: bool = True, | ||
cache_in_memory: bool = True, | ||
callbacks: Optional[List[BaseCallback]] = None, | ||
|
@@ -81,17 +81,24 @@ def __init__( | |
self.launch_kwargs = launch_kwargs or {} | ||
self.train_kwargs = train_kwargs or {} | ||
|
||
# Handle model-specific configuration for different model families | ||
# Identify reasoning models (e.g., o1, o3 variants) | ||
model_family = model.split("/")[-1].lower() if "/" in model else model.lower() | ||
|
||
# Match pattern: o[1,3] at the start, optionally followed by -mini and anything else | ||
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family) | ||
|
||
if model_pattern: | ||
# Handle OpenAI reasoning models (o1, o3) | ||
is_reasoning_model = bool(re.match(r"^o([13])(?:-mini)?", model_family)) | ||
|
||
# Set defaults | ||
if temperature is None: | ||
temperature = 1.0 if is_reasoning_model else 0.0 | ||
if max_tokens is None: | ||
max_tokens = 20_000 if is_reasoning_model else 1_000 | ||
|
||
# Check to make sure temperature and max_tokens work for the model | ||
if is_reasoning_model: | ||
assert ( | ||
max_tokens >= 20_000 and temperature == 1.0 | ||
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`" | ||
|
||
# Set kwargs based on reasoning model check | ||
if is_reasoning_model: | ||
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs) | ||
else: | ||
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from typing import Callable | ||
from dspy.teleprompt.teleprompt import Teleprompter | ||
from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule | ||
from typing import Optional, Any, Dict | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -20,6 +21,8 @@ def __init__( | |
num_candidates=6, | ||
max_steps=8, | ||
max_demos=4, | ||
prompt_model: Optional[Any] = None, | ||
teacher_settings: Optional[Dict] = None, | ||
demo_input_field_maxlen=100_000, | ||
num_threads=None, | ||
temperature_for_sampling=0.2, | ||
|
@@ -41,6 +44,8 @@ def __init__( | |
self.num_candidates = num_candidates | ||
self.max_steps = max_steps | ||
self.max_demos = max_demos | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding support for prompt / teacher models |
||
self.prompt_model = prompt_model if prompt_model else dspy.settings.lm | ||
self.teacher_settings = teacher_settings | ||
self.demo_input_field_maxlen = demo_input_field_maxlen | ||
self.num_threads = num_threads | ||
|
||
|
@@ -137,7 +142,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): | |
|
||
# We'll generate (program, model) pairs for the trajectory sampling. | ||
# Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0]. | ||
models = prepare_models_for_resampling(programs[0], self.num_candidates) | ||
models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings) | ||
top_programs = top_k_plus_baseline(self.num_candidates) | ||
|
||
exec_pairs = [] | ||
|
@@ -240,6 +245,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): | |
name2predictor=name2predictor, | ||
batch_10p_score=batch_10th_percentile_score, | ||
batch_90p_score=batch_90th_percentile_score, | ||
prompt_model=self.prompt_model, | ||
) | ||
except Exception as e: | ||
logger.error(f"Strategy failed with error: {e}") | ||
|
@@ -310,7 +316,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]): | |
trial_logs[idx_prog-1]["train_score"] = avg_score | ||
|
||
best_idx = scores.index(max(scores)) if scores else 0 | ||
best_program = candidate_programs[best_idx] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixing max recursion depth error |
||
best_program = candidate_programs[best_idx].deepcopy() | ||
logger.info( | ||
f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} " | ||
f"(at index {best_idx if scores else 'N/A'})\n\n\n" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,20 +3,37 @@ | |
import inspect | ||
import logging | ||
import textwrap | ||
import re | ||
|
||
from dspy.adapters.utils import get_field_description_string | ||
from dspy.signatures import InputField, OutputField | ||
from typing import Callable | ||
from typing import Callable, Optional, Dict | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def prepare_models_for_resampling(program: dspy.Module, n: int): | ||
def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None): | ||
lm = program.get_lm() or dspy.settings.lm | ||
temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / n) for i in range(n)] | ||
temps = list(dict.fromkeys(temps))[:n] | ||
return [lm.copy(temperature=t) for t in temps] | ||
|
||
# Check to see if our model is a reasoning model, which means temp must stay as 1.0 | ||
model_family = lm.model.split("/")[-1].lower() if "/" in lm.model else lm.model.lower() | ||
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function has been updated to:
|
||
models = [] | ||
if teacher_settings: | ||
models.append(dspy.LM(**teacher_settings)) | ||
|
||
if model_pattern: # Vary the seed | ||
start_seed = 0 if "seed" not in lm.kwargs else lm.kwargs["seed"] | ||
seeds = [start_seed + 1 + i for i in range(n-len(models))] | ||
seeds = list(dict.fromkeys(seeds))[:(n-len(models))] | ||
models.extend([lm.copy(seed=seed) for seed in seeds]) | ||
else: # Vary the temperature | ||
start_temp = 0 if "temperature" not in lm.kwargs else lm.kwargs["temperature"] | ||
temps = [start_temp + 0.5 + i * (0.5 / n) for i in range(n-len(models))] | ||
temps = list(dict.fromkeys(temps))[:(n-len(models))] | ||
models.extend([lm.copy(temperature=t) for t in temps]) | ||
|
||
return models | ||
|
||
def wrap_program(program: dspy.Module, metric: Callable): | ||
def wrapped_program(example): | ||
|
@@ -28,30 +45,48 @@ def wrapped_program(example): | |
print(e) | ||
trace = dspy.settings.trace.copy() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object |
||
output = None | ||
score = 0.0 | ||
output_metadata = {} | ||
|
||
try: | ||
score = metric(example, prediction) | ||
output = metric(example, prediction) | ||
if isinstance(output, (int, float)): | ||
score = output | ||
elif isinstance(output, dspy.Prediction): | ||
if not hasattr(output, 'score'): | ||
raise ValueError("dspy.Prediction must contain a 'score' attribute") | ||
score = output.score | ||
# Just extract fields from _store, excluding 'score' | ||
output_metadata = { | ||
k: v for k, v in output._store.items() if k != "score" | ||
} | ||
except Exception as e: | ||
print(e) | ||
|
||
# Include the `example` in the output for subsequent usage in buckets/strategies. | ||
return { | ||
"prediction": prediction, | ||
"trace": trace, | ||
"score": score, | ||
"example": example | ||
"example": example, | ||
"output_metadata": output_metadata | ||
} | ||
|
||
return wrapped_program | ||
|
||
|
||
|
||
def append_a_demo(demo_input_field_maxlen): | ||
def append_a_demo_(bucket, system, **kwargs): | ||
predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"] | ||
batch_10p_score = kwargs["batch_10p_score"] | ||
|
||
trace = bucket[0]["trace"] | ||
good = bucket[0] | ||
trace = good["trace"] | ||
name2demo = {} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Double checking that the demo we're appending is not below the 10th percentile of scores |
||
if good["score"] < batch_10p_score: | ||
logger.info(f"Skipping appending a demo as good score {good['score']} is below the 10th percentile.") | ||
return False | ||
|
||
for step in trace: | ||
predictor, _inputs, _outputs = step | ||
|
||
|
@@ -62,7 +97,6 @@ def append_a_demo_(bucket, system, **kwargs): | |
demo = dspy.Example(augmented=True, **_inputs, **_outputs) | ||
name = predictor2name[id(predictor)] | ||
name2demo[name] = demo # keep the last demo for each predictor | ||
|
||
for name, demo in name2demo.items(): | ||
predictor = name2predictor[name] | ||
predictor.demos.append(demo) | ||
|
@@ -76,6 +110,7 @@ def append_a_demo_(bucket, system, **kwargs): | |
def append_a_rule(bucket, system, **kwargs): | ||
predictor2name = kwargs["predictor2name"] | ||
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"] | ||
prompt_model = kwargs["prompt_model"] | ||
|
||
module_names = [name for name, _ in system.named_predictors()] | ||
good, bad = bucket[0], bucket[-1] | ||
|
@@ -116,12 +151,16 @@ def append_a_rule(bucket, system, **kwargs): | |
worse_program_outputs=dict(bad["prediction"] or {}), | ||
worse_reward_value=bad["score"], | ||
better_reward_value=good["score"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, adding additional metric metadata here |
||
worse_reward_info=bad["output_metadata"], | ||
better_reward_info=good["output_metadata"], | ||
module_names=module_names, | ||
) | ||
|
||
kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2) | ||
for k, v in kwargs.items()} | ||
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice | ||
|
||
advice_program = dspy.Predict(OfferFeedback) | ||
advice = advice_program(**kwargs).module_advice | ||
|
||
for name, predictor in system.named_predictors(): | ||
if name in advice: | ||
|
@@ -155,11 +194,13 @@ class OfferFeedback(dspy.Signature): | |
) | ||
worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") | ||
worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") | ||
worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") | ||
better_program_trajectory: str = InputField( | ||
desc="The trajectory of the program's execution, showing each module's I/O" | ||
) | ||
better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing") | ||
better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs") | ||
better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.") | ||
module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice") | ||
discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did") | ||
module_advice: dict[str, str] = OutputField( | ||
|
@@ -169,7 +210,6 @@ class OfferFeedback(dspy.Signature): | |
"like the successful trajectory rather than the lower-scoring trajectory." | ||
) | ||
|
||
|
||
def inspect_modules(program): | ||
separator = "-" * 80 | ||
output = [separator] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to allow support for Optional fields (i.e. where a field could be None, or str), which is the case for KIE and was throwing errors before.