|
1 | 1 | import os |
| 2 | +import tempfile |
2 | 3 | from typing import Generator, Any, Dict |
3 | 4 |
|
4 | 5 | import portforward |
|
10 | 11 | from ocp_resources.deployment import Deployment |
11 | 12 | from ocp_resources.llama_stack_distribution import LlamaStackDistribution |
12 | 13 | from ocp_resources.namespace import Namespace |
| 14 | +from ocp_resources.config_map import ConfigMap |
13 | 15 | from simple_logger.logger import get_logger |
| 16 | +from timeout_sampler import retry |
14 | 17 |
|
15 | 18 | from tests.llama_stack.utils import create_llama_stack_distribution, wait_for_llama_stack_client_ready |
16 | 19 | from utilities.constants import DscComponents, Timeout |
17 | 20 | from utilities.data_science_cluster_utils import update_components_in_dsc |
| 21 | +from utilities.rag_utils import ModelInfo |
18 | 22 |
|
19 | 23 |
|
20 | 24 | LOGGER = get_logger(name=__name__) |
@@ -81,22 +85,187 @@ def llama_stack_server_config( |
81 | 85 | }, |
82 | 86 | {"name": "FMS_ORCHESTRATOR_URL", "value": fms_orchestrator_url}, |
83 | 87 | ], |
| 88 | + "command": ["/bin/sh", "-c", "llama stack run /etc/llama-stack/run.yaml"], |
84 | 89 | "name": "llama-stack", |
85 | 90 | "port": 8321, |
86 | 91 | }, |
87 | 92 | "distribution": {"name": "rh-dev"}, |
| 93 | + "userConfig": {"configMapName": "rag-llama-stack-config-map"}, |
88 | 94 | "storage": { |
89 | 95 | "size": "20Gi", |
90 | 96 | }, |
91 | 97 | } |
92 | 98 |
|
93 | 99 |
|
| 100 | +@pytest.fixture(scope="class") |
| 101 | +def llama_stack_config_map( |
| 102 | + admin_client: DynamicClient, |
| 103 | + model_namespace: Namespace, |
| 104 | +) -> Generator[ConfigMap, Any, Any]: |
| 105 | + with ConfigMap( |
| 106 | + client=admin_client, |
| 107 | + namespace=model_namespace.name, |
| 108 | + name="rag-llama-stack-config-map", |
| 109 | + data={ |
| 110 | + "run.yaml": """version: 2 |
| 111 | +image_name: rh |
| 112 | +apis: |
| 113 | +- agents |
| 114 | +- datasetio |
| 115 | +- eval |
| 116 | +- inference |
| 117 | +- safety |
| 118 | +- files |
| 119 | +- scoring |
| 120 | +- telemetry |
| 121 | +- tool_runtime |
| 122 | +- vector_io |
| 123 | +providers: |
| 124 | + inference: |
| 125 | + - provider_id: vllm-inference |
| 126 | + provider_type: remote::vllm |
| 127 | + config: |
| 128 | + url: ${env.VLLM_URL:=http://localhost:8000/v1} |
| 129 | + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} |
| 130 | + api_token: ${env.VLLM_API_TOKEN:=fake} |
| 131 | + tls_verify: ${env.VLLM_TLS_VERIFY:=true} |
| 132 | + - provider_id: sentence-transformers |
| 133 | + provider_type: inline::sentence-transformers |
| 134 | + config: {} |
| 135 | + vector_io: |
| 136 | + - provider_id: milvus |
| 137 | + provider_type: inline::milvus |
| 138 | + config: |
| 139 | + db_path: /opt/app-root/src/.llama/distributions/rh/milvus.db |
| 140 | + kvstore: |
| 141 | + type: sqlite |
| 142 | + namespace: null |
| 143 | + db_path: /opt/app-root/src/.llama/distributions/rh/milvus_registry.db |
| 144 | + files: |
| 145 | + - provider_id: meta-reference-files |
| 146 | + provider_type: inline::localfs |
| 147 | + config: |
| 148 | + storage_dir: /opt/app-root/src/.llama/distributions/rh/files |
| 149 | + metadata_store: |
| 150 | + type: sqlite |
| 151 | + db_path: /opt/app-root/src/.llama/distributions/rh/files_metadata.db |
| 152 | + safety: |
| 153 | + - provider_id: trustyai_fms |
| 154 | + provider_type: remote::trustyai_fms |
| 155 | + config: |
| 156 | + orchestrator_url: ${env.FMS_ORCHESTRATOR_URL:=} |
| 157 | + ssl_cert_path: ${env.FMS_SSL_CERT_PATH:=} |
| 158 | + shields: {} |
| 159 | + agents: |
| 160 | + - provider_id: meta-reference |
| 161 | + provider_type: inline::meta-reference |
| 162 | + config: |
| 163 | + persistence_store: |
| 164 | + type: sqlite |
| 165 | + namespace: null |
| 166 | + db_path: /opt/app-root/src/.llama/distributions/rh/agents_store.db |
| 167 | + responses_store: |
| 168 | + type: sqlite |
| 169 | + db_path: /opt/app-root/src/.llama/distributions/rh/responses_store.db |
| 170 | + eval: |
| 171 | + - provider_id: trustyai_lmeval |
| 172 | + provider_type: remote::trustyai_lmeval |
| 173 | + config: |
| 174 | + use_k8s: True |
| 175 | + base_url: ${env.VLLM_URL:=http://localhost:8000/v1} |
| 176 | + datasetio: |
| 177 | + - provider_id: huggingface |
| 178 | + provider_type: remote::huggingface |
| 179 | + config: |
| 180 | + kvstore: |
| 181 | + type: sqlite |
| 182 | + namespace: null |
| 183 | + db_path: /opt/app-root/src/.llama/distributions/rh/huggingface_datasetio.db |
| 184 | + - provider_id: localfs |
| 185 | + provider_type: inline::localfs |
| 186 | + config: |
| 187 | + kvstore: |
| 188 | + type: sqlite |
| 189 | + namespace: null |
| 190 | + db_path: /opt/app-root/src/.llama/distributions/rh/localfs_datasetio.db |
| 191 | + scoring: |
| 192 | + - provider_id: basic |
| 193 | + provider_type: inline::basic |
| 194 | + config: {} |
| 195 | + - provider_id: llm-as-judge |
| 196 | + provider_type: inline::llm-as-judge |
| 197 | + config: {} |
| 198 | + - provider_id: braintrust |
| 199 | + provider_type: inline::braintrust |
| 200 | + config: |
| 201 | + openai_api_key: ${env.OPENAI_API_KEY:=} |
| 202 | + telemetry: |
| 203 | + - provider_id: meta-reference |
| 204 | + provider_type: inline::meta-reference |
| 205 | + config: |
| 206 | + service_name: "${env.OTEL_SERVICE_NAME:=\u200b}" |
| 207 | + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} |
| 208 | + sqlite_db_path: /opt/app-root/src/.llama/distributions/rh/trace_store.db |
| 209 | + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} |
| 210 | + tool_runtime: |
| 211 | + - provider_id: brave-search |
| 212 | + provider_type: remote::brave-search |
| 213 | + config: |
| 214 | + api_key: ${env.BRAVE_SEARCH_API_KEY:=} |
| 215 | + max_results: 3 |
| 216 | + - provider_id: tavily-search |
| 217 | + provider_type: remote::tavily-search |
| 218 | + config: |
| 219 | + api_key: ${env.TAVILY_SEARCH_API_KEY:=} |
| 220 | + max_results: 3 |
| 221 | + - provider_id: rag-runtime |
| 222 | + provider_type: inline::rag-runtime |
| 223 | + config: {} |
| 224 | + - provider_id: model-context-protocol |
| 225 | + provider_type: remote::model-context-protocol |
| 226 | + config: {} |
| 227 | +metadata_store: |
| 228 | + type: sqlite |
| 229 | + db_path: /opt/app-root/src/.llama/distributions/rh/registry.db |
| 230 | +inference_store: |
| 231 | + type: sqlite |
| 232 | + db_path: /opt/app-root/src/.llama/distributions/rh/inference_store.db |
| 233 | +models: |
| 234 | +- metadata: {} |
| 235 | + model_id: ${env.INFERENCE_MODEL} |
| 236 | + provider_id: vllm-inference |
| 237 | + model_type: llm |
| 238 | +- metadata: |
| 239 | + embedding_dimension: 768 |
| 240 | + model_id: granite-embedding-125m |
| 241 | + provider_id: sentence-transformers |
| 242 | + provider_model_id: ibm-granite/granite-embedding-125m-english |
| 243 | + model_type: embedding |
| 244 | +shields: [] |
| 245 | +vector_dbs: [] |
| 246 | +datasets: [] |
| 247 | +scoring_fns: [] |
| 248 | +benchmarks: [] |
| 249 | +tool_groups: |
| 250 | +- toolgroup_id: builtin::websearch |
| 251 | + provider_id: tavily-search |
| 252 | +- toolgroup_id: builtin::rag |
| 253 | + provider_id: rag-runtime |
| 254 | +server: |
| 255 | + port: 8321 |
| 256 | +external_providers_dir: /opt/app-root/src/.llama/providers.d""" |
| 257 | + }, |
| 258 | + ) as config_map: |
| 259 | + yield config_map |
| 260 | + |
| 261 | + |
94 | 262 | @pytest.fixture(scope="class") |
95 | 263 | def llama_stack_distribution( |
96 | 264 | admin_client: DynamicClient, |
97 | 265 | model_namespace: Namespace, |
98 | 266 | enabled_llama_stack_operator: DataScienceCluster, |
99 | 267 | llama_stack_server_config: Dict[str, Any], |
| 268 | + llama_stack_config_map: ConfigMap, |
100 | 269 | ) -> Generator[LlamaStackDistribution, None, None]: |
101 | 270 | with create_llama_stack_distribution( |
102 | 271 | client=admin_client, |
@@ -157,3 +326,135 @@ def llama_stack_client( |
157 | 326 | except Exception as e: |
158 | 327 | LOGGER.error(f"Failed to set up port forwarding: {e}") |
159 | 328 | raise |
| 329 | + |
| 330 | + |
| 331 | +@pytest.fixture(scope="class") |
| 332 | +def llama_stack_models(llama_stack_client: LlamaStackClient) -> ModelInfo: |
| 333 | + """ |
| 334 | + Returns model information from the LlamaStack client. |
| 335 | +
|
| 336 | + Provides: |
| 337 | + - model_id: The identifier of the LLM model |
| 338 | + - embedding_model: The embedding model object |
| 339 | + - embedding_dimension: The dimension of the embedding model |
| 340 | +
|
| 341 | + Args: |
| 342 | + llama_stack_client: The configured LlamaStackClient |
| 343 | +
|
| 344 | + Returns: |
| 345 | + ModelInfo: NamedTuple containing model information |
| 346 | + """ |
| 347 | + models = llama_stack_client.models.list() |
| 348 | + model_id = next(m for m in models if m.api_model_type == "llm").identifier |
| 349 | + |
| 350 | + embedding_model = next(m for m in models if m.api_model_type == "embedding") |
| 351 | + embedding_dimension = embedding_model.metadata["embedding_dimension"] |
| 352 | + |
| 353 | + return ModelInfo(model_id=model_id, embedding_model=embedding_model, embedding_dimension=embedding_dimension) |
| 354 | + |
| 355 | + |
| 356 | +@pytest.fixture(scope="class") |
| 357 | +def vector_store(llama_stack_client: LlamaStackClient, llama_stack_models: ModelInfo) -> Generator[Any, None, None]: |
| 358 | + """ |
| 359 | + Creates a vector store for testing and automatically cleans it up. |
| 360 | +
|
| 361 | + This fixture creates a vector store, yields it to the test, |
| 362 | + and ensures it's deleted after the test completes (whether it passes or fails). |
| 363 | +
|
| 364 | + Args: |
| 365 | + llama_stack_client: The configured LlamaStackClient |
| 366 | + llama_stack_models: Model information including embedding model details |
| 367 | +
|
| 368 | + Yields: |
| 369 | + Vector store object that can be used in tests |
| 370 | + """ |
| 371 | + # Setup |
| 372 | + vector_store = llama_stack_client.vector_stores.create( |
| 373 | + name="test_vector_store", |
| 374 | + embedding_model=llama_stack_models.embedding_model.identifier, |
| 375 | + embedding_dimension=llama_stack_models.embedding_dimension, |
| 376 | + ) |
| 377 | + |
| 378 | + yield vector_store |
| 379 | + |
| 380 | + try: |
| 381 | + llama_stack_client.vector_stores.delete(id=vector_store.id) |
| 382 | + LOGGER.info(f"Deleted vector store {vector_store.id}") |
| 383 | + except Exception as e: |
| 384 | + LOGGER.warning(f"Failed to delete vector store {vector_store.id}: {e}") |
| 385 | + |
| 386 | + |
| 387 | +@retry(wait_timeout=Timeout.TIMEOUT_1MIN, sleep=20) |
| 388 | +def _download_and_upload_file(url: str, llama_stack_client: LlamaStackClient, vector_store: Any) -> bool: |
| 389 | + """ |
| 390 | + Downloads a file from URL and uploads it to the vector store. |
| 391 | +
|
| 392 | + Args: |
| 393 | + url: The URL to download the file from |
| 394 | + llama_stack_client: The configured LlamaStackClient |
| 395 | + vector_store: The vector store to upload the file to |
| 396 | +
|
| 397 | + Returns: |
| 398 | + bool: True if successful, raises exception if failed |
| 399 | + """ |
| 400 | + import requests |
| 401 | + |
| 402 | + try: |
| 403 | + response = requests.get(url, timeout=30) |
| 404 | + response.raise_for_status() |
| 405 | + |
| 406 | + # Save file locally first and pretend it's a txt file, not sure why this is needed |
| 407 | + # but it works locally without it, |
| 408 | + # though llama stack version is the newer one. |
| 409 | + file_name = url.split("/")[-1] |
| 410 | + local_file_name = file_name.replace(".rst", ".txt") |
| 411 | + with tempfile.NamedTemporaryFile(mode="wb", suffix=f"_{local_file_name}") as temp_file: |
| 412 | + temp_file.write(response.content) |
| 413 | + temp_file_path = temp_file.name |
| 414 | + |
| 415 | + # Upload saved file to LlamaStack |
| 416 | + with open(temp_file_path, "rb") as file_to_upload: |
| 417 | + uploaded_file = llama_stack_client.files.create(file=file_to_upload, purpose="assistants") |
| 418 | + |
| 419 | + # Add file to vector store |
| 420 | + llama_stack_client.vector_stores.files.create(vector_store_id=vector_store.id, file_id=uploaded_file.id) |
| 421 | + |
| 422 | + return True |
| 423 | + |
| 424 | + except (requests.exceptions.RequestException, Exception) as e: |
| 425 | + LOGGER.warning(f"Failed to download and upload file {url}: {e}") |
| 426 | + raise |
| 427 | + |
| 428 | + |
| 429 | +@pytest.fixture(scope="class") |
| 430 | +def vector_store_with_docs(llama_stack_client: LlamaStackClient, vector_store: Any) -> Generator[Any, None, None]: |
| 431 | + """ |
| 432 | + Creates a vector store with TorchTune documentation files uploaded. |
| 433 | +
|
| 434 | + This fixture depends on the vector_store fixture and uploads the TorchTune |
| 435 | + documentation files to the vector store for testing purposes. The files |
| 436 | + are automatically cleaned up after the test completes. |
| 437 | +
|
| 438 | + Args: |
| 439 | + llama_stack_client: The configured LlamaStackClient |
| 440 | + vector_store: The vector store fixture to upload files to |
| 441 | +
|
| 442 | + Yields: |
| 443 | + Vector store object with uploaded TorchTune documentation files |
| 444 | + """ |
| 445 | + # Download TorchTune documentation files |
| 446 | + urls = [ |
| 447 | + "llama3.rst", |
| 448 | + "chat.rst", |
| 449 | + "lora_finetune.rst", |
| 450 | + "qat_finetune.rst", |
| 451 | + "memory_optimizations.rst", |
| 452 | + ] |
| 453 | + |
| 454 | + base_url = "https://raw.githubusercontent.com/pytorch/torchtune/refs/tags/v0.6.1/docs/source/tutorials/" |
| 455 | + |
| 456 | + for file_name in urls: |
| 457 | + url = f"{base_url}{file_name}" |
| 458 | + _download_and_upload_file(url=url, llama_stack_client=llama_stack_client, vector_store=vector_store) |
| 459 | + |
| 460 | + yield vector_store |
0 commit comments