Skip to content

Commit d0d41ec

Browse files
authored
feat: add support for llama-stack remote embeddings. Fix model_explainability tests (#973)
* feat: add support for remote and local embeddings in llama-stack - Added new test suite for LlamaStack Inference API covering chat completions and text completions. - Implemented tests for embedding functionality, validating response structure and dimensions for both single and multiple inputs. - Parameterized tests to support different embedding providers (vllm-embedding and sentence-transformers). - Improved environment variable management for embedding models in the configuration. Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * fix: fix llama-stack model explainability tests for recent changes Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> * feat: document new required env vars for llama-stack Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com> --------- Signed-off-by: Jorge Garcia Oncins <jgarciao@redhat.com>
1 parent beb40c8 commit d0d41ec

File tree

7 files changed

+250
-68
lines changed

7 files changed

+250
-68
lines changed

tests/llama_stack/README.md

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,43 @@ To add support for testing new LlamaStack API providers (e.g., a new vector_io p
4242

4343
### Required environment variables
4444

45-
LlamaStack tests require setting the following environment variables (for example in a .env file at the root folder):
45+
LlamaStack tests require setting the following environment variables (for example in a `.env` file at the root folder).
46+
47+
> **Note:** Most of these environment variables are added as `env_vars` in the LlamaStackDistribution CR, as they are required to configure the Red Hat LlamaStack Distribution's [run.yaml](https://github.com/opendatahub-io/llama-stack-distribution/blob/main/distribution/run.yaml).
4648
```bash
4749
OC_BINARY_PATH=/usr/local/sbin/oc # Optional
4850
LLS_CLIENT_VERIFY_SSL=false # Optional
51+
52+
# Core Inference Configuration
4953
LLS_CORE_VLLM_URL=<LLAMA-3.2-3b-ENDPOINT>/v1 (ends with /v1)
5054
LLS_CORE_INFERENCE_MODEL=<LLAMA-3.2-3b-MODEL_NAME>
5155
LLS_CORE_VLLM_API_TOKEN=<LLAMA-3.2-3b-TOKEN>
56+
LLS_CORE_VLLM_MAX_TOKENS=16384 # Optional
57+
LLS_CORE_VLLM_TLS_VERIFY=true # Optional
58+
59+
# Core Embedding Configuration
60+
LLS_CORE_EMBEDDING_MODEL=nomic-embed-text-v1-5 # Optional
61+
LLS_CORE_EMBEDDING_PROVIDER_MODEL_ID=nomic-embed-text-v1-5 # Optional
62+
LLS_CORE_VLLM_EMBEDDING_URL=<EMBEDDING-ENDPOINT>/v1 # Optional
63+
LLS_CORE_VLLM_EMBEDDING_API_TOKEN=<EMBEDDING-TOKEN> # Optional
64+
LLS_CORE_VLLM_EMBEDDING_MAX_TOKENS=8192 # Optional
65+
LLS_CORE_VLLM_EMBEDDING_TLS_VERIFY=true # Optional
66+
67+
# Vector I/O Configuration
5268
LLS_VECTOR_IO_MILVUS_IMAGE=<CUSTOM-MILVUS-IMAGE> # Optional
5369
LLS_VECTOR_IO_MILVUS_TOKEN=<CUSTOM-MILVUS-TOKEN> # Optional
5470
LLS_VECTOR_IO_ETCD_IMAGE=<CUSTOM-ETCD-IMAGE> # Optional
5571
LLS_VECTOR_IO_PGVECTOR_IMAGE=<CUSTOM-PGVECTOR-IMAGE> # Optional
5672
LLS_VECTOR_IO_PGVECTOR_USER=<CUSTOM-PGVECTOR-USER> # Optional
5773
LLS_VECTOR_IO_PGVECTOR_PASSWORD=<CUSTOM-PGVECTOR-PASSWORD> # Optional
74+
75+
# Red Hat Llama Stack Distribution requires PostgreSQL (replacing SQLite)
76+
LLS_VECTOR_IO_POSTGRES_IMAGE=<CUSTOM-POSTGRES-IMAGE> # Optional
77+
LLS_VECTOR_IO_POSTGRESQL_USER=ps_user # Optional
78+
LLS_VECTOR_IO_POSTGRESQL_PASSWORD=ps_password # Optional
79+
80+
# Files Provider Configuration
81+
LLS_FILES_S3_AUTO_CREATE_BUCKET=true # Optional
5882
```
5983

6084
### Run All Llama Stack Tests

tests/llama_stack/conftest.py

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@
4242
POSTGRESQL_USER = os.getenv("LLS_VECTOR_IO_POSTGRESQL_USER", "ps_user")
4343
POSTGRESQL_PASSWORD = os.getenv("LLS_VECTOR_IO_POSTGRESQL_PASSWORD", "ps_password")
4444

45+
LLS_CORE_INFERENCE_MODEL = os.getenv("LLS_CORE_INFERENCE_MODEL", "")
46+
LLS_CORE_VLLM_URL = os.getenv("LLS_CORE_VLLM_URL", "")
47+
LLS_CORE_VLLM_API_TOKEN = os.getenv("LLS_CORE_VLLM_API_TOKEN", "")
48+
LLS_CORE_VLLM_MAX_TOKENS = os.getenv("LLS_CORE_VLLM_MAX_TOKENS", "16384")
49+
LLS_CORE_VLLM_TLS_VERIFY = os.getenv("LLS_CORE_VLLM_TLS_VERIFY", "true")
50+
51+
LLS_CORE_EMBEDDING_MODEL = os.getenv("LLS_CORE_EMBEDDING_MODEL", "nomic-embed-text-v1-5")
52+
LLS_CORE_EMBEDDING_PROVIDER_MODEL_ID = os.getenv("LLS_CORE_EMBEDDING_PROVIDER_MODEL_ID", "nomic-embed-text-v1-5")
53+
LLS_CORE_VLLM_EMBEDDING_URL = os.getenv(
54+
"LLS_CORE_VLLM_EMBEDDING_URL", "https://nomic-embed-text-v1-5.example.com:443/v1"
55+
)
56+
LLS_CORE_VLLM_EMBEDDING_API_TOKEN = os.getenv("LLS_CORE_VLLM_EMBEDDING_API_TOKEN", "fake")
57+
LLS_CORE_VLLM_EMBEDDING_MAX_TOKENS = os.getenv("LLS_CORE_VLLM_EMBEDDING_MAX_TOKENS", "8192")
58+
LLS_CORE_VLLM_EMBEDDING_TLS_VERIFY = os.getenv("LLS_CORE_VLLM_EMBEDDING_TLS_VERIFY", "true")
59+
4560
distribution_name = generate_random_name(prefix="llama-stack-distribution")
4661

4762

@@ -113,8 +128,6 @@ def enabled_llama_stack_operator(dsc_resource: DataScienceCluster) -> Generator[
113128
@pytest.fixture(scope="class")
114129
def llama_stack_server_config(
115130
request: FixtureRequest,
116-
postgres_deployment: Deployment,
117-
postgres_service: Service,
118131
vector_io_provider_deployment_config_factory: Callable[[str], list[Dict[str, str]]],
119132
files_provider_config_factory: Callable[[str], list[Dict[str, str]]],
120133
) -> Dict[str, Any]:
@@ -186,25 +199,23 @@ def test_with_remote_milvus(llama_stack_server_config):
186199
if params.get("inference_model"):
187200
inference_model = str(params.get("inference_model"))
188201
else:
189-
inference_model = os.getenv("LLS_CORE_INFERENCE_MODEL", "")
202+
inference_model = LLS_CORE_INFERENCE_MODEL
190203
env_vars.append({"name": "INFERENCE_MODEL", "value": inference_model})
191204

192-
# VLLM_API_TOKEN
193205
if params.get("vllm_api_token"):
194206
vllm_api_token = str(params.get("vllm_api_token"))
195207
else:
196-
vllm_api_token = os.getenv("LLS_CORE_VLLM_API_TOKEN", "")
208+
vllm_api_token = LLS_CORE_VLLM_API_TOKEN
197209
env_vars.append({"name": "VLLM_API_TOKEN", "value": vllm_api_token})
198210

199-
# LLS_CORE_VLLM_URL
200211
if params.get("vllm_url_fixture"):
201212
vllm_url = str(request.getfixturevalue(argname=params.get("vllm_url_fixture")))
202213
else:
203-
vllm_url = os.getenv("LLS_CORE_VLLM_URL", "")
214+
vllm_url = LLS_CORE_VLLM_URL
204215
env_vars.append({"name": "VLLM_URL", "value": vllm_url})
205216

206-
# VLLM_TLS_VERIFY
207-
env_vars.append({"name": "VLLM_TLS_VERIFY", "value": "false"})
217+
env_vars.append({"name": "VLLM_TLS_VERIFY", "value": LLS_CORE_VLLM_TLS_VERIFY})
218+
env_vars.append({"name": "VLLM_MAX_TOKENS", "value": LLS_CORE_VLLM_MAX_TOKENS})
208219

209220
# FMS_ORCHESTRATOR_URL
210221
if params.get("fms_orchestrator_url_fixture"):
@@ -214,13 +225,25 @@ def test_with_remote_milvus(llama_stack_server_config):
214225
env_vars.append({"name": "FMS_ORCHESTRATOR_URL", "value": fms_orchestrator_url})
215226

216227
# EMBEDDING_MODEL
217-
embedding_model = params.get("embedding_model")
218-
if embedding_model:
219-
env_vars.append({"name": "EMBEDDING_MODEL", "value": embedding_model})
228+
embedding_provider = params.get("embedding_provider") or "vllm-embedding"
229+
230+
if embedding_provider == "vllm-embedding":
231+
env_vars.append({"name": "EMBEDDING_MODEL", "value": LLS_CORE_EMBEDDING_MODEL})
232+
env_vars.append({"name": "EMBEDDING_PROVIDER_MODEL_ID", "value": LLS_CORE_EMBEDDING_PROVIDER_MODEL_ID})
233+
env_vars.append({"name": "VLLM_EMBEDDING_URL", "value": LLS_CORE_VLLM_EMBEDDING_URL})
234+
env_vars.append({"name": "VLLM_EMBEDDING_API_TOKEN", "value": LLS_CORE_VLLM_EMBEDDING_API_TOKEN})
235+
env_vars.append({"name": "VLLM_EMBEDDING_MAX_TOKENS", "value": LLS_CORE_VLLM_EMBEDDING_MAX_TOKENS})
236+
env_vars.append({"name": "VLLM_EMBEDDING_TLS_VERIFY", "value": LLS_CORE_VLLM_EMBEDDING_TLS_VERIFY})
237+
elif embedding_provider == "sentence-transformers":
238+
env_vars.append({"name": "ENABLE_SENTENCE_TRANSFORMERS", "value": "true"})
239+
env_vars.append({"name": "EMBEDDING_PROVIDER", "value": "sentence-transformers"})
240+
else:
241+
raise ValueError(f"Unsupported embeddings provider: {embedding_provider}")
220242

221-
# Use inline::sentence-transformers embeddings provider
222-
env_vars.append({"name": "ENABLE_SENTENCE_TRANSFORMERS", "value": "true"})
223-
env_vars.append({"name": "EMBEDDING_PROVIDER", "value": "sentence-transformers"})
243+
# TRUSTYAI_EMBEDDING_MODEL
244+
trustyai_embedding_model = params.get("trustyai_embedding_model")
245+
if trustyai_embedding_model:
246+
env_vars.append({"name": "TRUSTYAI_EMBEDDING_MODEL", "value": trustyai_embedding_model})
224247

225248
# Kubeflow-related environment variables
226249
if params.get("enable_ragas_remote"):
@@ -314,6 +337,8 @@ def unprivileged_llama_stack_distribution(
314337
ci_s3_bucket_region: str,
315338
aws_access_key_id: str,
316339
aws_secret_access_key: str,
340+
unprivileged_postgres_deployment: Deployment,
341+
unprivileged_postgres_service: Service,
317342
) -> Generator[LlamaStackDistribution, None, None]:
318343
# Distribution name needs a random substring due to bug RHAIENG-999 / RHAIENG-1139
319344
distribution_name = generate_random_name(prefix="llama-stack-distribution")
@@ -359,6 +384,8 @@ def llama_stack_distribution(
359384
ci_s3_bucket_region: str,
360385
aws_access_key_id: str,
361386
aws_secret_access_key: str,
387+
postgres_deployment: Deployment,
388+
postgres_service: Service,
362389
) -> Generator[LlamaStackDistribution, None, None]:
363390
# Distribution name needs a random substring due to bug RHAIENG-999 / RHAIENG-1139
364391
with create_llama_stack_distribution(
@@ -604,22 +631,45 @@ def llama_stack_models(unprivileged_llama_stack_client: LlamaStackClient) -> Mod
604631
"""
605632
Returns model information from the LlamaStack client.
606633
634+
Selects the embedding model based on available providers with the following priority:
635+
1. sentence-transformers provider (if present)
636+
2. vllm-embedding provider (if present)
637+
607638
Provides:
608639
- model_id: The identifier of the LLM model
609-
- embedding_model: The embedding model object
640+
- embedding_model: The embedding model object from the selected provider
610641
- embedding_dimension: The dimension of the embedding model
611642
612643
Args:
613644
unprivileged_llama_stack_client: The configured LlamaStackClient
614645
615646
Returns:
616647
ModelInfo: NamedTuple containing model information
648+
649+
Raises:
650+
ValueError: If no embedding provider (sentence-transformers or vllm-embedding) is found
651+
617652
"""
618653
models = unprivileged_llama_stack_client.models.list()
654+
619655
model_id = next(m for m in models if m.api_model_type == "llm").identifier
620656

621-
embedding_model = next(m for m in models if m.api_model_type == "embedding")
622-
embedding_dimension = embedding_model.metadata["embedding_dimension"]
657+
# Ensure getting the right embedding model depending on the available providers
658+
providers = unprivileged_llama_stack_client.providers.list()
659+
provider_ids = [p.provider_id for p in providers]
660+
if "sentence-transformers" in provider_ids:
661+
target_provider_id = "sentence-transformers"
662+
elif "vllm-embedding" in provider_ids:
663+
target_provider_id = "vllm-embedding"
664+
else:
665+
raise ValueError("No embedding provider found")
666+
667+
embedding_model = next(m for m in models if m.api_model_type == "embedding" and m.provider_id == target_provider_id)
668+
embedding_dimension = float(embedding_model.metadata["embedding_dimension"])
669+
670+
LOGGER.info(f"Detected model: {model_id}")
671+
LOGGER.info(f"Detected embedding_model: {embedding_model.identifier}")
672+
LOGGER.info(f"Detected embedding_dimension: {embedding_dimension}")
623673

624674
return ModelInfo(model_id=model_id, embedding_model=embedding_model, embedding_dimension=embedding_dimension)
625675

@@ -705,12 +755,12 @@ def vector_store_with_example_docs(
705755

706756

707757
@pytest.fixture(scope="class")
708-
def postgres_service(
758+
def unprivileged_postgres_service(
709759
unprivileged_client: DynamicClient,
710760
unprivileged_model_namespace: Namespace,
711-
postgres_deployment: Deployment,
761+
unprivileged_postgres_deployment: Deployment,
712762
) -> Generator[Service, Any, Any]:
713-
"""Create a service for the postgres deployment."""
763+
"""Create a service for the unprivileged postgres deployment."""
714764
with Service(
715765
client=unprivileged_client,
716766
namespace=unprivileged_model_namespace.name,
@@ -728,11 +778,11 @@ def postgres_service(
728778

729779

730780
@pytest.fixture(scope="class")
731-
def postgres_deployment(
781+
def unprivileged_postgres_deployment(
732782
unprivileged_client: DynamicClient,
733783
unprivileged_model_namespace: Namespace,
734784
) -> Generator[Deployment, Any, Any]:
735-
"""Deploy a Postgres instance for vector I/O provider testing."""
785+
"""Deploy a Postgres instance for vector I/O provider testing with unprivileged client."""
736786
with Deployment(
737787
client=unprivileged_client,
738788
namespace=unprivileged_model_namespace.name,
@@ -748,6 +798,50 @@ def postgres_deployment(
748798
yield deployment
749799

750800

801+
@pytest.fixture(scope="class")
802+
def postgres_service(
803+
admin_client: DynamicClient,
804+
model_namespace: Namespace,
805+
postgres_deployment: Deployment,
806+
) -> Generator[Service, Any, Any]:
807+
"""Create a service for the postgres deployment."""
808+
with Service(
809+
client=admin_client,
810+
namespace=model_namespace.name,
811+
name="vector-io-postgres-service",
812+
ports=[
813+
{
814+
"port": 5432,
815+
"targetPort": 5432,
816+
}
817+
],
818+
selector={"app": "postgres"},
819+
wait_for_resource=True,
820+
) as service:
821+
yield service
822+
823+
824+
@pytest.fixture(scope="class")
825+
def postgres_deployment(
826+
admin_client: DynamicClient,
827+
model_namespace: Namespace,
828+
) -> Generator[Deployment, Any, Any]:
829+
"""Deploy a Postgres instance for vector I/O provider testing."""
830+
with Deployment(
831+
client=admin_client,
832+
namespace=model_namespace.name,
833+
name="vector-io-postgres-deployment",
834+
min_ready_seconds=5,
835+
replicas=1,
836+
selector={"matchLabels": {"app": "postgres"}},
837+
strategy={"type": "Recreate"},
838+
template=get_postgres_deployment_template(),
839+
teardown=True,
840+
) as deployment:
841+
deployment.wait_for_replicas(deployed=True, timeout=240)
842+
yield deployment
843+
844+
751845
def get_postgres_deployment_template() -> Dict[str, Any]:
752846
"""Return a Kubernetes deployment for PostgreSQL"""
753847
return {

tests/llama_stack/eval/test_lmeval_provider.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
{
2020
"vllm_url_fixture": "qwen_isvc_url",
2121
"inference_model": QWEN_MODEL_NAME,
22+
"embedding_provider": "sentence-transformers",
2223
},
2324
)
2425
],
@@ -82,7 +83,11 @@ def test_llamastack_run_eval(
8283
{"name": "test-llamastack-lmeval-custom"},
8384
MinIo.PodConfig.QWEN_HAP_BPIV2_MINIO_CONFIG,
8485
{"bucket": "llms"},
85-
{"vllm_url_fixture": "qwen_isvc_url", "inference_model": QWEN_MODEL_NAME},
86+
{
87+
"vllm_url_fixture": "qwen_isvc_url",
88+
"inference_model": QWEN_MODEL_NAME,
89+
"embedding_provider": "sentence-transformers",
90+
},
8691
)
8792
],
8893
indirect=True,

tests/llama_stack/eval/test_ragas_provider.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
{
3131
"vllm_url_fixture": "qwen_isvc_url",
3232
"inference_model": QWEN_MODEL_NAME,
33-
"embedding_model": "granite-embedding-125m",
33+
"embedding_provider": "sentence-transformers",
34+
"trustyai_embedding_model": "granite-embedding-125m-english",
3435
},
3536
)
3637
],
@@ -105,7 +106,8 @@ def test_ragas_inline_run_eval(self, minio_pod, minio_data_connection, llama_sta
105106
{
106107
"vllm_url_fixture": "qwen_isvc_url",
107108
"inference_model": QWEN_MODEL_NAME,
108-
"embedding_model": "granite-embedding-125m",
109+
"embedding_provider": "sentence-transformers",
110+
"trustyai_embedding_model": "granite-embedding-125m-english",
109111
"enable_ragas_remote": True,
110112
},
111113
)
Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import pytest
22
from llama_stack_client import LlamaStackClient
3-
from llama_stack_client.types import CreateEmbeddingsResponse
43
from tests.llama_stack.constants import ModelInfo
54

65

76
@pytest.mark.parametrize(
87
"unprivileged_model_namespace",
98
[
109
pytest.param(
11-
{"name": "test-llamastack-inference", "randomize_name": True},
10+
{"name": "test-llamastack-infer-completions", "randomize_name": True},
1211
),
1312
],
1413
indirect=True,
1514
)
1615
@pytest.mark.llama_stack
17-
class TestLlamaStackInference:
18-
"""Test class for LlamaStack Inference API (chat_completion, completion and embeddings)
16+
class TestLlamaStackInferenceCompletions:
17+
"""Test class for LlamaStack Inference API for Chat Completions and Completions
1918
2019
For more information about this API, see:
2120
- https://llamastack.github.io/docs/references/python_sdk_reference#inference
@@ -60,40 +59,3 @@ def test_inference_completion(
6059
content = response.choices[0].text.lower()
6160
assert content is not None, "LLM response content is None"
6261
assert "barcelona" in content, "The LLM didn't provide the expected answer to the prompt"
63-
64-
@pytest.mark.smoke
65-
def test_inference_embeddings(
66-
self,
67-
unprivileged_llama_stack_client: LlamaStackClient,
68-
llama_stack_models: ModelInfo,
69-
) -> None:
70-
"""
71-
Test embedding model functionality and vector generation.
72-
73-
Validates that the server can generate properly formatted embedding vectors
74-
for text input with correct dimensions as specified in model metadata.
75-
"""
76-
77-
embeddings_response = unprivileged_llama_stack_client.embeddings.create(
78-
model=llama_stack_models.embedding_model.identifier,
79-
input="The food was delicious and the waiter...",
80-
encoding_format="float",
81-
)
82-
83-
assert isinstance(embeddings_response, CreateEmbeddingsResponse)
84-
assert len(embeddings_response.data) == 1
85-
assert isinstance(embeddings_response.data[0].embedding, list)
86-
assert llama_stack_models.embedding_dimension == len(embeddings_response.data[0].embedding)
87-
assert isinstance(embeddings_response.data[0].embedding[0], float)
88-
89-
input_list = ["Input text 1", "Input text 1", "Input text 1"]
90-
embeddings_response = unprivileged_llama_stack_client.embeddings.create(
91-
model=llama_stack_models.embedding_model.identifier, input=input_list, encoding_format="float"
92-
)
93-
94-
assert isinstance(embeddings_response, CreateEmbeddingsResponse)
95-
assert len(embeddings_response.data) == len(input_list)
96-
for item in range(len(input_list)):
97-
assert isinstance(embeddings_response.data[item].embedding, list)
98-
assert llama_stack_models.embedding_dimension == len(embeddings_response.data[item].embedding)
99-
assert isinstance(embeddings_response.data[item].embedding[0], float)

0 commit comments

Comments
 (0)