Skip to content

Commit f0a04a1

Browse files
committed
test: add new tests for embedding models
Signed-off-by: Debarati Basu-Nag <dbasunag@redhat.com>
1 parent abd0153 commit f0a04a1

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
from tests.model_registry.model_catalog.utils import get_models_from_catalog_api
6+
7+
8+
@pytest.fixture(scope="class")
9+
def embedding_models_response(
10+
model_catalog_rest_url: list[str],
11+
model_registry_rest_headers: dict[str, str],
12+
) -> dict[str, Any]:
13+
"""Fetch models filtered by tasks='text-embedding' via filterQuery"""
14+
return get_models_from_catalog_api(
15+
model_catalog_rest_url=model_catalog_rest_url,
16+
model_registry_rest_headers=model_registry_rest_headers,
17+
additional_params="&filterQuery=tasks='text-embedding'",
18+
)

tests/model_registry/model_catalog/search/test_model_search.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from simple_logger.logger import get_logger
88

99
from tests.model_registry.model_catalog.constants import (
10+
OTHER_MODELS_CATALOG_ID,
1011
REDHAT_AI_CATALOG_ID,
1112
REDHAT_AI_CATALOG_NAME,
1213
REDHAT_AI_VALIDATED_UNESCAPED_CATALOG_NAME,
@@ -438,3 +439,48 @@ def test_filter_query_advanced_model_search(
438439
LOGGER.info(
439440
f"Advanced {logic_type.upper()} filter validation completed for {len(models_from_filter_query)} models"
440441
)
442+
443+
444+
@pytest.mark.install
445+
@pytest.mark.post_upgrade
446+
class TestEmbeddingModelSearch:
447+
def test_filter_query_by_text_embedding_task(
448+
self: Self,
449+
embedding_models_response: dict[str, Any],
450+
):
451+
"""
452+
Validate filterQuery with tasks='text-embedding' returns models
453+
"""
454+
number_of_models = embedding_models_response.get("size", 0)
455+
LOGGER.info(f"Found number of embedding models: {number_of_models}")
456+
assert number_of_models > 0, "Expected at least one model with tasks='text-embedding'"
457+
458+
def test_embedding_models_source_id(
459+
self: Self,
460+
embedding_models_response: dict[str, Any],
461+
):
462+
"""
463+
Validate all embedding models belong to the Other Models source
464+
"""
465+
mismatched_models = [
466+
f"{model['name']} (source_id={model['source_id']})"
467+
for model in embedding_models_response.get("items", [])
468+
if model["source_id"] != OTHER_MODELS_CATALOG_ID
469+
]
470+
assert not mismatched_models, (
471+
f"Models with unexpected source_id (expected '{OTHER_MODELS_CATALOG_ID}'): {mismatched_models}"
472+
)
473+
474+
def test_embedding_models_have_text_embedding_task(
475+
self: Self,
476+
embedding_models_response: dict[str, Any],
477+
):
478+
"""
479+
Validate all returned models have 'text-embedding' in their tasks
480+
"""
481+
models_missing_task = [
482+
model["name"]
483+
for model in embedding_models_response.get("items", [])
484+
if "text-embedding" not in model.get("tasks", [])
485+
]
486+
assert not models_missing_task, f"Models missing 'text-embedding' task: {models_missing_task}"

0 commit comments

Comments
 (0)