Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions instructor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from_litellm,
)
from .core import hooks
from .core.exceptions import TokenBudgetExceeded
from .utils.providers import Provider
from .auto_client import from_provider
from .batch import BatchProcessor, BatchRequest, BatchJob
Expand Down Expand Up @@ -65,6 +66,7 @@
"llm_validator",
"openai_moderation",
"hooks",
"TokenBudgetExceeded",
"client", # Backward compatibility
# Backward compatibility exports
"handle_response_model",
Expand Down
63 changes: 63 additions & 0 deletions instructor/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,69 @@ def __init__(
super().__init__(f"{message}{context}", *args, **kwargs)


class TokenBudgetExceeded(InstructorError):
"""Exception raised when retry token budget is exhausted.

This exception is raised when the cumulative token usage across retry
attempts exceeds the configured token budget. This helps prevent
runaway costs from adversarial or malformed LLM responses that
repeatedly fail validation.

Attributes:
token_budget: The maximum token budget that was exceeded
total_tokens_used: The actual tokens used before exceeding the budget
n_attempts: Number of retry attempts made
last_completion: The last completion received before budget exhaustion
failed_attempts: List of FailedAttempt objects with details about
each retry attempt

Security Context:
This exception helps mitigate retry amplification attacks where
an adversarial LLM (or prompt-injected response) crafts outputs
that always fail validation, causing:
- Context growth with each retry (2 messages per retry)
- Token budget exhaustion and cost amplification

Examples:
```python
try:
response = client.chat.completions.create(
response_model=StrictModel,
max_retries=10,
token_budget=10000, # Stop if we use more than 10k tokens
...
)
except TokenBudgetExceeded as e:
print(f"Token budget exceeded after {e.n_attempts} attempts")
print(f"Used {e.total_tokens_used} of {e.token_budget} tokens")
# Implement fallback or alert
```

See Also:
- InstructorRetryException: Raised when retry count is exhausted
"""

def __init__(
self,
*args: Any,
token_budget: int,
total_tokens_used: int,
n_attempts: int,
last_completion: Any | None = None,
failed_attempts: list[FailedAttempt] | None = None,
**kwargs: dict[str, Any],
):
self.token_budget = token_budget
self.total_tokens_used = total_tokens_used
self.n_attempts = n_attempts
self.last_completion = last_completion
message = (
f"Token budget exceeded: used {total_tokens_used} tokens "
f"(budget: {token_budget}) after {n_attempts} attempts"
)
super().__init__(message, *args, failed_attempts=failed_attempts, **kwargs)


class MultimodalError(ValueError, InstructorError):
"""Exception raised for multimodal content processing errors.

Expand Down
4 changes: 4 additions & 0 deletions instructor/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ async def new_create_async(
max_retries: int | AsyncRetrying = 1,
strict: bool = True,
hooks: Hooks | None = None,
token_budget: int | None = None,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
Expand Down Expand Up @@ -195,6 +196,7 @@ async def new_create_async(
strict=strict,
mode=mode,
hooks=hooks,
token_budget=token_budget,
)

# Store in cache *after* successful call
Expand All @@ -219,6 +221,7 @@ def new_create_sync(
max_retries: int | Retrying = 1,
strict: bool = True,
hooks: Hooks | None = None,
token_budget: int | None = None,
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
Expand Down Expand Up @@ -265,6 +268,7 @@ def new_create_sync(
strict=strict,
kwargs=new_kwargs,
mode=mode,
token_budget=token_budget,
)

# Save to cache
Expand Down
61 changes: 61 additions & 0 deletions instructor/core/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
InstructorRetryException,
AsyncValidationError,
FailedAttempt,
TokenBudgetExceeded,
ValidationError as InstructorValidationError,
)
from .hooks import Hooks
Expand Down Expand Up @@ -140,6 +141,34 @@ def extract_messages(kwargs: dict[str, Any]) -> Any:
return []


def get_total_tokens(usage: Any) -> int:
"""
Extract total token count from usage object.

Handles different usage object formats across providers (OpenAI, Anthropic, etc.)

Args:
usage: Usage object from the provider

Returns:
int: Total tokens used
"""
if usage is None:
return 0

# OpenAI-style usage
if hasattr(usage, "total_tokens"):
return usage.total_tokens or 0

# Anthropic-style usage
if hasattr(usage, "input_tokens") and hasattr(usage, "output_tokens"):
input_tokens = usage.input_tokens or 0
output_tokens = usage.output_tokens or 0
return input_tokens + output_tokens

return 0


def retry_sync(
func: Callable[T_ParamSpec, T_Retval],
response_model: type[T_Model] | None,
Expand All @@ -150,6 +179,7 @@ def retry_sync(
strict: bool | None = None,
mode: Mode = Mode.TOOLS,
hooks: Hooks | None = None,
token_budget: int | None = None,
) -> T_Model | None:
"""
Retry a synchronous function upon specified exceptions.
Expand All @@ -164,12 +194,15 @@ def retry_sync(
strict (Optional[bool], optional): Strict mode flag. Defaults to None.
mode (Mode, optional): The mode of operation. Defaults to Mode.TOOLS.
hooks (Optional[Hooks], optional): Hooks for emitting events. Defaults to None.
token_budget (Optional[int], optional): Maximum total tokens allowed across retries.
If exceeded, raises TokenBudgetExceeded. Helps prevent retry amplification attacks.

Returns:
T_Model | None: The processed response model or None.

Raises:
InstructorRetryException: If all retry attempts fail.
TokenBudgetExceeded: If token_budget is set and exceeded during retries.
"""
hooks = hooks or Hooks()
total_usage = initialize_usage(mode)
Expand All @@ -196,6 +229,18 @@ def retry_sync(
response=response, total_usage=total_usage
)

# Check token budget after each attempt
if token_budget is not None:
current_tokens = get_total_tokens(total_usage)
if current_tokens > token_budget:
raise TokenBudgetExceeded(
token_budget=token_budget,
total_tokens_used=current_tokens,
n_attempts=attempt.retry_state.attempt_number,
last_completion=response,
failed_attempts=failed_attempts,
)

return process_response( # type: ignore
response=response,
response_model=response_model,
Expand Down Expand Up @@ -306,6 +351,7 @@ async def retry_async(
strict: bool | None = None,
mode: Mode = Mode.TOOLS,
hooks: Hooks | None = None,
token_budget: int | None = None,
) -> T_Model | None:
"""
Retry an asynchronous function upon specified exceptions.
Expand All @@ -320,12 +366,15 @@ async def retry_async(
strict (Optional[bool], optional): Strict mode flag. Defaults to None.
mode (Mode, optional): The mode of operation. Defaults to Mode.TOOLS.
hooks (Optional[Hooks], optional): Hooks for emitting events. Defaults to None.
token_budget (Optional[int], optional): Maximum total tokens allowed across retries.
If exceeded, raises TokenBudgetExceeded. Helps prevent retry amplification attacks.

Returns:
T_Model | None: The processed response model or None.

Raises:
InstructorRetryException: If all retry attempts fail.
TokenBudgetExceeded: If token_budget is set and exceeded during retries.
"""
hooks = hooks or Hooks()
total_usage = initialize_usage(mode)
Expand All @@ -352,6 +401,18 @@ async def retry_async(
response=response, total_usage=total_usage
)

# Check token budget after each attempt
if token_budget is not None:
current_tokens = get_total_tokens(total_usage)
if current_tokens > token_budget:
raise TokenBudgetExceeded(
token_budget=token_budget,
total_tokens_used=current_tokens,
n_attempts=attempt.retry_state.attempt_number,
last_completion=response,
failed_attempts=failed_attempts,
)

return await process_response_async(
response=response,
response_model=response_model,
Expand Down
14 changes: 13 additions & 1 deletion instructor/validation/llm_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class User(BaseModel):
"""

def llm(v: str) -> str:
# Sanitize value to prevent prompt injection by escaping delimiters
# and using explicit structured format
sanitized_value = v.replace("```", "\\`\\`\\`").replace("---", "\\-\\-\\-")

resp = client.chat.completions.create(
response_model=Validator,
messages=[
Expand All @@ -57,7 +61,15 @@ def llm(v: str) -> str:
},
{
"role": "user",
"content": f"Does `{v}` follow the rules: {statement}",
"content": f"""Validate the following value against the rules.

---BEGIN VALUE---
{sanitized_value}
---END VALUE---

Rules to validate against: {statement}

Is this value valid according to the rules above?""",
},
],
model=model,
Expand Down
44 changes: 44 additions & 0 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ModeError,
ClientError,
FailedAttempt,
TokenBudgetExceeded,
)


Expand All @@ -26,6 +27,7 @@ def test_all_exceptions_can_be_imported():
assert ConfigurationError is not None
assert ModeError is not None
assert ClientError is not None
assert TokenBudgetExceeded is not None


def test_exception_hierarchy():
Expand All @@ -37,6 +39,7 @@ def test_exception_hierarchy():
assert issubclass(ConfigurationError, InstructorError)
assert issubclass(ModeError, InstructorError)
assert issubclass(ClientError, InstructorError)
assert issubclass(TokenBudgetExceeded, InstructorError)


def test_base_instructor_error_can_be_caught():
Expand Down Expand Up @@ -594,3 +597,44 @@ def test_failed_attempts_exception_chaining():
assert chained_error.failed_attempts is not None
assert len(chained_error.failed_attempts) == 1
assert chained_error.failed_attempts[0].exception.args[0] == "Original failure"


def test_token_budget_exceeded():
"""Test TokenBudgetExceeded attributes and catching."""
token_budget = 10000
total_tokens_used = 12500
n_attempts = 5
last_completion = {"content": "partial response"}
failed_attempts = [
FailedAttempt(1, Exception("Validation failed"), "attempt 1"),
FailedAttempt(2, Exception("Validation failed"), "attempt 2"),
]

with pytest.raises(TokenBudgetExceeded) as exc_info:
raise TokenBudgetExceeded(
token_budget=token_budget,
total_tokens_used=total_tokens_used,
n_attempts=n_attempts,
last_completion=last_completion,
failed_attempts=failed_attempts,
)

exception = exc_info.value
assert exception.token_budget == token_budget
assert exception.total_tokens_used == total_tokens_used
assert exception.n_attempts == n_attempts
assert exception.last_completion == last_completion
assert exception.failed_attempts == failed_attempts
assert "12500" in str(exception)
assert "10000" in str(exception)
assert "5 attempts" in str(exception)


def test_token_budget_exceeded_inherits_from_instructor_error():
"""Test that TokenBudgetExceeded can be caught as InstructorError."""
with pytest.raises(InstructorError):
raise TokenBudgetExceeded(
token_budget=1000,
total_tokens_used=1500,
n_attempts=3,
)
Loading