Skip to content

Implement adapter retry for Pydantic Validation Error #8050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.retry_adapter import RetryAdapter
from dspy.adapters.types import Image, History

DEFAULT_ADAPTER = RetryAdapter(main_adapter=ChatAdapter(), fallback_adapter=JSONAdapter())

__all__ = [
"Adapter",
"ChatAdapter",
"JSONAdapter",
"RetryAdapter",
"Image",
"History",
]
9 changes: 4 additions & 5 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Any, Optional, Type

import logging
from dspy.adapters.types import History
from dspy.adapters.types.image import try_expand_image_tags
from dspy.signatures.signature import Signature
Expand All @@ -8,6 +8,7 @@
if TYPE_CHECKING:
from dspy.clients.lm import LM

logger = logging.getLogger(__name__)

class Adapter:
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
Expand All @@ -28,11 +29,9 @@ def __call__(
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
inputs = self.format(signature, demos, inputs)

outputs = lm(messages=inputs, **lm_kwargs)
messages = self.format(signature=signature, demos=demos, inputs=inputs)
outputs = lm(messages=messages, **lm_kwargs)
values = []

for output in outputs:
output_logprobs = None

Expand Down
22 changes: 0 additions & 22 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import textwrap
from typing import Any, Dict, NamedTuple, Optional, Type

from litellm import ContextWindowExceededError
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
Expand All @@ -13,7 +12,6 @@
parse_value,
translate_field_type,
)
from dspy.clients.lm import LM
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback

Expand All @@ -29,26 +27,6 @@ class ChatAdapter(Adapter):
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
super().__init__(callbacks)

def __call__(
self,
lm: LM,
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
try:
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
except Exception as e:
# fallback to JSONAdapter
from dspy.adapters.json_adapter import JSONAdapter

if isinstance(e, ContextWindowExceededError) or isinstance(self, JSONAdapter):
# On context window exceeded error or already using JSONAdapter, we don't want to retry with a different
# adapter.
raise e
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)

def format_field_description(self, signature: Type[Signature]) -> str:
return (
f"Your input fields are:\n{get_field_description_string(signature.input_fields)}\n"
Expand Down
145 changes: 145 additions & 0 deletions dspy/adapters/retry_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@

from typing import TYPE_CHECKING, Any, Optional, Type
import logging

from dspy.adapters.base import Adapter
from dspy.signatures.signature import Signature
from dspy.adapters.utils import create_signature_for_retry

if TYPE_CHECKING:
from dspy.clients.lm import LM

logger = logging.getLogger(__name__)

class RetryAdapter(Adapter):
"""
RetryAdapter is an adapter that retries the execution of another adapter for
a specified number of times if it fails to parse completion outputs.
"""

def __init__(self, main_adapter: Adapter, fallback_adapter: Optional[Adapter] = None, max_retries: int = 3):
"""
Initializes the RetryAdapter.

Args:
main_adapter (Adapter): The main adapter to use.
fallback_adapter (Optional[Adapter]): The fallback adapter to use if the main adapter fails.
max_retries (int): The maximum number of retries. Defaults to 3.
"""
self.main_adapter = main_adapter
self.fallback_adapter = fallback_adapter
self.max_retries = max_retries

def __call__(
self,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
"""
Execute main_adapter and fallback_adapter in the following procedure:
1. Call the main_adapter.
2. If the main_adapter fails, call the fallback_adapter.
3. If the fallback_adapter fails, retry the main_adapter including previous response for `max_retries` times.

Args:
lm (LM): The dspy.LM to use.
lm_kwargs (dict[str, Any]): Additional arguments for the lm.
signature (Type[Signature]): The signature of the function.
demos (list[dict[str, Any]]): A list of demo examples.
inputs (dict[str, Any]): A list representating the user input.

Returns:
A list of parsed completions. The size of the list is equal to `n` argument. Defaults to 1.

Raises:
Exception: If fail to parse outputs after the maximum number of retries.
"""
outputs = []
max_retries = max(self.max_retries, 0)
n_completion = lm_kwargs.get("n", 1)

values, parse_failures = self._call_adapter(
self.main_adapter,
lm,
lm_kwargs,
signature,
demos,
inputs,
)
outputs.extend(values)

if len(outputs) == n_completion:
return outputs

lm_kwargs["n"] = n_completion - len(outputs)
if self.fallback_adapter is not None:
outputs.extend(self._call_adapter(
self.fallback_adapter,
lm,
lm_kwargs,
signature,
demos,
inputs,
)[0])
if len(outputs) == n_completion:
return outputs

# Retry the main adapter with previous response for `max_retries` times
lm_kwargs["n"] = 1
signature = create_signature_for_retry(signature)
if parse_failures:
inputs["previous_response"] = parse_failures[0][0]
inputs["error_message"] = str(parse_failures[0][1])
for i in range(max_retries):
values, parse_failures = self._call_adapter(
self.main_adapter,
lm,
lm_kwargs,
signature,
demos,
inputs,
)
outputs.extend(values)
if len(outputs) == n_completion:
return outputs
logger.warning(f"Retry {i+1}/{max_retries} for {self.main_adapter.__class__.__name__} failed with error: {parse_failures[0][1]}")
inputs["previous_response"] = parse_failures[0][0]
inputs["error_message"] = str(parse_failures[0][1])

# raise the last error
raise ValueError("Failed to parse LM outputs for maximum retries.") from parse_failures[0][1]

def _call_adapter(
self,
adapter: Adapter,
lm: "LM",
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
):
values = []
parse_failures = []
messages = adapter.format(signature=signature, demos=demos, inputs=inputs)
outputs = lm(messages=messages, **lm_kwargs)
for i, output in enumerate(outputs):
try:
output_logprobs = None

if isinstance(output, dict):
output, output_logprobs = output["text"], output["logprobs"]

value = adapter.parse(signature, output)

if output_logprobs is not None:
value["logprobs"] = output_logprobs

values.append(value)
except ValueError as e:
logger.warning(f"Failed to parse the {i+1}/{lm_kwargs.get('n', 1)} LM output with adapter {adapter.__class__.__name__}. Error: {e}")
parse_failures.append((outputs[i], e))

return values, parse_failures
13 changes: 13 additions & 0 deletions dspy/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Type
import ast
import enum
import inspect
Expand All @@ -11,6 +12,8 @@
from pydantic.fields import FieldInfo

from dspy.signatures.utils import get_dspy_field_type
from dspy.signatures.field import InputField
from dspy.signatures.signature import Signature


def serialize_for_json(value: Any) -> Any:
Expand Down Expand Up @@ -237,3 +240,13 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str:
else:
# Neither => enclose in single quotes
return f"'{s}'"

def create_signature_for_retry(signature: Type[Signature]):
signature = signature.append("previous_response", InputField(
prefix="Previous Response",
desc="Previous response with format errors. You should avoid the same type of error as the previous response.",
)).append("error_message", InputField(
prefix="Validation Error Message",
desc="Error message for the previous response.",
))
return signature
4 changes: 2 additions & 2 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters import DEFAULT_ADAPTER
from dspy.clients.base_lm import BaseLM
from dspy.clients.lm import LM
from dspy.dsp.utils import settings
Expand Down Expand Up @@ -103,7 +103,7 @@ def forward(self, **kwargs):
missing,
)

adapter = settings.adapter or ChatAdapter()
adapter = settings.adapter or DEFAULT_ADAPTER
completions = adapter(
lm,
lm_kwargs=config,
Expand Down
3 changes: 2 additions & 1 deletion dspy/predict/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import dspy
from dspy.adapters.utils import get_field_description_string
from dspy.adapters import DEFAULT_ADAPTER
from dspy.predict.predict import Prediction
from dspy.signatures import InputField, OutputField, Signature

Expand Down Expand Up @@ -100,7 +101,7 @@ def forward(self, **kwargs):
temps = list(dict.fromkeys(temps))[: self.N]
best_pred, best_trace, best_reward = None, None, -float("inf")
advice = None
adapter = dspy.settings.adapter or dspy.ChatAdapter()
adapter = dspy.settings.adapter or DEFAULT_ADAPTER

for idx, t in enumerate(temps):
lm_ = lm.copy(temperature=t)
Expand Down
74 changes: 0 additions & 74 deletions dspy/predict/retry.py

This file was deleted.

4 changes: 2 additions & 2 deletions dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dspy
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters import DEFAULT_ADAPTER
from dspy.clients.lm import LM
from dspy.clients.utils_finetune import infer_data_format
from dspy.dsp.utils.settings import settings
Expand Down Expand Up @@ -162,7 +162,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
logger.info(f"After filtering with the metric, {len(trace_data)} examples remain")

data = []
adapter = self.adapter[lm] or settings.adapter or ChatAdapter()
adapter = self.adapter[lm] or settings.adapter or DEFAULT_ADAPTER
data_format = infer_data_format(adapter)
for item in trace_data:
for pred_ind, _ in enumerate(item["trace"]):
Expand Down
Loading
Loading