Skip to content

Commit 2003b72

Browse files
committed
merge
1 parent 3139461 commit 2003b72

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

src/cleanlab_codex/validator.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,39 @@ def __init__(
149149
error_msg = f"Found thresholds for non-existent evaluation metrics: {_extra_thresholds}"
150150
raise ValueError(error_msg)
151151

152+
async def validate_async(
153+
self,
154+
query: str,
155+
context: str,
156+
response: str,
157+
prompt: Optional[str] = None,
158+
form_prompt: Optional[Callable[[str, str], str]] = None,
159+
) -> dict[str, Any]:
160+
"""Async version of validate"""
161+
expert_task = asyncio.create_task(self.remediate_async(query))
162+
detect_task = asyncio.get_running_loop().run_in_executor(
163+
None,
164+
self.detect,
165+
query, context, response, prompt, form_prompt
166+
)
167+
168+
# Use gather to run tasks concurrently
169+
expert_result, detect_result = await asyncio.gather(expert_task, detect_task)
170+
expert_answer, maybe_entry = expert_result
171+
scores, is_bad_response = detect_result
172+
173+
# Rest of your existing logic
174+
if is_bad_response and not expert_answer:
175+
self._project._sdk_client.projects.entries.add_question(
176+
self._project._id, question=query,
177+
).model_dump()
178+
179+
return {
180+
"expert_answer": expert_answer if is_bad_response else None,
181+
"is_bad_response": is_bad_response,
182+
**scores,
183+
}
184+
152185
def validate(
153186
self,
154187
query: str,
@@ -174,30 +207,18 @@ def validate(
174207
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
175208
"""
176209
try:
210+
# Try to use existing event loop
177211
loop = asyncio.get_running_loop()
178-
except RuntimeError: # No running loop
179-
loop = asyncio.new_event_loop()
180-
asyncio.set_event_loop(loop)
181-
expert_task = loop.create_task(self.remediate_async(query))
182-
detect_task = loop.run_in_executor(None, self.detect, query, context, response, prompt, form_prompt)
183-
expert_answer, maybe_entry = loop.run_until_complete(expert_task)
184-
scores, is_bad_response = loop.run_until_complete(detect_task)
185-
if not loop.is_running():
186-
loop.close()
187-
if is_bad_response:
188-
if expert_answer == None:
189-
# TODO: Make this async as well in the future (only if add_question takes nontrivial amt of time on the client)
190-
self._project._sdk_client.projects.entries.add_question(
191-
self._project._id, question=query,
192-
).model_dump()
193-
else:
194-
expert_answer = None
195-
196-
return {
197-
"expert_answer": expert_answer,
198-
"is_bad_response": is_bad_response,
199-
**scores,
200-
}
212+
future = asyncio.run_coroutine_threadsafe(
213+
self.validate_async(query, context, response, prompt, form_prompt),
214+
loop
215+
)
216+
return future.result()
217+
except RuntimeError:
218+
# No existing loop - create new one
219+
return asyncio.run(
220+
self.validate_async(query, context, response, prompt, form_prompt)
221+
)
201222

202223
def detect(
203224
self,

0 commit comments

Comments
 (0)