|
1 | 1 | import os |
2 | 2 | import tempfile |
3 | | -from collections.abc import Callable, Generator |
| 3 | +from collections.abc import Generator |
4 | 4 | from contextlib import contextmanager |
5 | | -from typing import Any, cast |
| 5 | +from typing import Any |
6 | 6 |
|
7 | 7 | import requests |
8 | 8 | from kubernetes.dynamic import DynamicClient |
9 | 9 | from kubernetes.dynamic.exceptions import ResourceNotFoundError |
10 | 10 | from llama_stack_client import APIConnectionError, InternalServerError, LlamaStackClient |
11 | | -from llama_stack_client.types.vector_store import VectorStore |
12 | 11 | from ocp_resources.pod import Pod |
13 | 12 | from simple_logger.logger import get_logger |
14 | 13 | from timeout_sampler import retry |
15 | 14 |
|
16 | 15 | from tests.llama_stack.constants import ( |
17 | 16 | LLS_CORE_POD_FILTER, |
18 | | - TORCHTUNE_TEST_EXPECTATIONS, |
19 | | - ModelInfo, |
20 | | - TurnExpectation, |
21 | | - ValidationResult, |
22 | 17 | ) |
23 | 18 | from utilities.exceptions import UnexpectedResourceCountError |
24 | 19 | from utilities.resources.llama_stack_distribution import LlamaStackDistribution |
@@ -115,142 +110,6 @@ def wait_for_llama_stack_client_ready(client: LlamaStackClient) -> bool: |
115 | 110 | return False |
116 | 111 |
|
117 | 112 |
|
118 | | -def get_torchtune_test_expectations() -> list[TurnExpectation]: |
119 | | - """ |
120 | | - Helper function to get the test expectations for TorchTune documentation questions. |
121 | | -
|
122 | | - Returns: |
123 | | - List of TurnExpectation objects for testing RAG responses |
124 | | - """ |
125 | | - return [ |
126 | | - { |
127 | | - "question": expectation.question, |
128 | | - "expected_keywords": expectation.expected_keywords, |
129 | | - "description": expectation.description, |
130 | | - } |
131 | | - for expectation in TORCHTUNE_TEST_EXPECTATIONS |
132 | | - ] |
133 | | - |
134 | | - |
135 | | -def create_response_function( |
136 | | - llama_stack_client: LlamaStackClient, llama_stack_models: ModelInfo, vector_store: VectorStore |
137 | | -) -> Callable: |
138 | | - """ |
139 | | - Helper function to create a response function for testing with vector store integration. |
140 | | -
|
141 | | - Args: |
142 | | - llama_stack_client: The LlamaStack client instance |
143 | | - llama_stack_models: The model configuration |
144 | | - vector_store: The vector store instance |
145 | | -
|
146 | | - Returns: |
147 | | - A callable function that takes a question and returns a response |
148 | | - """ |
149 | | - |
150 | | - def _response_fn(*, question: str) -> str: |
151 | | - response = llama_stack_client.responses.create( |
152 | | - input=question, |
153 | | - model=llama_stack_models.model_id, |
154 | | - stream=False, |
155 | | - tools=[ |
156 | | - { |
157 | | - "type": "file_search", |
158 | | - "vector_store_ids": [vector_store.id], |
159 | | - } |
160 | | - ], |
161 | | - ) |
162 | | - return response.output_text |
163 | | - |
164 | | - return _response_fn |
165 | | - |
166 | | - |
167 | | -def validate_api_responses( |
168 | | - response_fn: Callable[..., str], |
169 | | - test_cases: list[TurnExpectation], |
170 | | - min_keywords_required: int = 1, |
171 | | -) -> ValidationResult: |
172 | | - """ |
173 | | - Validate API responses against expected keywords. |
174 | | -
|
175 | | - Tests multiple questions and validates that responses contain expected keywords. |
176 | | - Returns validation results with success status and detailed results for each turn. |
177 | | - """ |
178 | | - all_results = [] |
179 | | - successful = 0 |
180 | | - |
181 | | - for idx, test in enumerate(test_cases, 1): |
182 | | - question = test["question"] |
183 | | - expected_keywords = test["expected_keywords"] |
184 | | - description = test.get("description", "") |
185 | | - |
186 | | - LOGGER.debug(f"\n[{idx}] Question: {question}") |
187 | | - if description: |
188 | | - LOGGER.debug(f" Expectation: {description}") |
189 | | - |
190 | | - try: |
191 | | - response = response_fn(question=question) |
192 | | - response_lower = response.lower() |
193 | | - |
194 | | - found = [kw for kw in expected_keywords if kw.lower() in response_lower] |
195 | | - missing = [kw for kw in expected_keywords if kw.lower() not in response_lower] |
196 | | - success = len(found) >= min_keywords_required |
197 | | - |
198 | | - if success: |
199 | | - successful += 1 |
200 | | - |
201 | | - result = { |
202 | | - "question": question, |
203 | | - "description": description, |
204 | | - "expected_keywords": expected_keywords, |
205 | | - "found_keywords": found, |
206 | | - "missing_keywords": missing, |
207 | | - "response_content": response, |
208 | | - "response_length": len(response) if isinstance(response, str) else 0, |
209 | | - "event_count": len(response.events) if hasattr(response, "events") else 0, |
210 | | - "success": success, |
211 | | - "error": None, |
212 | | - } |
213 | | - |
214 | | - all_results.append(result) |
215 | | - |
216 | | - LOGGER.debug(f"✓ Found: {found}") |
217 | | - if missing: |
218 | | - LOGGER.debug(f"✗ Missing: {missing}") |
219 | | - LOGGER.info(f"[{idx}] Result: {'PASS' if success else 'FAIL'}") |
220 | | - |
221 | | - except Exception as e: |
222 | | - all_results.append({ |
223 | | - "question": question, |
224 | | - "description": description, |
225 | | - "expected_keywords": expected_keywords, |
226 | | - "found_keywords": [], |
227 | | - "missing_keywords": expected_keywords, |
228 | | - "response_content": "", |
229 | | - "response_length": 0, |
230 | | - "event_count": 0, |
231 | | - "success": False, |
232 | | - "error": str(e), |
233 | | - }) |
234 | | - LOGGER.exception(f"[{idx}] ERROR") |
235 | | - |
236 | | - total = len(test_cases) |
237 | | - summary = { |
238 | | - "total": total, |
239 | | - "passed": successful, |
240 | | - "failed": total - successful, |
241 | | - "success_rate": successful / total if total > 0 else 0, |
242 | | - } |
243 | | - |
244 | | - LOGGER.info("\n" + "=" * 40) |
245 | | - LOGGER.info("Validation Summary:") |
246 | | - LOGGER.info(f"Total: {summary['total']}") |
247 | | - LOGGER.info(f"Passed: {summary['passed']}") |
248 | | - LOGGER.info(f"Failed: {summary['failed']}") |
249 | | - LOGGER.info(f"Success rate: {summary['success_rate']:.1%}") |
250 | | - |
251 | | - return cast("ValidationResult", {"success": successful == total, "results": all_results, "summary": summary}) |
252 | | - |
253 | | - |
254 | 113 | @retry( |
255 | 114 | wait_timeout=240, |
256 | 115 | sleep=15, |
|
0 commit comments