Skip to content

SIMBA Improvements #8077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion dspy/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import json
from collections.abc import Mapping
from typing import Any, List, Literal, Union, get_args, get_origin
from typing import Any, List, Literal, get_args, get_origin, Union

import json_repair
import pydantic
Expand All @@ -12,6 +12,7 @@

from dspy.signatures.utils import get_dspy_field_type

NoneType = type(None)

def serialize_for_json(value: Any) -> Any:
"""
Expand Down Expand Up @@ -130,8 +131,19 @@ def find_enum_member(enum, identifier):

raise ValueError(f"{identifier} is not a valid name or value for the enum {enum.__name__}")

def _strip_optional(ann):
"""If ann is Union[..., NoneType] return the non‑None part, else ann."""
if get_origin(ann) is Union and NoneType in get_args(ann):
# keep the first non‑None member (there will be only one in Optional[T])
return next(a for a in get_args(ann) if a is not NoneType)
return ann

def parse_value(value, annotation):
annotation = _strip_optional(annotation)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to allow support for Optional fields (i.e. where a field could be None, or str), which is the case for KIE and was throwing errors before.


if value is None:
return None

if annotation is str:
return str(value)

Expand Down
25 changes: 16 additions & 9 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(
self,
model: str,
model_type: Literal["chat", "text"] = "chat",
temperature: float = 0.0,
max_tokens: int = 1000,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating defaults, which were throwing errors for reasoning models. Now, instead of defaulting to a temp of 0.0 and max_tokens 1000 (and erroring out automatically for o3mini), we are setting temp and max_tokens based on whether the model is a reasoning model or not. If the user has intentionally set one of the values to something the reasoning model can't handle (i.e. temperature=0.7), then we will still throw an error.

cache: bool = True,
cache_in_memory: bool = True,
callbacks: Optional[List[BaseCallback]] = None,
Expand Down Expand Up @@ -81,17 +81,24 @@ def __init__(
self.launch_kwargs = launch_kwargs or {}
self.train_kwargs = train_kwargs or {}

# Handle model-specific configuration for different model families
# Identify reasoning models (e.g., o1, o3 variants)
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()

# Match pattern: o[1,3] at the start, optionally followed by -mini and anything else
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family)

if model_pattern:
# Handle OpenAI reasoning models (o1, o3)
is_reasoning_model = bool(re.match(r"^o([13])(?:-mini)?", model_family))

# Set defaults
if temperature is None:
temperature = 1.0 if is_reasoning_model else 0.0
if max_tokens is None:
max_tokens = 20_000 if is_reasoning_model else 1_000

# Check to make sure temperature and max_tokens work for the model
if is_reasoning_model:
assert (
max_tokens >= 20_000 and temperature == 1.0
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 20_000 to `dspy.LM(...)`"

# Set kwargs based on reasoning model check
if is_reasoning_model:
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
else:
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
Expand Down
10 changes: 8 additions & 2 deletions dspy/teleprompt/simba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable
from dspy.teleprompt.teleprompt import Teleprompter
from dspy.teleprompt.simba_utils import prepare_models_for_resampling, wrap_program, append_a_demo, append_a_rule
from typing import Optional, Any, Dict

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +21,8 @@ def __init__(
num_candidates=6,
max_steps=8,
max_demos=4,
prompt_model: Optional[Any] = None,
teacher_settings: Optional[Dict] = None,
demo_input_field_maxlen=100_000,
num_threads=None,
temperature_for_sampling=0.2,
Expand All @@ -41,6 +44,8 @@ def __init__(
self.num_candidates = num_candidates
self.max_steps = max_steps
self.max_demos = max_demos
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding support for prompt / teacher models

self.prompt_model = prompt_model if prompt_model else dspy.settings.lm
self.teacher_settings = teacher_settings
self.demo_input_field_maxlen = demo_input_field_maxlen
self.num_threads = num_threads

Expand Down Expand Up @@ -137,7 +142,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]):

# We'll generate (program, model) pairs for the trajectory sampling.
# Prepare distinct LMs (with different temperatures, etc.) from the baseline=programs[0].
models = prepare_models_for_resampling(programs[0], self.num_candidates)
models = prepare_models_for_resampling(programs[0], self.num_candidates, self.teacher_settings)
top_programs = top_k_plus_baseline(self.num_candidates)

exec_pairs = []
Expand Down Expand Up @@ -240,6 +245,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]):
name2predictor=name2predictor,
batch_10p_score=batch_10th_percentile_score,
batch_90p_score=batch_90th_percentile_score,
prompt_model=self.prompt_model,
)
except Exception as e:
logger.error(f"Strategy failed with error: {e}")
Expand Down Expand Up @@ -310,7 +316,7 @@ def register_new_program(prog: dspy.Module, score_list: list[float]):
trial_logs[idx_prog-1]["train_score"] = avg_score

best_idx = scores.index(max(scores)) if scores else 0
best_program = candidate_programs[best_idx]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing max recursion depth error

best_program = candidate_programs[best_idx].deepcopy()
logger.info(
f"Final trainset scores: {scores}, Best: {max(scores) if scores else 'N/A'} "
f"(at index {best_idx if scores else 'N/A'})\n\n\n"
Expand Down
70 changes: 55 additions & 15 deletions dspy/teleprompt/simba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,37 @@
import inspect
import logging
import textwrap
import re

from dspy.adapters.utils import get_field_description_string
from dspy.signatures import InputField, OutputField
from typing import Callable
from typing import Callable, Optional, Dict

logger = logging.getLogger(__name__)


def prepare_models_for_resampling(program: dspy.Module, n: int):
def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: Optional[Dict] = None):
lm = program.get_lm() or dspy.settings.lm
temps = [lm.kwargs["temperature"]] + [0.5 + i * (0.5 / n) for i in range(n)]
temps = list(dict.fromkeys(temps))[:n]
return [lm.copy(temperature=t) for t in temps]

# Check to see if our model is a reasoning model, which means temp must stay as 1.0
model_family = lm.model.split("/")[-1].lower() if "/" in lm.model else lm.model.lower()
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has been updated to:

  1. add support for teacher model (used for 1 of the N trajectories)
  2. add support for reasoning models by varying the seed

models = []
if teacher_settings:
models.append(dspy.LM(**teacher_settings))

if model_pattern: # Vary the seed
start_seed = 0 if "seed" not in lm.kwargs else lm.kwargs["seed"]
seeds = [start_seed + 1 + i for i in range(n-len(models))]
seeds = list(dict.fromkeys(seeds))[:(n-len(models))]
models.extend([lm.copy(seed=seed) for seed in seeds])
else: # Vary the temperature
start_temp = 0 if "temperature" not in lm.kwargs else lm.kwargs["temperature"]
temps = [start_temp + 0.5 + i * (0.5 / n) for i in range(n-len(models))]
temps = list(dict.fromkeys(temps))[:(n-len(models))]
models.extend([lm.copy(temperature=t) for t in temps])

return models

def wrap_program(program: dspy.Module, metric: Callable):
def wrapped_program(example):
Expand All @@ -28,30 +45,48 @@ def wrapped_program(example):
print(e)
trace = dspy.settings.trace.copy()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to handle additional metric metadata in addition to the score. To do this, we check if the output from the metric is a float or int (in which case we use it as the score) or a dspy.Prediction object

output = None
score = 0.0
output_metadata = {}

try:
score = metric(example, prediction)
output = metric(example, prediction)
if isinstance(output, (int, float)):
score = output
elif isinstance(output, dspy.Prediction):
if not hasattr(output, 'score'):
raise ValueError("dspy.Prediction must contain a 'score' attribute")
score = output.score
# Just extract fields from _store, excluding 'score'
output_metadata = {
k: v for k, v in output._store.items() if k != "score"
}
except Exception as e:
print(e)

# Include the `example` in the output for subsequent usage in buckets/strategies.
return {
"prediction": prediction,
"trace": trace,
"score": score,
"example": example
"example": example,
"output_metadata": output_metadata
}

return wrapped_program



def append_a_demo(demo_input_field_maxlen):
def append_a_demo_(bucket, system, **kwargs):
predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"]
batch_10p_score = kwargs["batch_10p_score"]

trace = bucket[0]["trace"]
good = bucket[0]
trace = good["trace"]
name2demo = {}

Copy link
Collaborator Author

@klopsahlong klopsahlong Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double checking that the demo we're appending is not below the 10th percentile of scores

if good["score"] < batch_10p_score:
logger.info(f"Skipping appending a demo as good score {good['score']} is below the 10th percentile.")
return False

for step in trace:
predictor, _inputs, _outputs = step

Expand All @@ -62,7 +97,6 @@ def append_a_demo_(bucket, system, **kwargs):
demo = dspy.Example(augmented=True, **_inputs, **_outputs)
name = predictor2name[id(predictor)]
name2demo[name] = demo # keep the last demo for each predictor

for name, demo in name2demo.items():
predictor = name2predictor[name]
predictor.demos.append(demo)
Expand All @@ -76,6 +110,7 @@ def append_a_demo_(bucket, system, **kwargs):
def append_a_rule(bucket, system, **kwargs):
predictor2name = kwargs["predictor2name"]
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"]
prompt_model = kwargs["prompt_model"]

module_names = [name for name, _ in system.named_predictors()]
good, bad = bucket[0], bucket[-1]
Expand Down Expand Up @@ -116,12 +151,16 @@ def append_a_rule(bucket, system, **kwargs):
worse_program_outputs=dict(bad["prediction"] or {}),
worse_reward_value=bad["score"],
better_reward_value=good["score"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, adding additional metric metadata here

worse_reward_info=bad["output_metadata"],
better_reward_info=good["output_metadata"],
module_names=module_names,
)

kwargs = {k: v if isinstance(v, str) else ujson.dumps(recursive_mask(v), indent=2)
for k, v in kwargs.items()}
advice = dspy.Predict(OfferFeedback)(**kwargs).module_advice

advice_program = dspy.Predict(OfferFeedback)
advice = advice_program(**kwargs).module_advice

for name, predictor in system.named_predictors():
if name in advice:
Expand Down Expand Up @@ -155,11 +194,13 @@ class OfferFeedback(dspy.Signature):
)
worse_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing")
worse_reward_value: float = InputField(desc="The reward value assigned to the program's outputs")
worse_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.")
better_program_trajectory: str = InputField(
desc="The trajectory of the program's execution, showing each module's I/O"
)
better_program_outputs: str = InputField(desc="The outputs of the program that we are analyzing")
better_reward_value: float = InputField(desc="The reward value assigned to the program's outputs")
better_reward_info: str = InputField(desc="Additional information that might be helpful to understanding the assigned reward value.")
module_names: list[str] = InputField(desc="The names of the modules in the program, for which we seek advice")
discussion: str = OutputField(desc="Discussing blame of where each module went wrong, if it did")
module_advice: dict[str, str] = OutputField(
Expand All @@ -169,7 +210,6 @@ class OfferFeedback(dspy.Signature):
"like the successful trajectory rather than the lower-scoring trajectory."
)


def inspect_modules(program):
separator = "-" * 80
output = [separator]
Expand Down
Loading