Skip to content

Commit f51d180

Browse files
committed
test: add openai compatible tests for responses, vector stores and files api
1 parent 3983e69 commit f51d180

File tree

4 files changed

+618
-72
lines changed

4 files changed

+618
-72
lines changed

tests/llama_stack/conftest.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import tempfile
23
from typing import Generator, Any, Dict
34

45
import portforward
@@ -10,11 +11,14 @@
1011
from ocp_resources.deployment import Deployment
1112
from ocp_resources.llama_stack_distribution import LlamaStackDistribution
1213
from ocp_resources.namespace import Namespace
14+
from ocp_resources.config_map import ConfigMap
1315
from simple_logger.logger import get_logger
16+
from timeout_sampler import retry
1417

1518
from tests.llama_stack.utils import create_llama_stack_distribution, wait_for_llama_stack_client_ready
1619
from utilities.constants import DscComponents, Timeout
1720
from utilities.data_science_cluster_utils import update_components_in_dsc
21+
from utilities.rag_utils import ModelInfo
1822

1923

2024
LOGGER = get_logger(name=__name__)
@@ -81,22 +85,187 @@ def llama_stack_server_config(
8185
},
8286
{"name": "FMS_ORCHESTRATOR_URL", "value": fms_orchestrator_url},
8387
],
88+
"command": ["/bin/sh", "-c", "llama stack run /etc/llama-stack/run.yaml"],
8489
"name": "llama-stack",
8590
"port": 8321,
8691
},
8792
"distribution": {"name": "rh-dev"},
93+
"userConfig": {"configMapName": "rag-llama-stack-config-map"},
8894
"storage": {
8995
"size": "20Gi",
9096
},
9197
}
9298

9399

100+
@pytest.fixture(scope="class")
101+
def llama_stack_config_map(
102+
admin_client: DynamicClient,
103+
model_namespace: Namespace,
104+
) -> Generator[ConfigMap, Any, Any]:
105+
with ConfigMap(
106+
client=admin_client,
107+
namespace=model_namespace.name,
108+
name="rag-llama-stack-config-map",
109+
data={
110+
"run.yaml": """version: 2
111+
image_name: rh
112+
apis:
113+
- agents
114+
- datasetio
115+
- eval
116+
- inference
117+
- safety
118+
- files
119+
- scoring
120+
- telemetry
121+
- tool_runtime
122+
- vector_io
123+
providers:
124+
inference:
125+
- provider_id: vllm-inference
126+
provider_type: remote::vllm
127+
config:
128+
url: ${env.VLLM_URL:=http://localhost:8000/v1}
129+
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
130+
api_token: ${env.VLLM_API_TOKEN:=fake}
131+
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
132+
- provider_id: sentence-transformers
133+
provider_type: inline::sentence-transformers
134+
config: {}
135+
vector_io:
136+
- provider_id: milvus
137+
provider_type: inline::milvus
138+
config:
139+
db_path: /opt/app-root/src/.llama/distributions/rh/milvus.db
140+
kvstore:
141+
type: sqlite
142+
namespace: null
143+
db_path: /opt/app-root/src/.llama/distributions/rh/milvus_registry.db
144+
files:
145+
- provider_id: meta-reference-files
146+
provider_type: inline::localfs
147+
config:
148+
storage_dir: /opt/app-root/src/.llama/distributions/rh/files
149+
metadata_store:
150+
type: sqlite
151+
db_path: /opt/app-root/src/.llama/distributions/rh/files_metadata.db
152+
safety:
153+
- provider_id: trustyai_fms
154+
provider_type: remote::trustyai_fms
155+
config:
156+
orchestrator_url: ${env.FMS_ORCHESTRATOR_URL:=}
157+
ssl_cert_path: ${env.FMS_SSL_CERT_PATH:=}
158+
shields: {}
159+
agents:
160+
- provider_id: meta-reference
161+
provider_type: inline::meta-reference
162+
config:
163+
persistence_store:
164+
type: sqlite
165+
namespace: null
166+
db_path: /opt/app-root/src/.llama/distributions/rh/agents_store.db
167+
responses_store:
168+
type: sqlite
169+
db_path: /opt/app-root/src/.llama/distributions/rh/responses_store.db
170+
eval:
171+
- provider_id: trustyai_lmeval
172+
provider_type: remote::trustyai_lmeval
173+
config:
174+
use_k8s: True
175+
base_url: ${env.VLLM_URL:=http://localhost:8000/v1}
176+
datasetio:
177+
- provider_id: huggingface
178+
provider_type: remote::huggingface
179+
config:
180+
kvstore:
181+
type: sqlite
182+
namespace: null
183+
db_path: /opt/app-root/src/.llama/distributions/rh/huggingface_datasetio.db
184+
- provider_id: localfs
185+
provider_type: inline::localfs
186+
config:
187+
kvstore:
188+
type: sqlite
189+
namespace: null
190+
db_path: /opt/app-root/src/.llama/distributions/rh/localfs_datasetio.db
191+
scoring:
192+
- provider_id: basic
193+
provider_type: inline::basic
194+
config: {}
195+
- provider_id: llm-as-judge
196+
provider_type: inline::llm-as-judge
197+
config: {}
198+
- provider_id: braintrust
199+
provider_type: inline::braintrust
200+
config:
201+
openai_api_key: ${env.OPENAI_API_KEY:=}
202+
telemetry:
203+
- provider_id: meta-reference
204+
provider_type: inline::meta-reference
205+
config:
206+
service_name: "${env.OTEL_SERVICE_NAME:=\u200b}"
207+
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
208+
sqlite_db_path: /opt/app-root/src/.llama/distributions/rh/trace_store.db
209+
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
210+
tool_runtime:
211+
- provider_id: brave-search
212+
provider_type: remote::brave-search
213+
config:
214+
api_key: ${env.BRAVE_SEARCH_API_KEY:=}
215+
max_results: 3
216+
- provider_id: tavily-search
217+
provider_type: remote::tavily-search
218+
config:
219+
api_key: ${env.TAVILY_SEARCH_API_KEY:=}
220+
max_results: 3
221+
- provider_id: rag-runtime
222+
provider_type: inline::rag-runtime
223+
config: {}
224+
- provider_id: model-context-protocol
225+
provider_type: remote::model-context-protocol
226+
config: {}
227+
metadata_store:
228+
type: sqlite
229+
db_path: /opt/app-root/src/.llama/distributions/rh/registry.db
230+
inference_store:
231+
type: sqlite
232+
db_path: /opt/app-root/src/.llama/distributions/rh/inference_store.db
233+
models:
234+
- metadata: {}
235+
model_id: ${env.INFERENCE_MODEL}
236+
provider_id: vllm-inference
237+
model_type: llm
238+
- metadata:
239+
embedding_dimension: 768
240+
model_id: granite-embedding-125m
241+
provider_id: sentence-transformers
242+
provider_model_id: ibm-granite/granite-embedding-125m-english
243+
model_type: embedding
244+
shields: []
245+
vector_dbs: []
246+
datasets: []
247+
scoring_fns: []
248+
benchmarks: []
249+
tool_groups:
250+
- toolgroup_id: builtin::websearch
251+
provider_id: tavily-search
252+
- toolgroup_id: builtin::rag
253+
provider_id: rag-runtime
254+
server:
255+
port: 8321
256+
external_providers_dir: /opt/app-root/src/.llama/providers.d"""
257+
},
258+
) as config_map:
259+
yield config_map
260+
261+
94262
@pytest.fixture(scope="class")
95263
def llama_stack_distribution(
96264
admin_client: DynamicClient,
97265
model_namespace: Namespace,
98266
enabled_llama_stack_operator: DataScienceCluster,
99267
llama_stack_server_config: Dict[str, Any],
268+
llama_stack_config_map: ConfigMap,
100269
) -> Generator[LlamaStackDistribution, None, None]:
101270
with create_llama_stack_distribution(
102271
client=admin_client,
@@ -157,3 +326,135 @@ def llama_stack_client(
157326
except Exception as e:
158327
LOGGER.error(f"Failed to set up port forwarding: {e}")
159328
raise
329+
330+
331+
@pytest.fixture(scope="class")
332+
def llama_stack_models(llama_stack_client: LlamaStackClient) -> ModelInfo:
333+
"""
334+
Returns model information from the LlamaStack client.
335+
336+
Provides:
337+
- model_id: The identifier of the LLM model
338+
- embedding_model: The embedding model object
339+
- embedding_dimension: The dimension of the embedding model
340+
341+
Args:
342+
llama_stack_client: The configured LlamaStackClient
343+
344+
Returns:
345+
ModelInfo: NamedTuple containing model information
346+
"""
347+
models = llama_stack_client.models.list()
348+
model_id = next(m for m in models if m.api_model_type == "llm").identifier
349+
350+
embedding_model = next(m for m in models if m.api_model_type == "embedding")
351+
embedding_dimension = embedding_model.metadata["embedding_dimension"]
352+
353+
return ModelInfo(model_id=model_id, embedding_model=embedding_model, embedding_dimension=embedding_dimension)
354+
355+
356+
@pytest.fixture(scope="class")
357+
def vector_store(llama_stack_client: LlamaStackClient, llama_stack_models: ModelInfo) -> Generator[Any, None, None]:
358+
"""
359+
Creates a vector store for testing and automatically cleans it up.
360+
361+
This fixture creates a vector store, yields it to the test,
362+
and ensures it's deleted after the test completes (whether it passes or fails).
363+
364+
Args:
365+
llama_stack_client: The configured LlamaStackClient
366+
llama_stack_models: Model information including embedding model details
367+
368+
Yields:
369+
Vector store object that can be used in tests
370+
"""
371+
# Setup
372+
vector_store = llama_stack_client.vector_stores.create(
373+
name="test_vector_store",
374+
embedding_model=llama_stack_models.embedding_model.identifier,
375+
embedding_dimension=llama_stack_models.embedding_dimension,
376+
)
377+
378+
yield vector_store
379+
380+
try:
381+
llama_stack_client.vector_stores.delete(id=vector_store.id)
382+
LOGGER.info(f"Deleted vector store {vector_store.id}")
383+
except Exception as e:
384+
LOGGER.warning(f"Failed to delete vector store {vector_store.id}: {e}")
385+
386+
387+
@retry(wait_timeout=Timeout.TIMEOUT_1MIN, sleep=20)
388+
def _download_and_upload_file(url: str, llama_stack_client: LlamaStackClient, vector_store: Any) -> bool:
389+
"""
390+
Downloads a file from URL and uploads it to the vector store.
391+
392+
Args:
393+
url: The URL to download the file from
394+
llama_stack_client: The configured LlamaStackClient
395+
vector_store: The vector store to upload the file to
396+
397+
Returns:
398+
bool: True if successful, raises exception if failed
399+
"""
400+
import requests
401+
402+
try:
403+
response = requests.get(url, timeout=30)
404+
response.raise_for_status()
405+
406+
# Save file locally first and pretend it's a txt file, not sure why this is needed
407+
# but it works locally without it,
408+
# though llama stack version is the newer one.
409+
file_name = url.split("/")[-1]
410+
local_file_name = file_name.replace(".rst", ".txt")
411+
with tempfile.NamedTemporaryFile(mode="wb", suffix=f"_{local_file_name}") as temp_file:
412+
temp_file.write(response.content)
413+
temp_file_path = temp_file.name
414+
415+
# Upload saved file to LlamaStack
416+
with open(temp_file_path, "rb") as file_to_upload:
417+
uploaded_file = llama_stack_client.files.create(file=file_to_upload, purpose="assistants")
418+
419+
# Add file to vector store
420+
llama_stack_client.vector_stores.files.create(vector_store_id=vector_store.id, file_id=uploaded_file.id)
421+
422+
return True
423+
424+
except (requests.exceptions.RequestException, Exception) as e:
425+
LOGGER.warning(f"Failed to download and upload file {url}: {e}")
426+
raise
427+
428+
429+
@pytest.fixture(scope="class")
430+
def vector_store_with_docs(llama_stack_client: LlamaStackClient, vector_store: Any) -> Generator[Any, None, None]:
431+
"""
432+
Creates a vector store with TorchTune documentation files uploaded.
433+
434+
This fixture depends on the vector_store fixture and uploads the TorchTune
435+
documentation files to the vector store for testing purposes. The files
436+
are automatically cleaned up after the test completes.
437+
438+
Args:
439+
llama_stack_client: The configured LlamaStackClient
440+
vector_store: The vector store fixture to upload files to
441+
442+
Yields:
443+
Vector store object with uploaded TorchTune documentation files
444+
"""
445+
# Download TorchTune documentation files
446+
urls = [
447+
"llama3.rst",
448+
"chat.rst",
449+
"lora_finetune.rst",
450+
"qat_finetune.rst",
451+
"memory_optimizations.rst",
452+
]
453+
454+
base_url = "https://raw.githubusercontent.com/pytorch/torchtune/refs/tags/v0.6.1/docs/source/tutorials/"
455+
456+
for file_name in urls:
457+
url = f"{base_url}{file_name}"
458+
_download_and_upload_file(url=url, llama_stack_client=llama_stack_client, vector_store=vector_store)
459+
460+
yield vector_store

0 commit comments

Comments
 (0)