Skip to content

Commit 031315f

Browse files
authored
Updates for llama stack 0.3.0 (#763)
* feat: Skip llama-stack tests if OpenShift < 4.17 Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: ensure ClusterVersion exists Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: use status.history[0].version when obtaining OpenShift version Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: Add bounds checking for history array access Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: use the openai embeddings api instead of deprecated api Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: update agents tests for compatibility with llama-stack 0.3.0 Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: delete deprecated vector_io tests the vector_db is no longer available in llama-stack 0.3.0 Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: update dependency to llama-stack-client to 0.3.x Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: fix pyproject Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: delete unused function Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> --------- Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com>
1 parent 896ba74 commit 031315f

File tree

7 files changed

+113
-236
lines changed

7 files changed

+113
-236
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ dependencies = [
6969
"marshmallow==3.26.1,<4", # this version is needed for pytest-jira
7070
"pytest-html>=4.1.1",
7171
"fire",
72-
"llama_stack_client==0.2.23",
72+
"llama_stack_client>=0.3.0,<0.4",
7373
"pytest-xdist==3.8.0",
7474
"dictdiffer>=0.9.0",
7575
]
Lines changed: 40 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import uuid
22
import pytest
3-
from llama_stack_client import Agent, LlamaStackClient, RAGDocument
3+
from llama_stack_client import Agent, LlamaStackClient
4+
from llama_stack_client.types.vector_store import VectorStore
45
from simple_logger.logger import get_logger
56
from tests.llama_stack.constants import ModelInfo
67
from tests.llama_stack.utils import get_torchtune_test_expectations, validate_rag_agent_responses
@@ -56,7 +57,7 @@ def test_agents_simple_agent(
5657
session_id=s_id,
5758
stream=False,
5859
)
59-
content = response.output_message.content
60+
content = response.output_text
6061
text = str(content or "")
6162
assert text, "LLM response content is empty"
6263
assert "model" in text.lower(), "The LLM didn't provide the expected answer to the prompt"
@@ -67,7 +68,7 @@ def test_agents_simple_agent(
6768
session_id=s_id,
6869
stream=False,
6970
)
70-
content = response.output_message.content
71+
content = response.output_text
7172
text = str(content or "")
7273
assert text, "LLM response content is empty"
7374
assert "answer" in text.lower(), "The LLM didn't provide the expected answer to the prompt"
@@ -77,6 +78,7 @@ def test_agents_rag_agent(
7778
self,
7879
unprivileged_llama_stack_client: LlamaStackClient,
7980
llama_stack_models: ModelInfo,
81+
vector_store_with_example_docs: VectorStore,
8082
) -> None:
8183
"""
8284
Test RAG agent that can answer questions about the Torchtune project using the documents
@@ -92,83 +94,40 @@ def test_agents_rag_agent(
9294
# TODO: update this example to use the vector_store API
9395
"""
9496

95-
vector_db_id: str | None = None
96-
try:
97-
vector_db = f"my-test-vector_db-{uuid.uuid4().hex}"
98-
res = unprivileged_llama_stack_client.vector_dbs.register(
99-
vector_db_id=vector_db,
100-
embedding_model=llama_stack_models.embedding_model.identifier,
101-
embedding_dimension=llama_stack_models.embedding_dimension,
102-
provider_id="milvus",
103-
)
104-
vector_db_id = res.identifier
105-
106-
# Create the RAG agent connected to the vector database
107-
rag_agent = Agent(
108-
client=unprivileged_llama_stack_client,
109-
model=llama_stack_models.model_id,
110-
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
111-
tools=[
112-
{
113-
"name": "builtin::rag/knowledge_search",
114-
"args": {"vector_db_ids": [vector_db_id]},
115-
}
116-
],
117-
)
118-
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
119-
120-
# Insert into the vector database example documents about torchtune
121-
urls = [
122-
"llama3.rst",
123-
"chat.rst",
124-
"lora_finetune.rst",
125-
"qat_finetune.rst",
126-
"memory_optimizations.rst",
127-
]
128-
documents = [
129-
RAGDocument(
130-
document_id=f"num-{index}",
131-
content=f"https://raw.githubusercontent.com/pytorch/torchtune/refs/tags/v0.6.1/docs/source/tutorials/{url}", # noqa
132-
mime_type="text/plain",
133-
metadata={},
134-
)
135-
for index, url in enumerate(urls)
136-
]
137-
138-
unprivileged_llama_stack_client.tool_runtime.rag_tool.insert(
139-
documents=documents,
140-
vector_db_id=vector_db_id,
141-
chunk_size_in_tokens=512,
142-
)
97+
# Create the RAG agent connected to the vector database
98+
rag_agent = Agent(
99+
client=unprivileged_llama_stack_client,
100+
model=llama_stack_models.model_id,
101+
instructions="You are a helpful assistant. Use the available tools to answer questions as needed.",
102+
tools=[
103+
{
104+
"type": "file_search",
105+
"vector_store_ids": [vector_store_with_example_docs.id],
106+
}
107+
],
108+
)
109+
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
110+
111+
turns_with_expectations = get_torchtune_test_expectations()
112+
113+
# Ask the agent about the inserted documents and validate responses
114+
validation_result = validate_rag_agent_responses(
115+
rag_agent=rag_agent,
116+
session_id=session_id,
117+
turns_with_expectations=turns_with_expectations,
118+
stream=True,
119+
verbose=True,
120+
min_keywords_required=1,
121+
print_events=False,
122+
)
143123

144-
turns_with_expectations = get_torchtune_test_expectations()
145-
146-
# Ask the agent about the inserted documents and validate responses
147-
validation_result = validate_rag_agent_responses(
148-
rag_agent=rag_agent,
149-
session_id=session_id,
150-
turns_with_expectations=turns_with_expectations,
151-
stream=True,
152-
verbose=True,
153-
min_keywords_required=1,
154-
print_events=False,
155-
)
124+
# Assert that validation was successful
125+
assert validation_result["success"], f"RAG agent validation failed. Summary: {validation_result['summary']}"
156126

157-
# Assert that validation was successful
158-
assert validation_result["success"], f"RAG agent validation failed. Summary: {validation_result['summary']}"
159-
160-
# Additional assertions for specific requirements
161-
for result in validation_result["results"]:
162-
assert result["event_count"] > 0, f"No events generated for question: {result['question']}"
163-
assert result["response_length"] > 0, f"No response content for question: {result['question']}"
164-
assert len(result["found_keywords"]) > 0, (
165-
f"No expected keywords found in response for: {result['question']}"
166-
)
167-
168-
finally:
169-
# Cleanup: unregister the vector database to prevent resource leaks
170-
if vector_db_id:
171-
try:
172-
unprivileged_llama_stack_client.vector_dbs.unregister(vector_db_id)
173-
except Exception as exc:
174-
LOGGER.warning("Failed to unregister vector database %s: %s", vector_db_id, exc)
127+
# Additional assertions for specific requirements
128+
for result in validation_result["results"]:
129+
assert result["event_count"] > 0, f"No events generated for question: {result['question']}"
130+
assert result["response_length"] > 0, f"No response content for question: {result['question']}"
131+
assert len(result["found_keywords"]) > 0, (
132+
f"No expected keywords found in response for: {result['question']}"
133+
)

tests/llama_stack/conftest.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from tests.llama_stack.utils import (
1717
create_llama_stack_distribution,
1818
wait_for_llama_stack_client_ready,
19+
vector_store_create_file_from_url,
1920
)
2021
from utilities.constants import DscComponents, Timeout
2122
from utilities.data_science_cluster_utils import update_components_in_dsc
@@ -378,9 +379,11 @@ def vector_store(
378379

379380
vector_store = unprivileged_llama_stack_client.vector_stores.create(
380381
name="test_vector_store",
381-
embedding_model=llama_stack_models.embedding_model.identifier,
382-
embedding_dimension=llama_stack_models.embedding_dimension,
383-
provider_id=vector_io_provider,
382+
extra_body={
383+
"embedding_model": llama_stack_models.embedding_model.identifier,
384+
"embedding_dimension": llama_stack_models.embedding_dimension,
385+
"provider_id": vector_io_provider,
386+
},
384387
)
385388
LOGGER.info(f"vector_store successfully created (provider_id={vector_io_provider}, id={vector_store.id})")
386389

@@ -391,3 +394,41 @@ def vector_store(
391394
LOGGER.info(f"Deleted vector store {vector_store.id}")
392395
except Exception as e:
393396
LOGGER.warning(f"Failed to delete vector store {vector_store.id}: {e}")
397+
398+
399+
@pytest.fixture(scope="class")
400+
def vector_store_with_example_docs(
401+
unprivileged_llama_stack_client: LlamaStackClient, vector_store: VectorStore
402+
) -> Generator[VectorStore, None, None]:
403+
"""
404+
Creates a vector store with TorchTune documentation files uploaded.
405+
406+
This fixture depends on the vector_store fixture and uploads the TorchTune
407+
documentation files to the vector store for testing purposes. The files
408+
are automatically cleaned up after the test completes.
409+
410+
Args:
411+
unprivileged_llama_stack_client: The configured LlamaStackClient
412+
vector_store: The vector store fixture to upload files to
413+
414+
Yields:
415+
Vector store object with uploaded TorchTune documentation files
416+
"""
417+
# Download TorchTune documentation files
418+
urls = [
419+
"llama3.rst",
420+
"chat.rst",
421+
"lora_finetune.rst",
422+
"qat_finetune.rst",
423+
"memory_optimizations.rst",
424+
]
425+
426+
base_url = "https://raw.githubusercontent.com/pytorch/torchtune/refs/tags/v0.6.1/docs/source/tutorials/"
427+
428+
for file_name in urls:
429+
url = f"{base_url}{file_name}"
430+
vector_store_create_file_from_url(
431+
url=url, llama_stack_client=unprivileged_llama_stack_client, vector_store=vector_store
432+
)
433+
434+
yield vector_store

tests/llama_stack/inference/test_inference.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
from llama_stack_client import LlamaStackClient
3-
from llama_stack_client.types import EmbeddingsResponse
3+
from llama_stack_client.types import CreateEmbeddingsResponse
44
from tests.llama_stack.constants import ModelInfo
55

66

@@ -73,12 +73,27 @@ def test_inference_embeddings(
7373
Validates that the server can generate properly formatted embedding vectors
7474
for text input with correct dimensions as specified in model metadata.
7575
"""
76-
embeddings_response = unprivileged_llama_stack_client.inference.embeddings(
77-
model_id=llama_stack_models.embedding_model.identifier,
78-
contents=["First chunk of text"],
79-
output_dimension=llama_stack_models.embedding_dimension,
76+
77+
embeddings_response = unprivileged_llama_stack_client.embeddings.create(
78+
model=llama_stack_models.embedding_model.identifier,
79+
input="The food was delicious and the waiter...",
80+
encoding_format="float",
8081
)
81-
assert isinstance(embeddings_response, EmbeddingsResponse)
82-
assert len(embeddings_response.embeddings) == 1
83-
assert isinstance(embeddings_response.embeddings[0], list)
84-
assert isinstance(embeddings_response.embeddings[0][0], float)
82+
83+
assert isinstance(embeddings_response, CreateEmbeddingsResponse)
84+
assert len(embeddings_response.data) == 1
85+
assert isinstance(embeddings_response.data[0].embedding, list)
86+
assert llama_stack_models.embedding_dimension == len(embeddings_response.data[0].embedding)
87+
assert isinstance(embeddings_response.data[0].embedding[0], float)
88+
89+
input_list = ["Input text 1", "Input text 1", "Input text 1"]
90+
embeddings_response = unprivileged_llama_stack_client.embeddings.create(
91+
model=llama_stack_models.embedding_model.identifier, input=input_list, encoding_format="float"
92+
)
93+
94+
assert isinstance(embeddings_response, CreateEmbeddingsResponse)
95+
assert len(embeddings_response.data) == len(input_list)
96+
for item in range(len(input_list)):
97+
assert isinstance(embeddings_response.data[item].embedding, list)
98+
assert llama_stack_models.embedding_dimension == len(embeddings_response.data[item].embedding)
99+
assert isinstance(embeddings_response.data[item].embedding[0], float)

tests/llama_stack/vector_io/test_vector_io_deprecated.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)