@@ -149,6 +149,39 @@ def __init__(
149
149
error_msg = f"Found thresholds for non-existent evaluation metrics: { _extra_thresholds } "
150
150
raise ValueError (error_msg )
151
151
152
+ async def validate_async (
153
+ self ,
154
+ query : str ,
155
+ context : str ,
156
+ response : str ,
157
+ prompt : Optional [str ] = None ,
158
+ form_prompt : Optional [Callable [[str , str ], str ]] = None ,
159
+ ) -> dict [str , Any ]:
160
+ """Async version of validate"""
161
+ expert_task = asyncio .create_task (self .remediate_async (query ))
162
+ detect_task = asyncio .get_running_loop ().run_in_executor (
163
+ None ,
164
+ self .detect ,
165
+ query , context , response , prompt , form_prompt
166
+ )
167
+
168
+ # Use gather to run tasks concurrently
169
+ expert_result , detect_result = await asyncio .gather (expert_task , detect_task )
170
+ expert_answer , maybe_entry = expert_result
171
+ scores , is_bad_response = detect_result
172
+
173
+ # Rest of your existing logic
174
+ if is_bad_response and not expert_answer :
175
+ self ._project ._sdk_client .projects .entries .add_question (
176
+ self ._project ._id , question = query ,
177
+ ).model_dump ()
178
+
179
+ return {
180
+ "expert_answer" : expert_answer if is_bad_response else None ,
181
+ "is_bad_response" : is_bad_response ,
182
+ ** scores ,
183
+ }
184
+
152
185
def validate (
153
186
self ,
154
187
query : str ,
@@ -174,30 +207,18 @@ def validate(
174
207
- Additional keys from a [`ThresholdedTrustworthyRAGScore`](/cleanlab_codex/types/validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold.
175
208
"""
176
209
try :
210
+ # Try to use existing event loop
177
211
loop = asyncio .get_running_loop ()
178
- except RuntimeError : # No running loop
179
- loop = asyncio .new_event_loop ()
180
- asyncio .set_event_loop (loop )
181
- expert_task = loop .create_task (self .remediate_async (query ))
182
- detect_task = loop .run_in_executor (None , self .detect , query , context , response , prompt , form_prompt )
183
- expert_answer , maybe_entry = loop .run_until_complete (expert_task )
184
- scores , is_bad_response = loop .run_until_complete (detect_task )
185
- if not loop .is_running ():
186
- loop .close ()
187
- if is_bad_response :
188
- if expert_answer == None :
189
- # TODO: Make this async as well in the future (only if add_question takes nontrivial amt of time on the client)
190
- self ._project ._sdk_client .projects .entries .add_question (
191
- self ._project ._id , question = query ,
192
- ).model_dump ()
193
- else :
194
- expert_answer = None
195
-
196
- return {
197
- "expert_answer" : expert_answer ,
198
- "is_bad_response" : is_bad_response ,
199
- ** scores ,
200
- }
212
+ future = asyncio .run_coroutine_threadsafe (
213
+ self .validate_async (query , context , response , prompt , form_prompt ),
214
+ loop
215
+ )
216
+ return future .result ()
217
+ except RuntimeError :
218
+ # No existing loop - create new one
219
+ return asyncio .run (
220
+ self .validate_async (query , context , response , prompt , form_prompt )
221
+ )
201
222
202
223
def detect (
203
224
self ,
0 commit comments