33
44import pytest
55from llama_stack_client import Agent , LlamaStackClient , RAGDocument
6+ from llama_stack_client .types import EmbeddingsResponse , QueryChunksResponse
7+ from llama_stack_client .types .vector_io_insert_params import Chunk
68from ocp_resources .deployment import Deployment
79from simple_logger .logger import get_logger
810
@@ -50,7 +52,7 @@ def test_llama_stack_server(
5052 assert embedding_dimension is not None , "No embedding_dimension set in embedding model"
5153
5254 @pytest .mark .smoke
53- def test_rag_basic_inference (self , rag_lls_client : LlamaStackClient ) -> None :
55+ def test_rag_chat_completions (self , rag_lls_client : LlamaStackClient ) -> None :
5456 """
5557 Test basic chat completion inference through LlamaStack client.
5658
@@ -77,6 +79,87 @@ def test_rag_basic_inference(self, rag_lls_client: LlamaStackClient) -> None:
7779 assert content is not None , "LLM response content is None"
7880 assert "Paris" in content , "The LLM didn't provide the expected answer to the prompt"
7981
82+ @pytest .mark .smoke
83+ def test_rag_inference_embeddings (self , rag_lls_client : LlamaStackClient ) -> None :
84+ """
85+ Test embedding model functionality and vector generation.
86+
87+ Validates that the server can generate properly formatted embedding vectors
88+ for text input with correct dimensions as specified in model metadata.
89+ """
90+ models = rag_lls_client .models .list ()
91+ embedding_model = next (m for m in models if m .api_model_type == "embedding" )
92+ embedding_dimension = embedding_model .metadata ["embedding_dimension" ]
93+
94+ embeddings_response = rag_lls_client .inference .embeddings (
95+ model_id = embedding_model .identifier ,
96+ contents = ["First chunk of text" ],
97+ output_dimension = embedding_dimension , # type: ignore
98+ )
99+ assert isinstance (embeddings_response , EmbeddingsResponse )
100+ assert len (embeddings_response .embeddings ) == 1
101+ assert isinstance (embeddings_response .embeddings [0 ], list )
102+ assert isinstance (embeddings_response .embeddings [0 ][0 ], float )
103+
104+ @pytest .mark .smoke
105+ def test_rag_vector_io_ingestion_retrieval (self , rag_lls_client : LlamaStackClient ) -> None :
106+ """
107+ Validates basic vector_db API in llama-stack using milvus
108+
109+ Tests registering, inserting and retrieving information from a milvus vector db database
110+
111+ Based on the example available at
112+ https://llama-stack.readthedocs.io/en/latest/building_applications/rag.html
113+ """
114+ models = rag_lls_client .models .list ()
115+ embedding_model = next (m for m in models if m .api_model_type == "embedding" )
116+ embedding_dimension = embedding_model .metadata ["embedding_dimension" ]
117+
118+ # Create a vector database instance
119+ vector_db_id = f"v{ uuid .uuid4 ().hex } "
120+
121+ try :
122+ rag_lls_client .vector_dbs .register (
123+ vector_db_id = vector_db_id ,
124+ embedding_model = embedding_model .identifier ,
125+ embedding_dimension = embedding_dimension , # type: ignore
126+ provider_id = "milvus" ,
127+ )
128+
129+ # Calculate embeddings
130+ embeddings_response = rag_lls_client .inference .embeddings (
131+ model_id = embedding_model .identifier ,
132+ contents = ["First chunk of text" ],
133+ output_dimension = embedding_dimension , # type: ignore
134+ )
135+
136+ # Insert chunk into the vector db
137+ chunks_with_embeddings = [
138+ Chunk (
139+ content = "First chunk of text" ,
140+ mime_type = "text/plain" ,
141+ metadata = {"document_id" : "doc1" , "source" : "precomputed" },
142+ embedding = embeddings_response .embeddings [0 ],
143+ ),
144+ ]
145+ rag_lls_client .vector_io .insert (vector_db_id = vector_db_id , chunks = chunks_with_embeddings )
146+
147+ # Query the vector db to find the chunk
148+ chunks_response = rag_lls_client .vector_io .query (
149+ vector_db_id = vector_db_id , query = "What do you know about..."
150+ )
151+ assert isinstance (chunks_response , QueryChunksResponse )
152+ assert len (chunks_response .chunks ) > 0
153+ assert chunks_response .chunks [0 ].metadata ["document_id" ] == "doc1"
154+ assert chunks_response .chunks [0 ].metadata ["source" ] == "precomputed"
155+
156+ finally :
157+ # Cleanup: unregister the vector database to prevent resource leaks
158+ try :
159+ rag_lls_client .vector_dbs .unregister (vector_db_id )
160+ except Exception as e :
161+ LOGGER .warning (f"Failed to unregister vector database { vector_db_id } : { e } " )
162+
80163 @pytest .mark .smoke
81164 def test_rag_simple_agent (self , rag_lls_client : LlamaStackClient ) -> None :
82165 """
0 commit comments