|
1 | 1 | import os |
| 2 | +import tempfile |
2 | 3 | from typing import Generator, Any, Dict |
3 | 4 |
|
4 | 5 | import portforward |
5 | 6 | import pytest |
| 7 | +import requests |
6 | 8 | from _pytest.fixtures import FixtureRequest |
7 | 9 | from kubernetes.dynamic import DynamicClient |
8 | 10 | from llama_stack_client import LlamaStackClient |
| 11 | +from llama_stack_client.types.vector_store import VectorStore |
9 | 12 | from ocp_resources.data_science_cluster import DataScienceCluster |
10 | 13 | from ocp_resources.deployment import Deployment |
11 | 14 | from ocp_resources.llama_stack_distribution import LlamaStackDistribution |
12 | 15 | from ocp_resources.namespace import Namespace |
13 | 16 | from simple_logger.logger import get_logger |
| 17 | +from timeout_sampler import retry |
14 | 18 |
|
15 | 19 | from tests.llama_stack.utils import create_llama_stack_distribution, wait_for_llama_stack_client_ready |
16 | 20 | from utilities.constants import DscComponents, Timeout |
17 | 21 | from utilities.data_science_cluster_utils import update_components_in_dsc |
| 22 | +from utilities.rag_utils import ModelInfo |
18 | 23 |
|
19 | 24 |
|
20 | 25 | LOGGER = get_logger(name=__name__) |
@@ -43,19 +48,18 @@ def llama_stack_server_config( |
43 | 48 | vllm_api_token = os.getenv("LLS_CORE_VLLM_API_TOKEN", "") |
44 | 49 | vllm_url = os.getenv("LLS_CORE_VLLM_URL", "") |
45 | 50 |
|
46 | | - if hasattr(request, "param"): |
47 | | - if request.param.get("fms_orchestrator_url_fixture"): |
48 | | - fms_orchestrator_url = request.getfixturevalue(argname=request.param.get("fms_orchestrator_url_fixture")) |
| 51 | + # Override env vars with request parameters if provided |
| 52 | + params = getattr(request, "param", {}) or {} |
| 53 | + if params.get("fms_orchestrator_url_fixture"): |
| 54 | + fms_orchestrator_url = request.getfixturevalue(argname=params.get("fms_orchestrator_url_fixture")) |
| 55 | + if params.get("inference_model"): |
| 56 | + inference_model = params.get("inference_model") # type: ignore |
| 57 | + if params.get("vllm_api_token"): |
| 58 | + vllm_api_token = params.get("vllm_api_token") # type: ignore |
| 59 | + if params.get("vllm_url_fixture"): |
| 60 | + vllm_url = request.getfixturevalue(argname=params.get("vllm_url_fixture")) |
49 | 61 |
|
50 | | - # Override env vars with request parameters if provided |
51 | | - if request.param.get("inference_model"): |
52 | | - inference_model = request.param.get("inference_model") |
53 | | - if request.param.get("vllm_api_token"): |
54 | | - vllm_api_token = request.param.get("vllm_api_token") |
55 | | - if request.param.get("vllm_url_fixture"): |
56 | | - vllm_url = request.getfixturevalue(argname=request.param.get("vllm_url_fixture")) |
57 | | - |
58 | | - return { |
| 62 | + server_config: Dict[str, Any] = { |
59 | 63 | "containerSpec": { |
60 | 64 | "resources": { |
61 | 65 | "requests": {"cpu": "250m", "memory": "500Mi"}, |
@@ -85,11 +89,14 @@ def llama_stack_server_config( |
85 | 89 | "port": 8321, |
86 | 90 | }, |
87 | 91 | "distribution": {"name": "rh-dev"}, |
88 | | - "storage": { |
89 | | - "size": "20Gi", |
90 | | - }, |
91 | 92 | } |
92 | 93 |
|
| 94 | + if params.get("llama_stack_storage_size"): |
| 95 | + storage_size = params.get("llama_stack_storage_size") |
| 96 | + server_config["storage"] = {"size": storage_size} |
| 97 | + |
| 98 | + return server_config |
| 99 | + |
93 | 100 |
|
94 | 101 | @pytest.fixture(scope="class") |
95 | 102 | def llama_stack_distribution( |
@@ -157,3 +164,139 @@ def llama_stack_client( |
157 | 164 | except Exception as e: |
158 | 165 | LOGGER.error(f"Failed to set up port forwarding: {e}") |
159 | 166 | raise |
| 167 | + |
| 168 | + |
| 169 | +@pytest.fixture(scope="class") |
| 170 | +def llama_stack_models(llama_stack_client: LlamaStackClient) -> ModelInfo: |
| 171 | + """ |
| 172 | + Returns model information from the LlamaStack client. |
| 173 | +
|
| 174 | + Provides: |
| 175 | + - model_id: The identifier of the LLM model |
| 176 | + - embedding_model: The embedding model object |
| 177 | + - embedding_dimension: The dimension of the embedding model |
| 178 | +
|
| 179 | + Args: |
| 180 | + llama_stack_client: The configured LlamaStackClient |
| 181 | +
|
| 182 | + Returns: |
| 183 | + ModelInfo: NamedTuple containing model information |
| 184 | + """ |
| 185 | + models = llama_stack_client.models.list() |
| 186 | + model_id = next(m for m in models if m.api_model_type == "llm").identifier |
| 187 | + |
| 188 | + embedding_model = next(m for m in models if m.api_model_type == "embedding") |
| 189 | + embedding_dimension = embedding_model.metadata["embedding_dimension"] |
| 190 | + |
| 191 | + return ModelInfo(model_id=model_id, embedding_model=embedding_model, embedding_dimension=embedding_dimension) |
| 192 | + |
| 193 | + |
| 194 | +@pytest.fixture(scope="class") |
| 195 | +def vector_store( |
| 196 | + llama_stack_client: LlamaStackClient, llama_stack_models: ModelInfo |
| 197 | +) -> Generator[VectorStore, None, None]: |
| 198 | + """ |
| 199 | + Creates a vector store for testing and automatically cleans it up. |
| 200 | +
|
| 201 | + This fixture creates a vector store, yields it to the test, |
| 202 | + and ensures it's deleted after the test completes (whether it passes or fails). |
| 203 | +
|
| 204 | + Args: |
| 205 | + llama_stack_client: The configured LlamaStackClient |
| 206 | + llama_stack_models: Model information including embedding model details |
| 207 | +
|
| 208 | + Yields: |
| 209 | + Vector store object that can be used in tests |
| 210 | + """ |
| 211 | + # Setup |
| 212 | + vector_store = llama_stack_client.vector_stores.create( |
| 213 | + name="test_vector_store", |
| 214 | + embedding_model=llama_stack_models.embedding_model.identifier, # type: ignore |
| 215 | + embedding_dimension=llama_stack_models.embedding_dimension, |
| 216 | + ) |
| 217 | + |
| 218 | + yield vector_store |
| 219 | + |
| 220 | + try: |
| 221 | + llama_stack_client.vector_stores.delete(vector_store_id=vector_store.id) |
| 222 | + LOGGER.info(f"Deleted vector store {vector_store.id}") |
| 223 | + except Exception as e: |
| 224 | + LOGGER.warning(f"Failed to delete vector store {vector_store.id}: {e}") |
| 225 | + |
| 226 | + |
| 227 | +@retry( |
| 228 | + wait_timeout=Timeout.TIMEOUT_1MIN, |
| 229 | + sleep=5, |
| 230 | + exceptions_dict={requests.exceptions.RequestException: [], Exception: []}, |
| 231 | +) |
| 232 | +def _download_and_upload_file(url: str, llama_stack_client: LlamaStackClient, vector_store: Any) -> bool: |
| 233 | + """ |
| 234 | + Downloads a file from URL and uploads it to the vector store. |
| 235 | +
|
| 236 | + Args: |
| 237 | + url: The URL to download the file from |
| 238 | + llama_stack_client: The configured LlamaStackClient |
| 239 | + vector_store: The vector store to upload the file to |
| 240 | +
|
| 241 | + Returns: |
| 242 | + bool: True if successful, raises exception if failed |
| 243 | + """ |
| 244 | + try: |
| 245 | + response = requests.get(url, timeout=30) |
| 246 | + response.raise_for_status() |
| 247 | + |
| 248 | + # Save file locally first and pretend it's a txt file, not sure why this is needed |
| 249 | + # but it works locally without it, |
| 250 | + # though llama stack version is the newer one. |
| 251 | + file_name = url.split("/")[-1] |
| 252 | + local_file_name = file_name.replace(".rst", ".txt") |
| 253 | + with tempfile.NamedTemporaryFile(mode="wb", suffix=f"_{local_file_name}") as temp_file: |
| 254 | + temp_file.write(response.content) |
| 255 | + temp_file_path = temp_file.name |
| 256 | + |
| 257 | + # Upload saved file to LlamaStack |
| 258 | + with open(temp_file_path, "rb") as file_to_upload: |
| 259 | + uploaded_file = llama_stack_client.files.create(file=file_to_upload, purpose="assistants") |
| 260 | + |
| 261 | + # Add file to vector store |
| 262 | + llama_stack_client.vector_stores.files.create(vector_store_id=vector_store.id, file_id=uploaded_file.id) |
| 263 | + |
| 264 | + return True |
| 265 | + |
| 266 | + except (requests.exceptions.RequestException, Exception) as e: |
| 267 | + LOGGER.warning(f"Failed to download and upload file {url}: {e}") |
| 268 | + raise |
| 269 | + |
| 270 | + |
| 271 | +@pytest.fixture(scope="class") |
| 272 | +def vector_store_with_docs(llama_stack_client: LlamaStackClient, vector_store: Any) -> Generator[Any, None, None]: |
| 273 | + """ |
| 274 | + Creates a vector store with TorchTune documentation files uploaded. |
| 275 | +
|
| 276 | + This fixture depends on the vector_store fixture and uploads the TorchTune |
| 277 | + documentation files to the vector store for testing purposes. The files |
| 278 | + are automatically cleaned up after the test completes. |
| 279 | +
|
| 280 | + Args: |
| 281 | + llama_stack_client: The configured LlamaStackClient |
| 282 | + vector_store: The vector store fixture to upload files to |
| 283 | +
|
| 284 | + Yields: |
| 285 | + Vector store object with uploaded TorchTune documentation files |
| 286 | + """ |
| 287 | + # Download TorchTune documentation files |
| 288 | + urls = [ |
| 289 | + "llama3.rst", |
| 290 | + "chat.rst", |
| 291 | + "lora_finetune.rst", |
| 292 | + "qat_finetune.rst", |
| 293 | + "memory_optimizations.rst", |
| 294 | + ] |
| 295 | + |
| 296 | + base_url = "https://raw.githubusercontent.com/pytorch/torchtune/refs/tags/v0.6.1/docs/source/tutorials/" |
| 297 | + |
| 298 | + for file_name in urls: |
| 299 | + url = f"{base_url}{file_name}" |
| 300 | + _download_and_upload_file(url=url, llama_stack_client=llama_stack_client, vector_store=vector_store) |
| 301 | + |
| 302 | + yield vector_store |
0 commit comments