|
17 | 17 | TurnExpectation, |
18 | 18 | ModelInfo, |
19 | 19 | ValidationResult, |
20 | | - TurnResult, |
21 | 20 | LLS_CORE_POD_FILTER, |
22 | 21 | ) |
23 | 22 |
|
24 | | -from llama_stack_client import Agent, AgentEventLogger |
25 | 23 | import tempfile |
26 | 24 | import requests |
27 | 25 |
|
@@ -156,161 +154,6 @@ def _response_fn(*, question: str) -> str: |
156 | 154 | return _response_fn |
157 | 155 |
|
158 | 156 |
|
159 | | -def extract_event_content(event: Any) -> str: |
160 | | - """Extract content from various event types.""" |
161 | | - for attr in ["content", "message", "text"]: |
162 | | - if hasattr(event, attr) and getattr(event, attr): |
163 | | - return str(getattr(event, attr)) |
164 | | - return "" |
165 | | - |
166 | | - |
167 | | -def validate_rag_agent_responses( |
168 | | - rag_agent: Agent, |
169 | | - session_id: str, |
170 | | - turns_with_expectations: List[TurnExpectation], |
171 | | - stream: bool = True, |
172 | | - verbose: bool = True, |
173 | | - min_keywords_required: int = 1, |
174 | | - print_events: bool = False, |
175 | | -) -> ValidationResult: |
176 | | - """ |
177 | | - Validate RAG agent responses against expected keywords. |
178 | | -
|
179 | | - Tests multiple questions and validates that responses contain expected keywords. |
180 | | - Returns validation results with success status and detailed results for each turn. |
181 | | - """ |
182 | | - |
183 | | - all_results = [] |
184 | | - total_turns = len(turns_with_expectations) |
185 | | - successful_turns = 0 |
186 | | - |
187 | | - for turn_idx, turn_data in enumerate(turns_with_expectations, 1): |
188 | | - question = turn_data["question"] |
189 | | - expected_keywords = turn_data["expected_keywords"] |
190 | | - description = turn_data.get("description", "") |
191 | | - |
192 | | - if verbose: |
193 | | - LOGGER.info(f"[{turn_idx}/{total_turns}] Processing: {question}") |
194 | | - if description: |
195 | | - LOGGER.info(f"Expected: {description}") |
196 | | - |
197 | | - # Collect response content for validation |
198 | | - response_content = "" |
199 | | - event_count = 0 |
200 | | - |
201 | | - try: |
202 | | - # Create turn with the agent |
203 | | - turn_response = rag_agent.create_turn( |
204 | | - messages=[{"role": "user", "content": question}], |
205 | | - session_id=session_id, |
206 | | - stream=stream, |
207 | | - ) |
208 | | - |
209 | | - if stream: |
210 | | - for event in AgentEventLogger().log(turn_response): |
211 | | - if print_events: |
212 | | - event.print() |
213 | | - event_count += 1 |
214 | | - |
215 | | - # Extract content from different event types |
216 | | - response_content += extract_event_content(event) |
217 | | - else: |
218 | | - response_content = turn_response.output_text |
219 | | - |
220 | | - # Validate response content |
221 | | - response_lower = response_content.lower() |
222 | | - found_keywords = [] |
223 | | - missing_keywords = [] |
224 | | - |
225 | | - for keyword in expected_keywords: |
226 | | - if keyword.lower() in response_lower: |
227 | | - found_keywords.append(keyword) |
228 | | - else: |
229 | | - missing_keywords.append(keyword) |
230 | | - |
231 | | - # Determine if this turn was successful |
232 | | - if stream: |
233 | | - turn_successful = ( |
234 | | - event_count > 0 and len(response_content) > 0 and len(found_keywords) >= min_keywords_required |
235 | | - ) |
236 | | - else: |
237 | | - turn_successful = len(response_content) > 0 and len(found_keywords) >= min_keywords_required |
238 | | - |
239 | | - if turn_successful: |
240 | | - successful_turns += 1 |
241 | | - |
242 | | - # Store results for this turn |
243 | | - turn_result = { |
244 | | - "question": question, |
245 | | - "description": description, |
246 | | - "expected_keywords": expected_keywords, |
247 | | - "found_keywords": found_keywords, |
248 | | - "missing_keywords": missing_keywords, |
249 | | - "response_content": response_content, |
250 | | - "response_length": len(response_content), |
251 | | - "event_count": event_count, |
252 | | - "success": turn_successful, |
253 | | - "error": None, |
254 | | - } |
255 | | - |
256 | | - all_results.append(turn_result) |
257 | | - |
258 | | - if verbose: |
259 | | - LOGGER.info(f"Response length: {len(response_content)}") |
260 | | - LOGGER.info(f"Events processed: {event_count}") |
261 | | - LOGGER.info(f"Found keywords: {found_keywords}") |
262 | | - |
263 | | - if missing_keywords: |
264 | | - LOGGER.warning(f"Missing expected keywords: {missing_keywords}") |
265 | | - |
266 | | - if turn_successful: |
267 | | - LOGGER.info(f"✓ Successfully validated response for: {question}") |
268 | | - else: |
269 | | - LOGGER.error(f"✗ Validation failed for: {question}") |
270 | | - |
271 | | - if turn_idx < total_turns: # Don't print separator after last turn |
272 | | - LOGGER.info("-" * 50) |
273 | | - |
274 | | - except Exception as exc: |
275 | | - LOGGER.exception("Error processing turn %s", question) |
276 | | - turn_result = { |
277 | | - "question": question, |
278 | | - "description": description, |
279 | | - "expected_keywords": expected_keywords, |
280 | | - "found_keywords": [], |
281 | | - "missing_keywords": expected_keywords, |
282 | | - "response_content": "", |
283 | | - "response_length": 0, |
284 | | - "event_count": 0, |
285 | | - "success": False, |
286 | | - "error": str(exc), |
287 | | - } |
288 | | - all_results.append(turn_result) |
289 | | - |
290 | | - # Generate summary |
291 | | - summary = { |
292 | | - "total_turns": total_turns, |
293 | | - "successful_turns": successful_turns, |
294 | | - "failed_turns": total_turns - successful_turns, |
295 | | - "success_rate": successful_turns / total_turns if total_turns > 0 else 0, |
296 | | - "total_events": sum(cast(TurnResult, result)["event_count"] for result in all_results), |
297 | | - "total_response_length": sum(cast(TurnResult, result)["response_length"] for result in all_results), |
298 | | - } |
299 | | - |
300 | | - overall_success = successful_turns == total_turns |
301 | | - |
302 | | - if verbose: |
303 | | - LOGGER.info("=" * 60) |
304 | | - LOGGER.info("VALIDATION SUMMARY:") |
305 | | - LOGGER.info(f"Total turns: {summary['total_turns']}") |
306 | | - LOGGER.info(f"Successful: {summary['successful_turns']}") |
307 | | - LOGGER.info(f"Failed: {summary['failed_turns']}") |
308 | | - LOGGER.info(f"Success rate: {summary['success_rate']:.1%}") |
309 | | - LOGGER.info(f"Overall result: {'✓ PASSED' if overall_success else '✗ FAILED'}") |
310 | | - |
311 | | - return cast(ValidationResult, {"success": overall_success, "results": all_results, "summary": summary}) |
312 | | - |
313 | | - |
314 | 157 | def validate_api_responses( |
315 | 158 | response_fn: Callable[..., str], |
316 | 159 | test_cases: List[TurnExpectation], |
|
0 commit comments