Skip to content

Commit 6035fba

Browse files
authored
Add helper functions to detect bad responses (#31)
1 parent cd6f2f8 commit 6035fba

File tree

4 files changed

+534
-0
lines changed

4 files changed

+534
-0
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ extra-dependencies = [
4343
"pytest",
4444
"llama-index-core",
4545
"smolagents",
46+
"cleanlab-studio",
47+
"thefuzz",
4648
"langchain-core",
4749
]
4850
[tool.hatch.envs.types.scripts]
@@ -54,6 +56,8 @@ allow-direct-references = true
5456
extra-dependencies = [
5557
"llama-index-core",
5658
"smolagents; python_version >= '3.10'",
59+
"cleanlab-studio",
60+
"thefuzz",
5761
"langchain-core",
5862
]
5963

Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
"""
2+
This module provides validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
from typing import (
8+
Any,
9+
Callable,
10+
Dict,
11+
Optional,
12+
Protocol,
13+
Sequence,
14+
Union,
15+
cast,
16+
runtime_checkable,
17+
)
18+
19+
from pydantic import BaseModel, ConfigDict, Field
20+
21+
from cleanlab_codex.utils.errors import MissingDependencyError
22+
from cleanlab_codex.utils.prompt import default_format_prompt
23+
24+
25+
@runtime_checkable
26+
class TLM(Protocol):
27+
def get_trustworthiness_score(
28+
self,
29+
prompt: Union[str, Sequence[str]],
30+
response: Union[str, Sequence[str]],
31+
**kwargs: Any,
32+
) -> Dict[str, Any]: ...
33+
34+
def prompt(
35+
self,
36+
prompt: Union[str, Sequence[str]],
37+
/,
38+
**kwargs: Any,
39+
) -> Dict[str, Any]: ...
40+
41+
42+
DEFAULT_FALLBACK_ANSWER: str = (
43+
"Based on the available information, I cannot provide a complete answer to this question."
44+
)
45+
DEFAULT_FALLBACK_SIMILARITY_THRESHOLD: int = 70
46+
DEFAULT_TRUSTWORTHINESS_THRESHOLD: float = 0.5
47+
48+
Query = str
49+
Context = str
50+
Prompt = str
51+
52+
53+
class BadResponseDetectionConfig(BaseModel):
54+
"""Configuration for bad response detection functions."""
55+
56+
model_config = ConfigDict(arbitrary_types_allowed=True)
57+
58+
# Fallback check config
59+
fallback_answer: str = Field(
60+
default=DEFAULT_FALLBACK_ANSWER, description="Known unhelpful response to compare against"
61+
)
62+
fallback_similarity_threshold: int = Field(
63+
default=DEFAULT_FALLBACK_SIMILARITY_THRESHOLD,
64+
description="Fuzzy matching similarity threshold (0-100). Higher values mean responses must be more similar to fallback_answer to be considered bad.",
65+
)
66+
67+
# Untrustworthy check config
68+
trustworthiness_threshold: float = Field(
69+
default=DEFAULT_TRUSTWORTHINESS_THRESHOLD,
70+
description="Score threshold (0.0-1.0). Lower values allow less trustworthy responses.",
71+
)
72+
format_prompt: Callable[[Query, Context], Prompt] = Field(
73+
default=default_format_prompt,
74+
description="Function to format (query, context) into a prompt string.",
75+
)
76+
77+
# Unhelpful check config
78+
unhelpfulness_confidence_threshold: Optional[float] = Field(
79+
default=None,
80+
description="Optional confidence threshold (0.0-1.0) for unhelpful classification.",
81+
)
82+
83+
# Shared config (for untrustworthiness and unhelpfulness checks)
84+
tlm: Optional[TLM] = Field(
85+
default=None,
86+
description="TLM model to use for evaluation (required for untrustworthiness and unhelpfulness checks).",
87+
)
88+
89+
90+
DEFAULT_CONFIG = BadResponseDetectionConfig()
91+
92+
93+
def is_bad_response(
94+
response: str,
95+
*,
96+
context: Optional[str] = None,
97+
query: Optional[str] = None,
98+
config: Union[BadResponseDetectionConfig, Dict[str, Any]] = DEFAULT_CONFIG,
99+
) -> bool:
100+
"""Run a series of checks to determine if a response is bad.
101+
102+
If any check detects an issue (i.e. fails), the function returns True, indicating the response is bad.
103+
104+
This function runs three possible validation checks:
105+
1. **Fallback check**: Detects if response is too similar to a known fallback answer.
106+
2. **Untrustworthy check**: Assesses response trustworthiness based on the given context and query.
107+
3. **Unhelpful check**: Predicts if the response adequately answers the query or not, in a useful way.
108+
109+
Note:
110+
Each validation check runs conditionally based on whether the required arguments are provided.
111+
As soon as any validation check fails, the function returns True.
112+
113+
Args:
114+
response: The response to check.
115+
context: Optional context/documents used for answering. Required for untrustworthy check.
116+
query: Optional user question. Required for untrustworthy and unhelpful checks.
117+
config: Optional, typed dictionary of configuration parameters. See <_BadReponseConfig> for details.
118+
119+
Returns:
120+
bool: True if any validation check fails, False if all pass.
121+
"""
122+
config = BadResponseDetectionConfig.model_validate(config)
123+
124+
validation_checks: list[Callable[[], bool]] = []
125+
126+
# All required inputs are available for checking fallback responses
127+
validation_checks.append(
128+
lambda: is_fallback_response(
129+
response,
130+
config.fallback_answer,
131+
threshold=config.fallback_similarity_threshold,
132+
)
133+
)
134+
135+
can_run_untrustworthy_check = query is not None and context is not None and config.tlm is not None
136+
if can_run_untrustworthy_check:
137+
# The if condition guarantees these are not None
138+
validation_checks.append(
139+
lambda: is_untrustworthy_response(
140+
response=response,
141+
context=cast(str, context),
142+
query=cast(str, query),
143+
tlm=cast(TLM, config.tlm),
144+
trustworthiness_threshold=config.trustworthiness_threshold,
145+
format_prompt=config.format_prompt,
146+
)
147+
)
148+
149+
can_run_unhelpful_check = query is not None and config.tlm is not None
150+
if can_run_unhelpful_check:
151+
validation_checks.append(
152+
lambda: is_unhelpful_response(
153+
response=response,
154+
query=cast(str, query),
155+
tlm=cast(TLM, config.tlm),
156+
trustworthiness_score_threshold=cast(float, config.unhelpfulness_confidence_threshold),
157+
)
158+
)
159+
160+
return any(check() for check in validation_checks)
161+
162+
163+
def is_fallback_response(
164+
response: str,
165+
fallback_answer: str = DEFAULT_FALLBACK_ANSWER,
166+
threshold: int = DEFAULT_FALLBACK_SIMILARITY_THRESHOLD,
167+
) -> bool:
168+
"""Check if a response is too similar to a known fallback answer.
169+
170+
Uses fuzzy string matching to compare the response against a known fallback answer.
171+
Returns True if the response is similar enough to be considered unhelpful.
172+
173+
Args:
174+
response: The response to check.
175+
fallback_answer: A known unhelpful/fallback response to compare against.
176+
threshold: Similarity threshold (0-100). Higher values require more similarity.
177+
Default 70 means responses that are 70% or more similar are considered bad.
178+
179+
Returns:
180+
bool: True if the response is too similar to the fallback answer, False otherwise
181+
"""
182+
try:
183+
from thefuzz import fuzz # type: ignore
184+
except ImportError as e:
185+
raise MissingDependencyError(
186+
import_name=e.name or "thefuzz",
187+
package_url="https://github.com/seatgeek/thefuzz",
188+
) from e
189+
190+
partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower())
191+
return bool(partial_ratio >= threshold)
192+
193+
194+
def is_untrustworthy_response(
195+
response: str,
196+
context: str,
197+
query: str,
198+
tlm: TLM,
199+
trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD,
200+
format_prompt: Callable[[str, str], str] = default_format_prompt,
201+
) -> bool:
202+
"""Check if a response is untrustworthy.
203+
204+
Uses TLM to evaluate whether a response is trustworthy given the context and query.
205+
Returns True if TLM's trustworthiness score falls below the threshold, indicating
206+
the response may be incorrect or unreliable.
207+
208+
Args:
209+
response: The response to check from the assistant
210+
context: The context information available for answering the query
211+
query: The user's question or request
212+
tlm: The TLM model to use for evaluation
213+
trustworthiness_threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses.
214+
Default 0.5, meaning responses with scores less than 0.5 are considered untrustworthy.
215+
format_prompt: Function that takes (query, context) and returns a formatted prompt string.
216+
Users should provide their RAG app's own prompt formatting function here
217+
to match how their LLM is prompted.
218+
219+
Returns:
220+
bool: True if the response is deemed untrustworthy by TLM, False otherwise
221+
"""
222+
try:
223+
from cleanlab_studio import Studio # type: ignore[import-untyped] # noqa: F401
224+
except ImportError as e:
225+
raise MissingDependencyError(
226+
import_name=e.name or "cleanlab_studio",
227+
package_name="cleanlab-studio",
228+
package_url="https://github.com/cleanlab/cleanlab-studio",
229+
) from e
230+
231+
prompt = format_prompt(query, context)
232+
result = tlm.get_trustworthiness_score(prompt, response)
233+
score: float = result["trustworthiness_score"]
234+
return score < trustworthiness_threshold
235+
236+
237+
def is_unhelpful_response(
238+
response: str,
239+
query: str,
240+
tlm: TLM,
241+
trustworthiness_score_threshold: Optional[float] = None,
242+
) -> bool:
243+
"""Check if a response is unhelpful by asking TLM to evaluate it.
244+
245+
Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment.
246+
The evaluation considers both the TLM's binary classification of helpfulness and its
247+
confidence score. Returns True only if TLM classifies the response as unhelpful AND
248+
is sufficiently confident in that assessment (if a threshold is provided).
249+
250+
Args:
251+
response: The response to check
252+
query: User query that will be used to evaluate if the response is helpful
253+
tlm: The TLM model to use for evaluation
254+
trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0)
255+
If provided and the response is marked as unhelpful,
256+
the confidence score must exceed this threshold for
257+
the response to be considered truly unhelpful.
258+
259+
Returns:
260+
bool: True if TLM determines the response is unhelpful with sufficient confidence,
261+
False otherwise
262+
"""
263+
try:
264+
from cleanlab_studio import Studio # noqa: F401
265+
except ImportError as e:
266+
raise MissingDependencyError(
267+
import_name=e.name or "cleanlab_studio",
268+
package_name="cleanlab-studio",
269+
package_url="https://github.com/cleanlab/cleanlab-studio",
270+
) from e
271+
272+
# If editing `question`, make sure `expected_unhelpful_response` is still correct:
273+
# - When asking "is helpful?" -> "no" means unhelpful
274+
# - When asking "is unhelpful?" -> "yes" means unhelpful
275+
question = (
276+
"Does the AI Assistant Response seem unhelpful? "
277+
"Things that are not helpful include answers that:\n"
278+
"- Are not useful, incomplete, incorrect, uncertain or unclear.\n"
279+
"- Abstain or refuse to answer the question\n"
280+
"- Statements which are similar to 'I don't know', 'Sorry', or 'No information available'.\n"
281+
"- Leave the original question unresolved\n"
282+
"- Are irrelevant to the question\n"
283+
"Answer Yes/No only."
284+
)
285+
expected_unhelpful_response = "yes"
286+
287+
prompt = (
288+
"Consider the following User Query and AI Assistant Response.\n\n"
289+
f"User Query: {query}\n\n"
290+
f"AI Assistant Response: {response}\n\n"
291+
f"{question}"
292+
)
293+
294+
output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"])
295+
response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response
296+
is_trustworthy = trustworthiness_score_threshold is None or (
297+
output["trustworthiness_score"] > trustworthiness_score_threshold
298+
)
299+
return response_marked_unhelpful and is_trustworthy

src/cleanlab_codex/utils/prompt.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
Helper functions for processing prompts in RAG applications.
3+
"""
4+
5+
6+
def default_format_prompt(query: str, context: str) -> str:
7+
"""Default function for formatting RAG prompts.
8+
9+
Args:
10+
query: The user's question
11+
context: The context/documents to use for answering
12+
13+
Returns:
14+
str: A formatted prompt combining the query and context
15+
"""
16+
template = (
17+
"Using only information from the following Context, answer the following Query.\n\n"
18+
"Context:\n{context}\n\n"
19+
"Query: {query}"
20+
)
21+
return template.format(context=context, query=query)

0 commit comments

Comments
 (0)