Skip to content

Commit ad5ec94

Browse files
authored
Added functionality for adding metadata using validate api (#72)
1 parent e57fd31 commit ad5ec94

File tree

6 files changed

+175
-10
lines changed

6 files changed

+175
-10
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
## [1.0.12] - 2025-04-17
11+
12+
- Support adding metadata in `validate()` method in Validator API.
13+
1014
## [1.0.11] - 2025-04-16
1115

1216
- Update default thresholds for custom evals to 0.0 in `Validator` API.
@@ -59,7 +63,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5963

6064
- Initial release of the `cleanlab-codex` client library.
6165

62-
[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.11...HEAD
66+
[Unreleased]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.12...HEAD
67+
[1.0.12]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.11...v1.0.12
6368
[1.0.11]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.10...v1.0.11
6469
[1.0.10]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.9...v1.0.10
6570
[1.0.9]: https://github.com/cleanlab/cleanlab-codex/compare/v1.0.8...v1.0.9

src/cleanlab_codex/__about__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# SPDX-License-Identifier: MIT
2-
__version__ = "1.0.11"
2+
__version__ = "1.0.12"

src/cleanlab_codex/internal/validator.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
"""Evaluation metrics (excluding trustworthiness) that are used to determine if a response is bad."""
1414
DEFAULT_EVAL_METRICS = ["response_helpfulness"]
1515

16+
# Simple mappings for is_bad keys
17+
_SCORE_TO_IS_BAD_KEY = {
18+
"trustworthiness": "is_not_trustworthy",
19+
"query_ease": "is_not_query_easy",
20+
"response_helpfulness": "is_not_response_helpful",
21+
"context_sufficiency": "is_not_context_sufficient",
22+
}
23+
1624

1725
def get_default_evaluations() -> list[Eval]:
1826
"""Get the default evaluations for the TrustworthyRAG.
@@ -51,3 +59,52 @@ def is_bad(score: Optional[float], threshold: float) -> bool:
5159
"is_bad": is_bad(score_dict["score"], thresholds.get_threshold(eval_name)),
5260
}
5361
return cast(ThresholdedTrustworthyRAGScore, thresholded_scores)
62+
63+
64+
def process_score_metadata(scores: ThresholdedTrustworthyRAGScore, thresholds: BadResponseThresholds) -> dict[str, Any]:
65+
"""Process scores into metadata format with standardized keys.
66+
67+
Args:
68+
scores: The ThresholdedTrustworthyRAGScore containing evaluation results
69+
thresholds: The BadResponseThresholds configuration
70+
71+
Returns:
72+
dict: A dictionary containing evaluation scores and their corresponding metadata
73+
"""
74+
metadata: dict[str, Any] = {}
75+
76+
# Process scores and add to metadata
77+
for metric, score_data in scores.items():
78+
metadata[metric] = score_data["score"]
79+
80+
# Add is_bad flags with standardized naming
81+
is_bad_key = _SCORE_TO_IS_BAD_KEY.get(metric, f"is_not_{metric}")
82+
metadata[is_bad_key] = score_data["is_bad"]
83+
84+
# Special case for trustworthiness explanation
85+
if metric == "trustworthiness" and "log" in score_data and "explanation" in score_data["log"]:
86+
metadata["explanation"] = score_data["log"]["explanation"]
87+
88+
# Add thresholds to metadata
89+
thresholds_dict = thresholds.model_dump()
90+
for metric in {k for k in scores if k not in thresholds_dict}:
91+
thresholds_dict[metric] = thresholds.get_threshold(metric)
92+
metadata["thresholds"] = thresholds_dict
93+
94+
# TODO: Remove this as the backend can infer this from the is_bad flags
95+
metadata["label"] = _get_label(metadata)
96+
97+
return metadata
98+
99+
100+
def _get_label(metadata: dict[str, Any]) -> str:
101+
def is_bad(metric: str) -> bool:
102+
return bool(metadata.get(_SCORE_TO_IS_BAD_KEY[metric], False))
103+
104+
if is_bad("context_sufficiency"):
105+
return "search_failure"
106+
if is_bad("response_helpfulness") or is_bad("query_ease"):
107+
return "unhelpful"
108+
if is_bad("trustworthiness"):
109+
return "hallucination"
110+
return "other_issues"

src/cleanlab_codex/validator.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from __future__ import annotations
66

7+
from copy import deepcopy
78
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
89

910
from cleanlab_tlm import TrustworthyRAG
@@ -13,6 +14,9 @@
1314
get_default_evaluations,
1415
get_default_trustworthyrag_config,
1516
)
17+
from cleanlab_codex.internal.validator import (
18+
process_score_metadata as _process_score_metadata,
19+
)
1620
from cleanlab_codex.internal.validator import (
1721
update_scores_based_on_thresholds as _update_scores_based_on_thresholds,
1822
)
@@ -100,11 +104,14 @@ def __init__(
100104

101105
def validate(
102106
self,
107+
*,
103108
query: str,
104109
context: str,
105110
response: str,
106111
prompt: Optional[str] = None,
107112
form_prompt: Optional[Callable[[str, str], str]] = None,
113+
metadata: Optional[dict[str, Any]] = None,
114+
log_results: bool = True,
108115
) -> dict[str, Any]:
109116
"""Evaluate whether the AI-generated response is bad, and if so, request an alternate expert answer.
110117
If no expert answer is available, this query is still logged for SMEs to answer.
@@ -122,10 +129,16 @@ def validate(
122129
- 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer.
123130
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
124131
"""
125-
scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt)
132+
scores, is_bad_response = self.detect(
133+
query=query, context=context, response=response, prompt=prompt, form_prompt=form_prompt
134+
)
126135
expert_answer = None
127136
if is_bad_response:
128-
expert_answer = self._remediate(query)
137+
final_metadata = deepcopy(metadata) if metadata else {}
138+
if log_results:
139+
processed_metadata = _process_score_metadata(scores, self._bad_response_thresholds)
140+
final_metadata.update(processed_metadata)
141+
expert_answer = self._remediate(query=query, metadata=final_metadata)
129142

130143
return {
131144
"expert_answer": expert_answer,
@@ -135,11 +148,14 @@ def validate(
135148

136149
async def validate_async(
137150
self,
151+
*,
138152
query: str,
139153
context: str,
140154
response: str,
141155
prompt: Optional[str] = None,
142156
form_prompt: Optional[Callable[[str, str], str]] = None,
157+
metadata: Optional[dict[str, Any]] = None,
158+
log_results: bool = True,
143159
) -> dict[str, Any]:
144160
"""Evaluate whether the AI-generated response is bad, and if so, request an alternate expert answer.
145161
If no expert answer is available, this query is still logged for SMEs to answer.
@@ -158,9 +174,14 @@ async def validate_async(
158174
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
159175
"""
160176
scores, is_bad_response = await self.detect_async(query, context, response, prompt, form_prompt)
177+
final_metadata = metadata.copy() if metadata else {}
178+
if log_results:
179+
processed_metadata = _process_score_metadata(scores, self._bad_response_thresholds)
180+
final_metadata.update(processed_metadata)
181+
161182
expert_answer = None
162183
if is_bad_response:
163-
expert_answer = self._remediate(query)
184+
expert_answer = self._remediate(query=query, metadata=final_metadata)
164185

165186
return {
166187
"expert_answer": expert_answer,
@@ -170,6 +191,7 @@ async def validate_async(
170191

171192
def detect(
172193
self,
194+
*,
173195
query: str,
174196
context: str,
175197
response: str,
@@ -258,7 +280,7 @@ async def detect_async(
258280
is_bad_response = any(score_dict["is_bad"] for score_dict in thresholded_scores.values())
259281
return thresholded_scores, is_bad_response
260282

261-
def _remediate(self, query: str) -> str | None:
283+
def _remediate(self, *, query: str, metadata: dict[str, Any] | None = None) -> str | None:
262284
"""Request a SME-provided answer for this query, if one is available in Codex.
263285
264286
Args:
@@ -267,7 +289,7 @@ def _remediate(self, query: str) -> str | None:
267289
Returns:
268290
str | None: The SME-provided answer from Codex, or None if no answer could be found in the Codex Project.
269291
"""
270-
codex_answer, _ = self._project.query(question=query)
292+
codex_answer, _ = self._project.query(question=query, metadata=metadata)
271293
return codex_answer
272294

273295

tests/internal/test_validator.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from cleanlab_tlm.utils.rag import TrustworthyRAGScore
44

5-
from cleanlab_codex.internal.validator import get_default_evaluations
5+
from cleanlab_codex.internal.validator import (
6+
get_default_evaluations,
7+
process_score_metadata,
8+
update_scores_based_on_thresholds,
9+
)
10+
from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore
611
from cleanlab_codex.validator import BadResponseThresholds
712

813

@@ -27,3 +32,79 @@ def make_is_bad_response_config(trustworthiness: float, response_helpfulness: fl
2732

2833
def test_get_default_evaluations() -> None:
2934
assert {evaluation.name for evaluation in get_default_evaluations()} == {"response_helpfulness"}
35+
36+
37+
def test_process_score_metadata() -> None:
38+
# Create test scores with various metrics
39+
thresholded_scores = {
40+
"trustworthiness": {"score": 0.8, "is_bad": False, "log": {"explanation": "Test explanation"}},
41+
"response_helpfulness": {"score": 0.6, "is_bad": True},
42+
"query_ease": {"score": 0.9, "is_bad": False},
43+
}
44+
45+
thresholds = BadResponseThresholds(trustworthiness=0.7, response_helpfulness=0.7)
46+
47+
metadata = process_score_metadata(cast(ThresholdedTrustworthyRAGScore, thresholded_scores), thresholds)
48+
49+
# Check scores and flags
50+
expected_metadata = {
51+
"trustworthiness": 0.8,
52+
"response_helpfulness": 0.6,
53+
"query_ease": 0.9,
54+
"is_not_trustworthy": False,
55+
"is_not_response_helpful": True,
56+
"is_not_query_easy": False,
57+
"explanation": "Test explanation",
58+
"thresholds": {"trustworthiness": 0.7, "response_helpfulness": 0.7, "query_ease": 0.0},
59+
"label": "unhelpful",
60+
}
61+
62+
assert metadata == expected_metadata
63+
64+
65+
def test_process_score_metadata_edge_cases() -> None:
66+
"""Test edge cases for process_score_metadata."""
67+
thresholds = BadResponseThresholds()
68+
69+
# Test empty scores
70+
metadata_for_empty_scores = process_score_metadata(cast(ThresholdedTrustworthyRAGScore, {}), thresholds)
71+
assert {"thresholds", "label"} == set(metadata_for_empty_scores.keys())
72+
73+
# Test missing explanation
74+
scores = cast(ThresholdedTrustworthyRAGScore, {"trustworthiness": {"score": 0.6, "is_bad": True}})
75+
metadata_missing_explanation = process_score_metadata(scores, thresholds)
76+
assert "explanation" not in metadata_missing_explanation
77+
78+
# Test custom metric
79+
scores = cast(ThresholdedTrustworthyRAGScore, {"my_metric": {"score": 0.3, "is_bad": True}})
80+
metadata_custom_metric = process_score_metadata(scores, thresholds)
81+
assert metadata_custom_metric["my_metric"] == 0.3
82+
assert metadata_custom_metric["is_not_my_metric"] is True
83+
84+
85+
def test_update_scores_based_on_thresholds() -> None:
86+
"""Test that update_scores_based_on_thresholds correctly flags scores based on thresholds."""
87+
raw_scores = cast(
88+
TrustworthyRAGScore,
89+
{
90+
"trustworthiness": {"score": 0.6}, # Below threshold
91+
"response_helpfulness": {"score": 0.8}, # Above threshold
92+
"custom_metric": {"score": 0.4}, # Below custom threshold
93+
"another_metric": {"score": 0.6}, # Uses default threshold
94+
},
95+
)
96+
97+
thresholds = BadResponseThresholds(trustworthiness=0.7, response_helpfulness=0.7, custom_metric=0.45) # type: ignore[call-arg]
98+
99+
scores = update_scores_based_on_thresholds(raw_scores, thresholds)
100+
101+
expected_is_bad = {
102+
"trustworthiness": True,
103+
"response_helpfulness": False,
104+
"custom_metric": True,
105+
"another_metric": False,
106+
}
107+
108+
for metric, expected in expected_is_bad.items():
109+
assert scores[metric]["is_bad"] is expected
110+
assert all(scores[k]["score"] == raw_scores[k]["score"] for k in raw_scores)

tests/test_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ def test_remediate(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None
137137
mock_project.from_access_key.return_value.query.return_value = ("expert answer", None)
138138

139139
validator = Validator(codex_access_key="test")
140-
result = validator._remediate("test query") # noqa: SLF001
140+
result = validator._remediate(query="test query") # noqa: SLF001
141141

142142
# Verify project.query was called
143-
mock_project.from_access_key.return_value.query.assert_called_once_with(question="test query")
143+
mock_project.from_access_key.return_value.query.assert_called_once_with(question="test query", metadata=None)
144144
assert result == "expert answer"
145145

146146
def test_user_provided_thresholds(self, mock_project: Mock, mock_trustworthy_rag: Mock) -> None: # noqa: ARG002

0 commit comments

Comments
 (0)