Skip to content

Commit aba0370

Browse files
authored
fix: simplify model search using helper function (#743)
* fix: simplify model search using helper function Replace direct API calls with reusable `get_models_from_api` helper function to reduce code duplication in test_model_search.py. Also replace dynamic catalog ID lookups with explicit string constants for better readability. * fix: fix typos and type * fix: revert lock file
1 parent c8db8ac commit aba0370

File tree

3 files changed

+73
-52
lines changed

3 files changed

+73
-52
lines changed

tests/model_registry/model_catalog/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@
3030
"labels": [REDHAT_AI_VALIDATED_CATALOG_NAME],
3131
},
3232
}
33-
REDHAT_AI_CATALOG_ID: str = next(iter(DEFAULT_CATALOGS))
33+
REDHAT_AI_CATALOG_ID: str = "redhat_ai_models"
3434
DEFAULT_CATALOG_FILE: str = DEFAULT_CATALOGS[REDHAT_AI_CATALOG_ID]["properties"]["yamlCatalogPath"]
35-
VALIDATED_CATALOG_ID: str = tuple(DEFAULT_CATALOGS.keys())[1]
35+
VALIDATED_CATALOG_ID: str = "redhat_ai_validated_models"
3636

3737
REDHAT_AI_FILTER: str = "Red+Hat+AI"
3838
REDHAT_AI_VALIDATED_FILTER = "Red+Hat+AI+Validated"

tests/model_registry/model_catalog/test_model_search.py

Lines changed: 39 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
REDHAT_AI_CATALOG_ID,
1010
VALIDATED_CATALOG_ID,
1111
)
12-
from tests.model_registry.utils import (
13-
execute_get_command,
12+
from tests.model_registry.model_catalog.utils import (
13+
get_models_from_api,
1414
)
1515

1616
LOGGER = get_logger(name=__name__)
@@ -22,59 +22,53 @@
2222
class TestSearchModelCatalog:
2323
@pytest.mark.smoke
2424
def test_search_model_catalog_source_label(
25-
self: Self, model_catalog_rest_url: list[str], model_registry_rest_headers: str
25+
self: Self, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str]
2626
):
2727
"""
2828
RHOAIENG-33656: Validate search model catalog by source label
2929
"""
3030

31-
result = execute_get_command(
32-
url=f"{model_catalog_rest_url[0]}models?sourceLabel={REDHAT_AI_FILTER}&pageSize=100",
33-
headers=model_registry_rest_headers,
34-
)
35-
redhai_ai_filter_moldels_size = result["size"]
36-
37-
result = execute_get_command(
38-
url=f"{model_catalog_rest_url[0]}models?sourceLabel={REDHAT_AI_VALIDATED_FILTER}&pageSize=100",
39-
headers=model_registry_rest_headers,
40-
)
41-
redhai_ai_validated_filter_models_size = result["size"]
42-
43-
result = execute_get_command(
44-
url=f"{model_catalog_rest_url[0]}models?pageSize=100", headers=model_registry_rest_headers
45-
)
46-
no_filtered_models_size = result["size"]
47-
48-
result = execute_get_command(
49-
url=(
50-
f"{model_catalog_rest_url[0]}models?"
51-
f"sourceLabel={REDHAT_AI_VALIDATED_FILTER},{REDHAT_AI_FILTER}&pageSize=100"
52-
),
53-
headers=model_registry_rest_headers,
54-
)
55-
both_filtered_models_size = result["size"]
31+
redhat_ai_filter_moldels_size = get_models_from_api(
32+
model_catalog_rest_url=model_catalog_rest_url,
33+
model_registry_rest_headers=model_registry_rest_headers,
34+
source_label=REDHAT_AI_FILTER,
35+
)["size"]
36+
redhat_ai_validated_filter_models_size = get_models_from_api(
37+
model_catalog_rest_url=model_catalog_rest_url,
38+
model_registry_rest_headers=model_registry_rest_headers,
39+
source_label=REDHAT_AI_VALIDATED_FILTER,
40+
)["size"]
41+
no_filtered_models_size = get_models_from_api(
42+
model_catalog_rest_url=model_catalog_rest_url, model_registry_rest_headers=model_registry_rest_headers
43+
)["size"]
44+
both_filtered_models_size = get_models_from_api(
45+
model_catalog_rest_url=model_catalog_rest_url,
46+
model_registry_rest_headers=model_registry_rest_headers,
47+
source_label=f"{REDHAT_AI_VALIDATED_FILTER},{REDHAT_AI_FILTER}",
48+
)["size"]
5649

5750
assert no_filtered_models_size == both_filtered_models_size
58-
assert redhai_ai_filter_moldels_size + redhai_ai_validated_filter_models_size == both_filtered_models_size
51+
assert redhat_ai_filter_moldels_size + redhat_ai_validated_filter_models_size == both_filtered_models_size
5952

6053
def test_search_model_catalog_invalid_source_label(
61-
self: Self, model_catalog_rest_url: list[str], model_registry_rest_headers: str
54+
self: Self, model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str]
6255
):
6356
"""
6457
RHOAIENG-33656:
6558
Validate search model catalog by invalid source label
6659
"""
6760

68-
result = execute_get_command(
69-
url=f"{model_catalog_rest_url[0]}/models?sourceLabel=null&pageSize=100", headers=model_registry_rest_headers
70-
)
71-
null_size = result["size"]
61+
null_size = get_models_from_api(
62+
model_catalog_rest_url=model_catalog_rest_url,
63+
model_registry_rest_headers=model_registry_rest_headers,
64+
source_label="null",
65+
)["size"]
7266

73-
result = execute_get_command(
74-
url=f"{model_catalog_rest_url[0]}/models?sourceLabel=invalid&pageSize=100",
75-
headers=model_registry_rest_headers,
76-
)
77-
invalid_size = result["size"]
67+
invalid_size = get_models_from_api(
68+
model_catalog_rest_url=model_catalog_rest_url,
69+
model_registry_rest_headers=model_registry_rest_headers,
70+
source_label="invalid",
71+
)["size"]
7872

7973
assert null_size == invalid_size == 0, (
8074
"Expected 0 models for null and invalid source label found {null_size} and {invalid_size}"
@@ -97,7 +91,7 @@ def test_search_model_catalog_invalid_source_label(
9791
def test_search_model_catalog_match(
9892
self: Self,
9993
model_catalog_rest_url: list[str],
100-
model_registry_rest_headers: str,
94+
model_registry_rest_headers: dict[str, str],
10195
randomly_picked_model: dict[Any, Any],
10296
source_filter: str,
10397
):
@@ -107,13 +101,11 @@ def test_search_model_catalog_match(
107101
random_model = randomly_picked_model
108102
random_model_name = random_model["name"]
109103
LOGGER.info(f"random_model_name: {random_model_name}")
110-
result = execute_get_command(
111-
url=(
112-
f"{model_catalog_rest_url[0]}/models?"
113-
f"sourceLabel={source_filter}&"
114-
f"filterQuery=name='{random_model_name}'&pageSize=100"
115-
),
116-
headers=model_registry_rest_headers,
104+
result = get_models_from_api(
105+
model_catalog_rest_url=model_catalog_rest_url,
106+
model_registry_rest_headers=model_registry_rest_headers,
107+
source_label=source_filter,
108+
additional_params=f"&filterQuery=name='{random_model_name}'",
117109
)
118110
assert random_model_name == result["items"][0]["name"]
119111
assert result["size"] == 1

tests/model_registry/model_catalog/utils.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77

88
from ocp_resources.pod import Pod
99
from ocp_resources.config_map import ConfigMap
10-
from tests.model_registry.model_catalog.constants import (
11-
DEFAULT_CATALOGS,
12-
)
10+
from tests.model_registry.model_catalog.constants import DEFAULT_CATALOGS
11+
from tests.model_registry.utils import execute_get_command
1312

1413
LOGGER = get_logger(name=__name__)
1514

@@ -243,3 +242,33 @@ def validate_model_catalog_configmap_data(configmap: ConfigMap, num_catalogs: in
243242
assert len(catalogs) == num_catalogs, f"{configmap.name} should have {num_catalogs} catalog"
244243
if num_catalogs:
245244
validate_default_catalog(catalogs=catalogs)
245+
246+
247+
def get_models_from_api(
248+
model_catalog_rest_url: list[str],
249+
model_registry_rest_headers: dict[str, str],
250+
page_size: int = 100,
251+
source_label: str | None = None,
252+
additional_params: str = "",
253+
) -> dict[str, Any]:
254+
"""
255+
Helper method to get models from API with optional filtering
256+
257+
Args:
258+
model_catalog_rest_url: REST URL for model catalog
259+
model_registry_rest_headers: Headers for model registry REST API
260+
page_size: Number of results per page
261+
source_label: Source label(s) to filter by (must be comma-separated for multiple filters)
262+
additional_params: Additional query parameters (e.g., "&filterQuery=name='model_name'")
263+
264+
Returns:
265+
Dictionary containing the API response
266+
"""
267+
url = f"{model_catalog_rest_url[0]}models?pageSize={page_size}"
268+
269+
if source_label:
270+
url += f"&sourceLabel={source_label}"
271+
272+
url += additional_params
273+
274+
return execute_get_command(url=url, headers=model_registry_rest_headers)

0 commit comments

Comments
 (0)