-
Notifications
You must be signed in to change notification settings - Fork 3
Add validator module #61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
1bc7370
2529ae6
722d287
6f64a12
a2c0ea5
b8a1e97
db5fe24
29e231a
a741e15
380b1ef
4f40e3d
02b16e0
873f552
b471371
be4745c
54e866b
c632625
d422bcf
2ae9b0f
7322026
8089c17
f8aeb52
0f602e3
76ca4c3
3e4d8bb
ac8762f
84799a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from __future__ import annotations | ||
|
||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
from cleanlab_codex.utils.errors import MissingDependencyError | ||
|
||
try: | ||
from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except ImportError as e: | ||
raise MissingDependencyError( | ||
import_name=e.name or "cleanlab-tlm", | ||
package_url="https://github.com/cleanlab/cleanlab-tlm", | ||
) from e | ||
|
||
if TYPE_CHECKING: | ||
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore | ||
from cleanlab_codex.validator import BadResponseThresholds | ||
|
||
|
||
"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad.""" | ||
DEFAULT_EVAL_METRICS = ["response_helpfulness"] | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def get_default_evaluations() -> list[Eval]: | ||
"""Get the default evaluations for the TrustworthyRAG. | ||
|
||
Note: | ||
This excludes trustworthiness, which is automatically computed by TrustworthyRAG. | ||
""" | ||
return [evaluation for evaluation in get_default_evals() if evaluation.name in DEFAULT_EVAL_METRICS] | ||
|
||
|
||
def get_default_trustworthyrag_config() -> dict[str, Any]: | ||
"""Get the default configuration for the TrustworthyRAG.""" | ||
return { | ||
"options": { | ||
"log": ["explanation"], | ||
}, | ||
} | ||
|
||
|
||
def update_scores_based_on_thresholds( | ||
scores: ThresholdedTrustworthyRAGScore, thresholds: BadResponseThresholds | ||
) -> None: | ||
"""Adds a `is_bad` flag to the scores dictionaries based on the thresholds.""" | ||
for eval_name, score_dict in scores.items(): | ||
score_dict.setdefault("is_bad", False) | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (score := score_dict["score"]) is not None: | ||
score_dict["is_bad"] = score < thresholds.get_threshold(eval_name) | ||
|
||
|
||
def is_bad_response( | ||
scores: TrustworthyRAGScore | ThresholdedTrustworthyRAGScore, | ||
thresholds: BadResponseThresholds, | ||
) -> bool: | ||
""" | ||
Check if the response is bad based on the scores computed by TrustworthyRAG and the config containing thresholds. | ||
""" | ||
for eval_metric, score_dict in scores.items(): | ||
score = score_dict["score"] | ||
if score is None: | ||
continue | ||
if score < thresholds.get_threshold(eval_metric): | ||
return True | ||
return False |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from cleanlab_tlm.utils.rag import EvalMetric | ||
|
||
|
||
class ThresholdedEvalMetric(EvalMetric): | ||
is_bad: bool | ||
|
||
|
||
ThresholdedEvalMetric.__doc__ = f""" | ||
{EvalMetric.__doc__} | ||
|
||
is_bad: bool | ||
Whether the score is a certain threshold. | ||
""" | ||
|
||
|
||
class ThresholdedTrustworthyRAGScore(dict[str, ThresholdedEvalMetric]): | ||
"""Object returned by `Validator.detect` containing evaluation scores from [TrustworthyRAGScore](/tlm/api/python/utils.rag/#class-trustworthyragscore) | ||
along with a boolean flag, `is_bad`, indicating whether the score is below the threshold. | ||
|
||
Example: | ||
```python | ||
{ | ||
"trustworthiness": { | ||
"score": 0.92, | ||
"log": {"explanation": "Did not find a reason to doubt trustworthiness."}, | ||
"is_bad": False | ||
}, | ||
"response_helpfulness": { | ||
"score": 0.35, | ||
"is_bad": True | ||
}, | ||
... | ||
} | ||
``` | ||
""" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
""" | ||
Leverage Cleanlab's Evals and Codex to detect and remediate bad responses in RAG applications. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Callable, Optional, cast | ||
|
||
from pydantic import BaseModel, Field, field_validator | ||
|
||
from cleanlab_codex.internal.validator import ( | ||
get_default_evaluations, | ||
get_default_trustworthyrag_config, | ||
) | ||
from cleanlab_codex.internal.validator import is_bad_response as _is_bad_response | ||
from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds | ||
from cleanlab_codex.project import Project | ||
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore | ||
from cleanlab_codex.utils.errors import MissingDependencyError | ||
|
||
try: | ||
from cleanlab_tlm import TrustworthyRAG | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except ImportError as e: | ||
raise MissingDependencyError( | ||
import_name=e.name or "cleanlab-tlm", | ||
package_url="https://github.com/cleanlab/cleanlab-tlm", | ||
) from e | ||
|
||
|
||
class BadResponseThresholds(BaseModel): | ||
"""Config for determining if a response is bad. | ||
Each key is an evaluation metric and the value is a threshold such that if the score is below the threshold, the response is bad. | ||
""" | ||
|
||
trustworthiness: float = Field( | ||
description="Threshold for trustworthiness.", | ||
default=0.5, | ||
ge=0.0, | ||
le=1.0, | ||
) | ||
response_helpfulness: float = Field( | ||
description="Threshold for response helpfulness.", | ||
default=0.5, | ||
ge=0.0, | ||
le=1.0, | ||
) | ||
|
||
@property | ||
def default_threshold(self) -> float: | ||
"""The default threshold to use when a specific evaluation metric's threshold is not set. This threshold is set to 0.5.""" | ||
return 0.5 | ||
|
||
def get_threshold(self, eval_name: str) -> float: | ||
"""Get threshold for an eval if it exists. | ||
|
||
For fields defined in the model, returns their value (which may be the field's default). | ||
For custom evals not defined in the model, returns the default threshold value (see `default_threshold`). | ||
""" | ||
|
||
# For fields defined in the model, use their value (which may be the field's default) | ||
if eval_name in self.model_fields: | ||
return cast(float, getattr(self, eval_name)) | ||
|
||
# For custom evals, use the default threshold | ||
return getattr(self, eval_name, self.default_threshold) | ||
|
||
@field_validator("*") | ||
@classmethod | ||
def validate_threshold(cls, v: Any) -> float: | ||
"""Validate that all fields (including dynamic ones) are floats between 0 and 1.""" | ||
if not isinstance(v, (int, float)): | ||
error_msg = f"Threshold must be a number, got {type(v)}" | ||
raise TypeError(error_msg) | ||
if not 0 <= float(v) <= 1: | ||
error_msg = f"Threshold must be between 0 and 1, got {v}" | ||
raise ValueError(error_msg) | ||
return float(v) | ||
|
||
model_config = { | ||
"extra": "allow" # Allow additional fields for custom eval thresholds | ||
} | ||
|
||
|
||
class Validator: | ||
def __init__( | ||
self, | ||
codex_access_key: str, | ||
tlm_api_key: Optional[str] = None, | ||
trustworthy_rag_config: Optional[dict[str, Any]] = None, | ||
bad_response_thresholds: Optional[dict[str, float]] = None, | ||
): | ||
"""Evaluates the quality of responses generated in RAG applications and remediates them if needed. | ||
|
||
This object combines Cleanlab's various Evals with thresholding to detect bad responses and remediates them with Codex. | ||
|
||
Args: | ||
codex_access_key (str): The [access key](/codex/web_tutorials/create_project/#access-keys) for a Codex project. | ||
tlm_api_key (Optional[str]): The API key for [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag). | ||
trustworthy_rag_config (Optional[dict[str, Any]]): Optional initialization arguments for [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag), which is used to detect response issues. | ||
bad_response_thresholds (Optional[dict[str, float]]): Detection score thresholds used to flag whether or not a response is considered bad. Each key in this dict corresponds to an Eval from TrustworthyRAG, and the value indicates a threshold below which scores from this Eval are considered detected issues. A response is flagged as bad if any issues are detected for it. | ||
""" | ||
trustworthy_rag_config = trustworthy_rag_config or get_default_trustworthyrag_config() | ||
jwmueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if tlm_api_key is not None and "api_key" in trustworthy_rag_config: | ||
error_msg = "Cannot specify both tlm_api_key and api_key in trustworthy_rag_config" | ||
raise ValueError(error_msg) | ||
if tlm_api_key is not None: | ||
trustworthy_rag_config["api_key"] = tlm_api_key | ||
|
||
self._project: Project = Project.from_access_key(access_key=codex_access_key) | ||
|
||
trustworthy_rag_config.setdefault("evals", get_default_evaluations()) | ||
jwmueller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._tlm_rag = TrustworthyRAG(**trustworthy_rag_config) | ||
|
||
# Validate that all the necessary thresholds are present in the TrustworthyRAG. | ||
_evals = [e.name for e in self._tlm_rag.get_evals()] + ["trustworthiness"] | ||
|
||
self._bad_response_thresholds = BadResponseThresholds.model_validate(bad_response_thresholds or {}) | ||
|
||
_threshold_keys = self._bad_response_thresholds.model_dump().keys() | ||
|
||
# Check if there are any thresholds without corresponding evals (this is an error) | ||
_extra_thresholds = set(_threshold_keys) - set(_evals) | ||
if _extra_thresholds: | ||
error_msg = f"Found thresholds for non-existent evaluation metrics: {_extra_thresholds}" | ||
raise ValueError(error_msg) | ||
|
||
def validate( | ||
self, | ||
query: str, | ||
context: str, | ||
response: str, | ||
prompt: Optional[str] = None, | ||
form_prompt: Optional[Callable[[str, str], str]] = None, | ||
) -> dict[str, Any]: | ||
"""Evaluate whether the AI-generated response is bad, and if so, request an alternate expert response. | ||
|
||
Args: | ||
query (str): The user query that was used to generate the response. | ||
context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. | ||
response (str): A reponse from your LLM/RAG system. | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
dict[str, Any]: A dictionary containing: | ||
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. | ||
- 'expert_answer': Alternate SME-provided answer from Codex, or None if no answer could be found in the Codex Project. | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- Additional keys: Various keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary, with raw scores from [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) for each evaluation metric. `is_bad` indicating whether the score is below the threshold. | ||
""" | ||
scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt) | ||
expert_answer = None | ||
if is_bad_response: | ||
expert_answer = self.remediate(query) | ||
|
||
return { | ||
"is_bad_response": is_bad_response, | ||
"expert_answer": expert_answer, | ||
**scores, | ||
} | ||
|
||
def detect( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a user just wants to detect bad responses, should they use TrustworthyRAG or Validate.detect? How is a user supposed to understand how these two relate to each other? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea (which docstring should be updated to reflect) was that Validator is just a version of TrustworthyRAG with different default evals & predetermined thresholds. The practical impact of those thresholds is they determine when we lookup things in Codex (what is logged in the Project for SME to answer, what gets answered by Codex instead of RAG app). But that impact is primarily realized in So we could make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That solution sounds fine to me—making it private, and updating the instructions to indicate that detect -> TrustworthyRAG, detect + remediate -> Validator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sgtm. @elisno can you also add an optional flag to This flag could be something like: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, we should keep the No need for another optional flag in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also include screenshot of tutorial where you show that it's clearly explained when to use validate() vs. detect() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I pushed more docstring changes to clearly distinguish these, so review those There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
https://github.com/cleanlab/cleanlab-studio-docs/pull/868#issuecomment-2756947611 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that screenshot does not explain the main reason to use detect(), which is to test/tune detection configurations like the evaluation score thresholds and TrustworthyRAG settings |
||
self, | ||
query: str, | ||
context: str, | ||
response: str, | ||
prompt: Optional[str] = None, | ||
form_prompt: Optional[Callable[[str, str], str]] = None, | ||
) -> tuple[ThresholdedTrustworthyRAGScore, bool]: | ||
"""Evaluate the response quality using TrustworthyRAG and determine if it is a bad response. | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
query (str): The user query that was used to generate the response. | ||
context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. | ||
response (str): A reponse from your LLM/RAG system. | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
tuple[ThresholdedTrustworthyRAGScore, bool]: A tuple containing: | ||
- ThresholdedTrustworthyRAGScore: Quality scores for different evaluation metrics like trustworthiness | ||
and response helpfulness. Each metric has a score between 0-1. It also has a boolean flag, `is_bad` indicating whether the score is below a given threshold. | ||
- bool: True if the response is determined to be bad based on the evaluation scores | ||
and configured thresholds, False otherwise. | ||
""" | ||
scores = cast( | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ThresholdedTrustworthyRAGScore, | ||
self._tlm_rag.score( | ||
response=response, | ||
query=query, | ||
context=context, | ||
prompt=prompt, | ||
form_prompt=form_prompt, | ||
), | ||
) | ||
|
||
_update_scores_based_on_thresholds(scores, thresholds=self._bad_response_thresholds) | ||
|
||
is_bad_response = _is_bad_response(scores, self._bad_response_thresholds) | ||
return scores, is_bad_response | ||
|
||
def remediate(self, query: str) -> str | None: | ||
"""Request a SME-provided answer for this query, if one is available in Codex. | ||
|
||
Args: | ||
query (str): The user's original query to get SME-provided answer for. | ||
|
||
Returns: | ||
str | None: The SME-provided answer from Codex, or None if no answer could be found in the Codex Project. | ||
""" | ||
codex_answer, _ = self._project.query(question=query) | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return codex_answer |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import cast | ||
|
||
import pytest | ||
from cleanlab_tlm.utils.rag import TrustworthyRAGScore | ||
|
||
from cleanlab_codex.internal.validator import get_default_evaluations, is_bad_response | ||
from cleanlab_codex.validator import BadResponseThresholds | ||
|
||
|
||
def make_scores(trustworthiness: float, response_helpfulness: float) -> TrustworthyRAGScore: | ||
scores = { | ||
"trustworthiness": { | ||
"score": trustworthiness, | ||
}, | ||
"response_helpfulness": { | ||
"score": response_helpfulness, | ||
}, | ||
} | ||
return cast(TrustworthyRAGScore, scores) | ||
|
||
|
||
def make_is_bad_response_config(trustworthiness: float, response_helpfulness: float) -> BadResponseThresholds: | ||
return BadResponseThresholds( | ||
trustworthiness=trustworthiness, | ||
response_helpfulness=response_helpfulness, | ||
) | ||
|
||
|
||
def test_get_default_evaluations() -> None: | ||
assert {evaluation.name for evaluation in get_default_evaluations()} == {"response_helpfulness"} | ||
|
||
|
||
class TestIsBadResponse: | ||
@pytest.fixture | ||
def scores(self) -> TrustworthyRAGScore: | ||
return make_scores(0.92, 0.75) | ||
|
||
@pytest.fixture | ||
def custom_is_bad_response_config(self) -> BadResponseThresholds: | ||
return make_is_bad_response_config(0.6, 0.7) | ||
|
||
def test_thresholds(self, scores: TrustworthyRAGScore) -> None: | ||
# High trustworthiness_threshold | ||
is_bad_response_config = make_is_bad_response_config(0.921, 0.5) | ||
assert is_bad_response(scores, is_bad_response_config) | ||
|
||
# High response_helpfulness_threshold | ||
is_bad_response_config = make_is_bad_response_config(0.5, 0.751) | ||
assert is_bad_response(scores, is_bad_response_config) | ||
|
||
def test_scores(self, custom_is_bad_response_config: BadResponseThresholds) -> None: | ||
scores = make_scores(0.59, 0.7) | ||
assert is_bad_response(scores, custom_is_bad_response_config) |
Uh oh!
There was an error while loading. Please reload this page.