Skip to content

Commit 50cd309

Browse files
authored
Add embeddings and vector_io tests (#488)
* change: Add results foder to .gitignore Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * test: add embeddings and vector_io tests Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> --------- Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com>
1 parent ee2becb commit 50cd309

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ coverage.xml
5050
.hypothesis/
5151
.pytest_cache/
5252
cover/
53+
results/
5354

5455
# Translations
5556
*.mo

tests/rag/test_rag.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import pytest
55
from llama_stack_client import Agent, LlamaStackClient, RAGDocument
6+
from llama_stack_client.types import EmbeddingsResponse, QueryChunksResponse
7+
from llama_stack_client.types.vector_io_insert_params import Chunk
68
from ocp_resources.deployment import Deployment
79
from simple_logger.logger import get_logger
810

@@ -50,7 +52,7 @@ def test_llama_stack_server(
5052
assert embedding_dimension is not None, "No embedding_dimension set in embedding model"
5153

5254
@pytest.mark.smoke
53-
def test_rag_basic_inference(self, rag_lls_client: LlamaStackClient) -> None:
55+
def test_rag_chat_completions(self, rag_lls_client: LlamaStackClient) -> None:
5456
"""
5557
Test basic chat completion inference through LlamaStack client.
5658
@@ -77,6 +79,87 @@ def test_rag_basic_inference(self, rag_lls_client: LlamaStackClient) -> None:
7779
assert content is not None, "LLM response content is None"
7880
assert "Paris" in content, "The LLM didn't provide the expected answer to the prompt"
7981

82+
@pytest.mark.smoke
83+
def test_rag_inference_embeddings(self, rag_lls_client: LlamaStackClient) -> None:
84+
"""
85+
Test embedding model functionality and vector generation.
86+
87+
Validates that the server can generate properly formatted embedding vectors
88+
for text input with correct dimensions as specified in model metadata.
89+
"""
90+
models = rag_lls_client.models.list()
91+
embedding_model = next(m for m in models if m.api_model_type == "embedding")
92+
embedding_dimension = embedding_model.metadata["embedding_dimension"]
93+
94+
embeddings_response = rag_lls_client.inference.embeddings(
95+
model_id=embedding_model.identifier,
96+
contents=["First chunk of text"],
97+
output_dimension=embedding_dimension, # type: ignore
98+
)
99+
assert isinstance(embeddings_response, EmbeddingsResponse)
100+
assert len(embeddings_response.embeddings) == 1
101+
assert isinstance(embeddings_response.embeddings[0], list)
102+
assert isinstance(embeddings_response.embeddings[0][0], float)
103+
104+
@pytest.mark.smoke
105+
def test_rag_vector_io_ingestion_retrieval(self, rag_lls_client: LlamaStackClient) -> None:
106+
"""
107+
Validates basic vector_db API in llama-stack using milvus
108+
109+
Tests registering, inserting and retrieving information from a milvus vector db database
110+
111+
Based on the example available at
112+
https://llama-stack.readthedocs.io/en/latest/building_applications/rag.html
113+
"""
114+
models = rag_lls_client.models.list()
115+
embedding_model = next(m for m in models if m.api_model_type == "embedding")
116+
embedding_dimension = embedding_model.metadata["embedding_dimension"]
117+
118+
# Create a vector database instance
119+
vector_db_id = f"v{uuid.uuid4().hex}"
120+
121+
try:
122+
rag_lls_client.vector_dbs.register(
123+
vector_db_id=vector_db_id,
124+
embedding_model=embedding_model.identifier,
125+
embedding_dimension=embedding_dimension, # type: ignore
126+
provider_id="milvus",
127+
)
128+
129+
# Calculate embeddings
130+
embeddings_response = rag_lls_client.inference.embeddings(
131+
model_id=embedding_model.identifier,
132+
contents=["First chunk of text"],
133+
output_dimension=embedding_dimension, # type: ignore
134+
)
135+
136+
# Insert chunk into the vector db
137+
chunks_with_embeddings = [
138+
Chunk(
139+
content="First chunk of text",
140+
mime_type="text/plain",
141+
metadata={"document_id": "doc1", "source": "precomputed"},
142+
embedding=embeddings_response.embeddings[0],
143+
),
144+
]
145+
rag_lls_client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
146+
147+
# Query the vector db to find the chunk
148+
chunks_response = rag_lls_client.vector_io.query(
149+
vector_db_id=vector_db_id, query="What do you know about..."
150+
)
151+
assert isinstance(chunks_response, QueryChunksResponse)
152+
assert len(chunks_response.chunks) > 0
153+
assert chunks_response.chunks[0].metadata["document_id"] == "doc1"
154+
assert chunks_response.chunks[0].metadata["source"] == "precomputed"
155+
156+
finally:
157+
# Cleanup: unregister the vector database to prevent resource leaks
158+
try:
159+
rag_lls_client.vector_dbs.unregister(vector_db_id)
160+
except Exception as e:
161+
LOGGER.warning(f"Failed to unregister vector database {vector_db_id}: {e}")
162+
80163
@pytest.mark.smoke
81164
def test_rag_simple_agent(self, rag_lls_client: LlamaStackClient) -> None:
82165
"""

0 commit comments

Comments
 (0)