Skip to content

Commit c2ad21f

Browse files
dbasunagssaleem-rh
authored andcommitted
test: add new tests for embedding models (opendatahub-io#1254)
* test: add new tests for embedding models Signed-off-by: Debarati Basu-Nag <dbasunag@redhat.com> * fix: address review comments Signed-off-by: Debarati Basu-Nag <dbasunag@redhat.com> --------- Signed-off-by: Debarati Basu-Nag <dbasunag@redhat.com> Signed-off-by: Shehan Saleem <ssaleem@redhat.com>
1 parent b1c69a4 commit c2ad21f

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
from tests.model_registry.model_catalog.utils import get_models_from_catalog_api
6+
7+
8+
@pytest.fixture(scope="class")
9+
def embedding_models_response(
10+
model_catalog_rest_url: list[str],
11+
model_registry_rest_headers: dict[str, str],
12+
) -> dict[str, Any]:
13+
"""Fetch models filtered by tasks='text-embedding' via filterQuery"""
14+
return get_models_from_catalog_api(
15+
model_catalog_rest_url=model_catalog_rest_url,
16+
model_registry_rest_headers=model_registry_rest_headers,
17+
additional_params="&filterQuery=tasks='text-embedding'",
18+
)

tests/model_registry/model_catalog/search/test_model_search.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from simple_logger.logger import get_logger
88

99
from tests.model_registry.model_catalog.constants import (
10+
OTHER_MODELS_CATALOG_ID,
1011
REDHAT_AI_CATALOG_ID,
1112
REDHAT_AI_CATALOG_NAME,
1213
REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME,
@@ -438,3 +439,51 @@ def test_filter_query_advanced_model_search(
438439
LOGGER.info(
439440
f"Advanced {logic_type.upper()} filter validation completed for {len(models_from_filter_query)} models"
440441
)
442+
443+
444+
@pytest.mark.install
445+
@pytest.mark.post_upgrade
446+
class TestEmbeddingModelSearch:
447+
@pytest.mark.dependency(name="test_filter_query_by_text_embedding_task")
448+
def test_filter_query_by_text_embedding_task(
449+
self: Self,
450+
embedding_models_response: dict[str, Any],
451+
):
452+
"""
453+
Validate filterQuery with tasks='text-embedding' returns models
454+
"""
455+
number_of_models = embedding_models_response.get("size", 0)
456+
LOGGER.info(f"Found number of embedding models: {number_of_models}")
457+
assert number_of_models > 0, "Expected at least one model with tasks='text-embedding'"
458+
459+
@pytest.mark.dependency(depends=["test_filter_query_by_text_embedding_task"])
460+
def test_embedding_models_source_id(
461+
self: Self,
462+
embedding_models_response: dict[str, Any],
463+
):
464+
"""
465+
Validate all embedding models belong to the Other Models source
466+
"""
467+
mismatched_models = [
468+
f"{model['name']} (source_id={model['source_id']})"
469+
for model in embedding_models_response.get("items", [])
470+
if model["source_id"] != OTHER_MODELS_CATALOG_ID
471+
]
472+
assert not mismatched_models, (
473+
f"Models with unexpected source_id (expected '{OTHER_MODELS_CATALOG_ID}'): {mismatched_models}"
474+
)
475+
476+
@pytest.mark.dependency(depends=["test_filter_query_by_text_embedding_task"])
477+
def test_embedding_models_have_text_embedding_task(
478+
self: Self,
479+
embedding_models_response: dict[str, Any],
480+
):
481+
"""
482+
Validate all returned models have 'text-embedding' in their tasks
483+
"""
484+
models_missing_task = [
485+
model["name"]
486+
for model in embedding_models_response.get("items", [])
487+
if "text-embedding" not in model.get("tasks", [])
488+
]
489+
assert not models_missing_task, f"Models missing 'text-embedding' task: {models_missing_task}"

0 commit comments

Comments
 (0)