Skip to content

Commit 57b7399

Browse files
committed
fix(security): prevent prompt injection in llm_validator
Fixes #2056. - Add _format_validation_prompt() with XML-style escaping to prevent prompt injection attacks. User input is now wrapped in <user_value> tags to clearly separate it from instructions. - Add escape_user_input parameter (default: True) for backward compatibility when needed. - Fix allow_override logic bug: now correctly returns fixed_value when validation fails instead of only when valid. - Add async_llm_validator() for async validation pipelines. - Update exports in validation/__init__.py and instructor/__init__.py. - Add comprehensive unit tests for security improvements. Security Note: This prevents attackers from crafting inputs that could manipulate the LLM validator into always returning is_valid=true.
1 parent caea4ab commit 57b7399

File tree

5 files changed

+518
-59
lines changed

5 files changed

+518
-59
lines changed

instructor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
IterableModel,
1313
)
1414

15-
from .validation import llm_validator, openai_moderation
15+
from .validation import llm_validator, async_llm_validator, openai_moderation
1616
from .processing.function_calls import OpenAISchema, openai_schema
1717
from .processing.schema import (
1818
generate_openai_schema,
@@ -63,6 +63,7 @@
6363
"BatchRequest",
6464
"BatchJob",
6565
"llm_validator",
66+
"async_llm_validator",
6667
"openai_moderation",
6768
"hooks",
6869
"client", # Backward compatibility

instructor/processing/validators.py

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
1-
"""Validators that extend OpenAISchema for structured outputs."""
1+
"""Validators that extend OpenAISchema for structured outputs.
22
3-
from typing import Optional
3+
This module provides validation classes for LLM-based validation:
4+
- Validator: Original simple validator (backward compatible)
5+
- ValidationResult[T]: Enhanced generic validator with confidence scoring
6+
"""
47

5-
from pydantic import Field
8+
from typing import Any, Generic, Optional, TypeVar
9+
10+
from pydantic import BaseModel, ConfigDict, Field, field_validator
611

712
from .function_calls import OpenAISchema
813

914

15+
# Type variable for generic validation result
16+
T = TypeVar("T")
17+
18+
1019
class Validator(OpenAISchema):
1120
"""
1221
Validate if an attribute is correct and if not,
13-
return a new value with an error message
22+
return a new value with an error message.
23+
24+
This is the original validator class, maintained for backward compatibility.
25+
For new implementations, consider using ValidationResult[T] for better type safety.
1426
"""
1527

1628
is_valid: bool = Field(
@@ -25,3 +37,114 @@ class Validator(OpenAISchema):
2537
default=None,
2638
description="If the attribute is not valid, suggest a new value for the attribute",
2739
)
40+
41+
42+
class ValidationResult(BaseModel, Generic[T]):
43+
"""
44+
Enhanced validator supporting any type and confidence scoring.
45+
46+
This is a Pydantic v2-native validator that provides:
47+
- Generic type support for fixed_value (not just strings)
48+
- Confidence scoring for LLM-based validation decisions
49+
- Multiple error message support
50+
- Strict model configuration
51+
52+
Example usage:
53+
```python
54+
from instructor.processing.validators import ValidationResult
55+
from pydantic import BaseModel
56+
57+
class UserName(BaseModel):
58+
first: str
59+
last: str
60+
61+
# The LLM can return ValidationResult with properly typed fixed_value
62+
result: ValidationResult[UserName] = client.chat.completions.create(
63+
response_model=ValidationResult[UserName],
64+
messages=[...],
65+
)
66+
67+
if not result.is_valid:
68+
print(f"Confidence: {result.confidence}")
69+
print(f"Errors: {result.errors}")
70+
if result.fixed_value:
71+
corrected_name = result.fixed_value # Type: UserName
72+
```
73+
74+
Attributes:
75+
is_valid: Whether the validated value passes all requirements.
76+
confidence: Confidence score (0.0-1.0) for LLM-based validation decisions.
77+
Higher values indicate more certainty in the validation result.
78+
errors: List of validation error messages. Can contain multiple specific issues.
79+
reason: Primary error message (for backward compatibility with Validator).
80+
fixed_value: Suggested corrected value of the same type as the input.
81+
"""
82+
83+
model_config = ConfigDict(
84+
extra="forbid",
85+
validate_assignment=True,
86+
)
87+
88+
is_valid: bool = Field(
89+
default=True,
90+
description="Whether the attribute is valid based on the requirements",
91+
)
92+
confidence: float = Field(
93+
default=1.0,
94+
ge=0.0,
95+
le=1.0,
96+
description=(
97+
"Confidence score (0.0-1.0) for the validation result. "
98+
"1.0 means completely certain, 0.0 means no confidence."
99+
),
100+
)
101+
errors: list[str] = Field(
102+
default_factory=list,
103+
description="List of validation error messages if the attribute is not valid",
104+
)
105+
reason: Optional[str] = Field(
106+
default=None,
107+
description=(
108+
"Primary error message if the attribute is not valid. "
109+
"For backward compatibility with Validator class."
110+
),
111+
)
112+
fixed_value: Optional[T] = Field(
113+
default=None,
114+
description=(
115+
"Suggested corrected value of the same type as the validated input. "
116+
"Only provided when is_valid is False and a fix is possible."
117+
),
118+
)
119+
120+
@field_validator("errors", mode="before")
121+
@classmethod
122+
def ensure_errors_list(cls, v: Any) -> list[str]:
123+
"""Ensure errors is always a list."""
124+
if v is None:
125+
return []
126+
if isinstance(v, str):
127+
return [v]
128+
return list(v)
129+
130+
def model_post_init(self, __context: Any) -> None:
131+
"""Sync reason with errors for consistency."""
132+
# If reason is provided but errors is empty, populate errors
133+
if self.reason and not self.errors:
134+
self.errors = [self.reason]
135+
# If errors exist but reason is empty, set reason to first error
136+
elif self.errors and not self.reason:
137+
object.__setattr__(self, "reason", self.errors[0])
138+
139+
@property
140+
def is_confident(self) -> bool:
141+
"""Check if validation result has high confidence (>= 0.8)."""
142+
return self.confidence >= 0.8
143+
144+
def get_summary(self) -> str:
145+
"""Get a human-readable summary of the validation result."""
146+
if self.is_valid:
147+
return f"Valid (confidence: {self.confidence:.0%})"
148+
error_summary = "; ".join(self.errors) if self.errors else "Unknown error"
149+
return f"Invalid (confidence: {self.confidence:.0%}): {error_summary}"
150+

instructor/validation/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
ASYNC_MODEL_VALIDATOR_KEY,
99
)
1010
from ..core.exceptions import AsyncValidationError
11-
from .llm_validators import Validator, llm_validator, openai_moderation
11+
from .llm_validators import llm_validator, async_llm_validator, openai_moderation
1212

1313
__all__ = [
1414
"AsyncValidationContext",
@@ -17,7 +17,7 @@
1717
"async_model_validator",
1818
"ASYNC_VALIDATOR_KEY",
1919
"ASYNC_MODEL_VALIDATOR_KEY",
20-
"Validator",
2120
"llm_validator",
21+
"async_llm_validator",
2222
"openai_moderation",
2323
]

0 commit comments

Comments
 (0)