diff --git a/CHANGELOG.md b/CHANGELOG.md index 17e7ab3..7e26398 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.0.5] - 2025-03-27 + +- Add `Validator` API +- Deprecate `response_validation.py` module. + ## [1.0.4] - 2025-03-14 - Pass analytics metadata in headers for all Codex API requests. @@ -29,7 +34,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Initial release of the `cleanlab-codex` client library. -[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.4...HEAD +[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.5...HEAD +[1.0.5]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.4...v1.0.5 [1.0.4]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.3...v1.0.4 [1.0.3]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.2...v1.0.3 [1.0.2]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.1...v1.0.2 diff --git a/pyproject.toml b/pyproject.toml index 8fc930e..ecba729 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "cleanlab-tlm~=1.0.12", "codex-sdk==0.1.0a12", "pydantic>=2.0.0, <3", ] diff --git a/src/cleanlab_codex/__init__.py b/src/cleanlab_codex/__init__.py index d1b8ef6..572a626 100644 --- a/src/cleanlab_codex/__init__.py +++ b/src/cleanlab_codex/__init__.py @@ -2,5 +2,6 @@ from cleanlab_codex.client import Client from cleanlab_codex.codex_tool import CodexTool from cleanlab_codex.project import Project +from cleanlab_codex.validator import Validator -__all__ = ["Client", "CodexTool", "Project"] +__all__ = ["Client", "CodexTool", "Project", "Validator"] diff --git a/src/cleanlab_codex/internal/validator.py b/src/cleanlab_codex/internal/validator.py new file mode 100644 index 0000000..0914c02 --- /dev/null +++ b/src/cleanlab_codex/internal/validator.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Sequence, cast + +from cleanlab_tlm.utils.rag import Eval, TrustworthyRAGScore, get_default_evals + +from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore + +if TYPE_CHECKING: + from cleanlab_codex.validator import BadResponseThresholds + + +"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad.""" +DEFAULT_EVAL_METRICS = ["response_helpfulness"] + + +def get_default_evaluations() -> list[Eval]: + """Get the default evaluations for the TrustworthyRAG. + + Note: + This excludes trustworthiness, which is automatically computed by TrustworthyRAG. + """ + return [evaluation for evaluation in get_default_evals() if evaluation.name in DEFAULT_EVAL_METRICS] + + +def get_default_trustworthyrag_config() -> dict[str, Any]: + """Get the default configuration for the TrustworthyRAG.""" + return { + "options": { + "log": ["explanation"], + }, + } + + +def update_scores_based_on_thresholds( + scores: TrustworthyRAGScore | Sequence[TrustworthyRAGScore], thresholds: BadResponseThresholds +) -> ThresholdedTrustworthyRAGScore: + """Adds a `is_bad` flag to the scores dictionaries based on the thresholds.""" + + # Helper function to check if a score is bad + def is_bad(score: Optional[float], threshold: float) -> bool: + return score is not None and score < threshold + + if isinstance(scores, Sequence): + raise NotImplementedError("Batching is not supported yet.") + + thresholded_scores = {} + for eval_name, score_dict in scores.items(): + thresholded_scores[eval_name] = { + **score_dict, + "is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)), + } + return cast(ThresholdedTrustworthyRAGScore, thresholded_scores) diff --git a/src/cleanlab_codex/response_validation.py b/src/cleanlab_codex/response_validation.py index e3bf78a..dca2877 100644 --- a/src/cleanlab_codex/response_validation.py +++ b/src/cleanlab_codex/response_validation.py @@ -1,5 +1,7 @@ """ -Validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives. +This module is now superseded by this [Validator API](/codex/api/validator/). + +Deprecated validation functions for evaluating LLM responses and determining if they should be replaced with Codex-generated alternatives. """ from __future__ import annotations diff --git a/src/cleanlab_codex/types/response_validation.py b/src/cleanlab_codex/types/response_validation.py index 1479e2d..e0e5b26 100644 --- a/src/cleanlab_codex/types/response_validation.py +++ b/src/cleanlab_codex/types/response_validation.py @@ -1,4 +1,7 @@ -"""Types for response validation.""" +""" +This module is now superseded by this [Validator API](/codex/api/validator/). + +Deprecated types for response validation.""" from abc import ABC, abstractmethod from collections import OrderedDict diff --git a/src/cleanlab_codex/types/validator.py b/src/cleanlab_codex/types/validator.py new file mode 100644 index 0000000..930273f --- /dev/null +++ b/src/cleanlab_codex/types/validator.py @@ -0,0 +1,35 @@ +from cleanlab_tlm.utils.rag import EvalMetric + + +class ThresholdedEvalMetric(EvalMetric): + is_bad: bool + + +ThresholdedEvalMetric.__doc__ = f""" +{EvalMetric.__doc__} + +is_bad: bool + Whether the score is a certain threshold. +""" + + +class ThresholdedTrustworthyRAGScore(dict[str, ThresholdedEvalMetric]): + """Object returned by `Validator.detect` containing evaluation scores from [TrustworthyRAGScore](/tlm/api/python/utils.rag/#class-trustworthyragscore) + along with a boolean flag, `is_bad`, indicating whether the score is below the threshold. + + Example: + ```python + { + "trustworthiness": { + "score": 0.92, + "log": {"explanation": "Did not find a reason to doubt trustworthiness."}, + "is_bad": False + }, + "response_helpfulness": { + "score": 0.35, + "is_bad": True + }, + ... + } + ``` + """ diff --git a/src/cleanlab_codex/validator.py b/src/cleanlab_codex/validator.py new file mode 100644 index 0000000..81365b3 --- /dev/null +++ b/src/cleanlab_codex/validator.py @@ -0,0 +1,241 @@ +""" +Detect and remediate bad responses in RAG applications, by integrating Codex as-a-Backup. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Optional, cast + +from cleanlab_tlm import TrustworthyRAG +from pydantic import BaseModel, Field, field_validator + +from cleanlab_codex.internal.validator import ( + get_default_evaluations, + get_default_trustworthyrag_config, +) +from cleanlab_codex.internal.validator import update_scores_based_on_thresholds as _update_scores_based_on_thresholds +from cleanlab_codex.project import Project + +if TYPE_CHECKING: + from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore + + +class Validator: + def __init__( + self, + codex_access_key: str, + tlm_api_key: Optional[str] = None, + trustworthy_rag_config: Optional[dict[str, Any]] = None, + bad_response_thresholds: Optional[dict[str, float]] = None, + ): + """Real-time detection and remediation of bad responses in RAG applications, powered by Cleanlab's TrustworthyRAG and Codex. + + This object combines Cleanlab's TrustworthyRAG evaluation scores with configurable thresholds to detect potentially bad responses + in your RAG application. When a bad response is detected, this Validator automatically attempts to remediate by retrieving an expert-provided + answer from the Codex Project you've integrated with your RAG app. If no expert answer is available, + the corresponding query is logged in the Codex Project for SMEs to answer. + + For production, use the `validate()` method which provides a complete validation workflow including both detection and remediation. + A `detect()` method is separately available for you to test/tune detection configurations like score thresholds and TrustworthyRAG settings + without triggering any Codex lookups that otherwise could affect the state of the corresponding Codex Project. + + Args: + codex_access_key (str): The [access key](/codex/web_tutorials/create_project/#access-keys) for a Codex project. Used to retrieve expert-provided answers + when bad responses are detected, or otherwise log the corresponding queries for SMEs to answer. + + tlm_api_key (str, optional): API key for accessing [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag). If not provided, this must be specified + in `trustworthy_rag_config`. + + trustworthy_rag_config (dict[str, Any], optional): Optional initialization arguments for [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag), + which is used to detect response issues. If not provided, a default configuration will be used. + By default, this Validator uses the same default configurations as [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag), except: + - Explanations are returned in logs for better debugging + - Only the `response_helpfulness` eval is run + + bad_response_thresholds (dict[str, float], optional): Detection score thresholds used to flag whether + a response is bad or not. Each key corresponds to an Eval from [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag), + and the value indicates a threshold (between 0 and 1) below which Eval scores are treated as detected issues. A response + is flagged as bad if any issues are detected. If not provided, default thresholds will be used. See + [`BadResponseThresholds`](/codex/api/python/validator/#class-badresponsethresholds) for more details. + + Raises: + ValueError: If both tlm_api_key and api_key in trustworthy_rag_config are provided. + ValueError: If bad_response_thresholds contains thresholds for non-existent evaluation metrics. + TypeError: If any threshold value is not a number. + ValueError: If any threshold value is not between 0 and 1. + """ + trustworthy_rag_config = trustworthy_rag_config or get_default_trustworthyrag_config() + if tlm_api_key is not None and "api_key" in trustworthy_rag_config: + error_msg = "Cannot specify both tlm_api_key and api_key in trustworthy_rag_config" + raise ValueError(error_msg) + if tlm_api_key is not None: + trustworthy_rag_config["api_key"] = tlm_api_key + + self._project: Project = Project.from_access_key(access_key=codex_access_key) + + trustworthy_rag_config.setdefault("evals", get_default_evaluations()) + self._tlm_rag = TrustworthyRAG(**trustworthy_rag_config) + + # Validate that all the necessary thresholds are present in the TrustworthyRAG. + _evals = [e.name for e in self._tlm_rag.get_evals()] + ["trustworthiness"] + + self._bad_response_thresholds = BadResponseThresholds.model_validate(bad_response_thresholds or {}) + + _threshold_keys = self._bad_response_thresholds.model_dump().keys() + + # Check if there are any thresholds without corresponding evals (this is an error) + _extra_thresholds = set(_threshold_keys) - set(_evals) + if _extra_thresholds: + error_msg = f"Found thresholds for non-existent evaluation metrics: {_extra_thresholds}" + raise ValueError(error_msg) + + def validate( + self, + query: str, + context: str, + response: str, + prompt: Optional[str] = None, + form_prompt: Optional[Callable[[str, str], str]] = None, + ) -> dict[str, Any]: + """Evaluate whether the AI-generated response is bad, and if so, request an alternate expert answer. + If no expert answer is available, this query is still logged for SMEs to answer. + + Args: + query (str): The user query that was used to generate the response. + context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. + response (str): A reponse from your LLM/RAG system. + prompt (str, optional): Optional prompt representing the actual inputs (combining query, context, and system instructions into one string) to the LLM that generated the response. + form_prompt (Callable[[str, str], str], optional): Optional function to format the prompt based on query and context. Cannot be provided together with prompt, provide one or the other. This function should take query and context as parameters and return a formatted prompt string. If not provided, a default prompt formatter will be used. To include a system prompt or any other special instructions for your LLM, incorporate them directly in your custom form_prompt() function definition. + + Returns: + dict[str, Any]: A dictionary containing: + - 'expert_answer': Alternate SME-provided answer from Codex if the response was flagged as bad and an answer was found in the Codex Project, or None otherwise. + - 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer. + - Additional keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold. + """ + scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt) + expert_answer = None + if is_bad_response: + expert_answer = self._remediate(query) + + return { + "expert_answer": expert_answer, + "is_bad_response": is_bad_response, + **scores, + } + + def detect( + self, + query: str, + context: str, + response: str, + prompt: Optional[str] = None, + form_prompt: Optional[Callable[[str, str], str]] = None, + ) -> tuple[ThresholdedTrustworthyRAGScore, bool]: + """Score response quality using TrustworthyRAG and flag bad responses based on configured thresholds. + + Note: + Use this method instead of `validate()` to test/tune detection configurations like score thresholds and TrustworthyRAG settings. + This `detect()` method will not affect your Codex Project, whereas `validate()` will log queries whose response was detected as bad into the Codex Project and is thus only suitable for production, not testing. + Both this method and `validate()` rely on this same detection logic, so you can use this method to first optimize detections and then switch to using `validate()`. + + Args: + query (str): The user query that was used to generate the response. + context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. + response (str): A reponse from your LLM/RAG system. + prompt (str, optional): Optional prompt representing the actual inputs (combining query, context, and system instructions into one string) to the LLM that generated the response. + form_prompt (Callable[[str, str], str], optional): Optional function to format the prompt based on query and context. Cannot be provided together with prompt, provide one or the other. This function should take query and context as parameters and return a formatted prompt string. If not provided, a default prompt formatter will be used. To include a system prompt or any other special instructions for your LLM, incorporate them directly in your custom form_prompt() function definition. + + Returns: + tuple[ThresholdedTrustworthyRAGScore, bool]: A tuple containing: + - ThresholdedTrustworthyRAGScore: Quality scores for different evaluation metrics like trustworthiness + and response helpfulness. Each metric has a score between 0-1. It also has a boolean flag, `is_bad` indicating whether the score is below a given threshold. + - bool: True if the response is determined to be bad based on the evaluation scores + and configured thresholds, False otherwise. + """ + scores = self._tlm_rag.score( + response=response, + query=query, + context=context, + prompt=prompt, + form_prompt=form_prompt, + ) + + thresholded_scores = _update_scores_based_on_thresholds( + scores=scores, + thresholds=self._bad_response_thresholds, + ) + + is_bad_response = any(score_dict["is_bad"] for score_dict in thresholded_scores.values()) + return thresholded_scores, is_bad_response + + def _remediate(self, query: str) -> str | None: + """Request a SME-provided answer for this query, if one is available in Codex. + + Args: + query (str): The user's original query to get SME-provided answer for. + + Returns: + str | None: The SME-provided answer from Codex, or None if no answer could be found in the Codex Project. + """ + codex_answer, _ = self._project.query(question=query) + return codex_answer + + +class BadResponseThresholds(BaseModel): + """Config for determining if a response is bad. + Each key is an evaluation metric and the value is a threshold such that a response is considered bad whenever the corresponding evaluation score falls below the threshold. + + Default Thresholds: + - trustworthiness: 0.5 + - response_helpfulness: 0.5 + - Any custom eval: 0.5 (if not explicitly specified in bad_response_thresholds) + """ + + trustworthiness: float = Field( + description="Threshold for trustworthiness.", + default=0.5, + ge=0.0, + le=1.0, + ) + response_helpfulness: float = Field( + description="Threshold for response helpfulness.", + default=0.5, + ge=0.0, + le=1.0, + ) + + @property + def default_threshold(self) -> float: + """The default threshold to use when an evaluation metric's threshold is not specified. This threshold is set to 0.5.""" + return 0.5 + + def get_threshold(self, eval_name: str) -> float: + """Get threshold for an eval, if it exists. + + For fields defined in the model, returns their value (which may be the field's default). + For custom evals not defined in the model, returns the default threshold value (see `default_threshold`). + """ + + # For fields defined in the model, use their value (which may be the field's default) + if eval_name in self.model_fields: + return cast(float, getattr(self, eval_name)) + + # For custom evals, use the default threshold + return getattr(self, eval_name, self.default_threshold) + + @field_validator("*") + @classmethod + def validate_threshold(cls, v: Any) -> float: + """Validate that all fields (including dynamic ones) are floats between 0 and 1.""" + if not isinstance(v, (int, float)): + error_msg = f"Threshold must be a number, got {type(v)}" + raise TypeError(error_msg) + if not 0 <= float(v) <= 1: + error_msg = f"Threshold must be between 0 and 1, got {v}" + raise ValueError(error_msg) + return float(v) + + model_config = { + "extra": "allow" # Allow additional fields for custom eval thresholds + } diff --git a/tests/internal/test_validator.py b/tests/internal/test_validator.py new file mode 100644 index 0000000..b2d059e --- /dev/null +++ b/tests/internal/test_validator.py @@ -0,0 +1,29 @@ +from typing import cast + +from cleanlab_tlm.utils.rag import TrustworthyRAGScore + +from cleanlab_codex.internal.validator import get_default_evaluations +from cleanlab_codex.validator import BadResponseThresholds + + +def make_scores(trustworthiness: float, response_helpfulness: float) -> TrustworthyRAGScore: + scores = { + "trustworthiness": { + "score": trustworthiness, + }, + "response_helpfulness": { + "score": response_helpfulness, + }, + } + return cast(TrustworthyRAGScore, scores) + + +def make_is_bad_response_config(trustworthiness: float, response_helpfulness: float) -> BadResponseThresholds: + return BadResponseThresholds( + trustworthiness=trustworthiness, + response_helpfulness=response_helpfulness, + ) + + +def test_get_default_evaluations() -> None: + assert {evaluation.name for evaluation in get_default_evaluations()} == {"response_helpfulness"} diff --git a/tests/test_validator.py b/tests/test_validator.py new file mode 100644 index 0000000..3193dbf --- /dev/null +++ b/tests/test_validator.py @@ -0,0 +1,140 @@ +from typing import Generator +from unittest.mock import Mock, patch + +import pytest +from pydantic import ValidationError + +from cleanlab_codex.validator import BadResponseThresholds, Validator + + +class TestBadResponseThresholds: + def test_get_threshold(self) -> None: + thresholds = BadResponseThresholds( + trustworthiness=0.5, + response_helpfulness=0.5, + ) + assert thresholds.get_threshold("trustworthiness") == 0.5 + assert thresholds.get_threshold("response_helpfulness") == 0.5 + + def test_default_threshold(self) -> None: + thresholds = BadResponseThresholds() + assert thresholds.get_threshold("trustworthiness") == 0.5 + assert thresholds.get_threshold("response_helpfulness") == 0.5 + + def test_unspecified_threshold(self) -> None: + thresholds = BadResponseThresholds() + assert thresholds.get_threshold("unspecified_threshold") == 0.5 + + def test_threshold_value(self) -> None: + thresholds = BadResponseThresholds(valid_threshold=0.3) # type: ignore + assert thresholds.get_threshold("valid_threshold") == 0.3 + assert thresholds.valid_threshold == 0.3 # type: ignore + + def test_invalid_threshold_value(self) -> None: + with pytest.raises(ValidationError): + BadResponseThresholds(trustworthiness=1.1) + + with pytest.raises(ValidationError): + BadResponseThresholds(response_helpfulness=-0.1) + + def test_invalid_threshold_type(self) -> None: + with pytest.raises(ValidationError): + BadResponseThresholds(trustworthiness="not a number") # type: ignore + + +@pytest.fixture +def mock_project() -> Generator[Mock, None, None]: + with patch("cleanlab_codex.validator.Project") as mock: + mock.from_access_key.return_value = Mock() + yield mock + + +@pytest.fixture +def mock_trustworthy_rag() -> Generator[Mock, None, None]: + mock = Mock() + mock.score.return_value = { + "trustworthiness": {"score": 0.8, "is_bad": False}, + "response_helpfulness": {"score": 0.7, "is_bad": False}, + } + eval_mock = Mock() + eval_mock.name = "response_helpfulness" + mock.get_evals.return_value = [eval_mock] + with patch("cleanlab_codex.validator.TrustworthyRAG") as mock_class: + mock_class.return_value = mock + yield mock_class + + +class TestValidator: + def test_init(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: + Validator(codex_access_key="test") + + # Verify Project was initialized with access key + mock_project.from_access_key.assert_called_once_with(access_key="test") + + # Verify TrustworthyRAG was initialized with default config + mock_trustworthy_rag.assert_called_once() + + def test_init_with_tlm_api_key(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 + Validator(codex_access_key="test", tlm_api_key="tlm-key") + + # Verify TrustworthyRAG was initialized with API key + config = mock_trustworthy_rag.call_args[1] + assert config["api_key"] == "tlm-key" + + def test_init_with_config_conflict(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 + with pytest.raises(ValueError, match="Cannot specify both tlm_api_key and api_key in trustworthy_rag_config"): + Validator(codex_access_key="test", tlm_api_key="tlm-key", trustworthy_rag_config={"api_key": "config-key"}) + + def test_validate(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 + validator = Validator(codex_access_key="test") + + result = validator.validate(query="test query", context="test context", response="test response") + + # Verify TrustworthyRAG.score was called + mock_trustworthy_rag.return_value.score.assert_called_once_with( + response="test response", query="test query", context="test context", prompt=None, form_prompt=None + ) + + # Verify expected result structure + assert result["is_bad_response"] is False + assert result["expert_answer"] is None + + eval_metrics = ["trustworthiness", "response_helpfulness"] + for metric in eval_metrics: + assert metric in result + assert "score" in result[metric] + assert "is_bad" in result[metric] + + def test_validate_expert_answer(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 + # Setup mock project query response + mock_project.from_access_key.return_value.query.return_value = ("expert answer", None) + + # Basically any response will be flagged as untrustworthy + validator = Validator(codex_access_key="test", bad_response_thresholds={"trustworthiness": 1.0}) + result = validator.validate(query="test query", context="test context", response="test response") + assert result["expert_answer"] == "expert answer" + + mock_project.from_access_key.return_value.query.return_value = (None, None) + result = validator.validate(query="test query", context="test context", response="test response") + assert result["expert_answer"] is None + + def test_detect(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 + validator = Validator(codex_access_key="test") + + scores, is_bad = validator.detect(query="test query", context="test context", response="test response") + + # Verify scores match mock return value + assert scores["trustworthiness"]["score"] == 0.8 + assert scores["response_helpfulness"]["score"] == 0.7 + assert not is_bad # Since mock scores are above default thresholds + + def test_remediate(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002 + # Setup mock project query response + mock_project.from_access_key.return_value.query.return_value = ("expert answer", None) + + validator = Validator(codex_access_key="test") + result = validator._remediate("test query") # noqa: SLF001 + + # Verify project.query was called + mock_project.from_access_key.return_value.query.assert_called_once_with(question="test query") + assert result == "expert answer"