Skip to content

Commit 76ec8f7

Browse files
authored
Add HF positive tests with wildcard (#1003)
* Add HF positive tests with wildcard * Updates to add checks to catalog pod log * updates based on review comments
1 parent 74dc788 commit 76ec8f7

File tree

3 files changed

+172
-3
lines changed

3 files changed

+172
-3
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,36 @@
11
import pytest
22
from huggingface_hub import HfApi
3+
from simple_logger.logger import get_logger
4+
5+
LOGGER = get_logger(name=__name__)
36

47

58
@pytest.fixture()
69
def huggingface_api():
710
return HfApi()
11+
12+
13+
@pytest.fixture()
14+
def num_models_from_hf_api_with_matching_criteria(request: pytest.FixtureRequest, huggingface_api: HfApi) -> int:
15+
excluded_str = request.param.get("excluded_str")
16+
included_str = request.param.get("included_str")
17+
models = huggingface_api.list_models(author=request.param["org_name"], limit=10000)
18+
model_list = []
19+
for model in models:
20+
if excluded_str:
21+
if model.id.endswith(excluded_str):
22+
LOGGER.info(f"Skipping {model.id} due to {excluded_str}")
23+
continue
24+
else:
25+
LOGGER.info(f"Adding {model.id}")
26+
model_list.append(model.id)
27+
elif included_str:
28+
if model.id.startswith(included_str):
29+
LOGGER.info(f"Adding {model.id}")
30+
model_list.append(model.id)
31+
else:
32+
LOGGER.info(f"Skipping {model.id} due to {included_str}")
33+
continue
34+
else:
35+
model_list.append(model.id)
36+
return len(model_list)

tests/model_registry/model_catalog/huggingface/test_huggingface_model_validation.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
from typing import Self
33
from ocp_resources.config_map import ConfigMap
44
from simple_logger.logger import get_logger
5-
65
from tests.model_registry.model_catalog.constants import HF_MODELS
76
from tests.model_registry.model_catalog.utils import (
87
get_hf_catalog_str,
98
)
109
from tests.model_registry.model_catalog.huggingface.utils import (
1110
assert_huggingface_values_matches_model_catalog_api_values,
11+
wait_for_huggingface_retrival_match,
12+
wait_for_hugging_face_model_import,
1213
)
14+
from kubernetes.dynamic import DynamicClient
1315

1416
LOGGER = get_logger(name=__name__)
1517

@@ -54,3 +56,96 @@ def test_huggingface_model_metadata(
5456
expected_catalog_values=expected_catalog_values,
5557
huggingface_api=huggingface_api,
5658
)
59+
60+
61+
class TestHFPatternMatching:
62+
@pytest.mark.parametrize(
63+
"updated_catalog_config_map_scope_function, num_models_from_hf_api_with_matching_criteria",
64+
[
65+
pytest.param(
66+
"""
67+
catalogs:
68+
- name: HuggingFace Hub
69+
id: hf_id
70+
type: hf
71+
enabled: true
72+
includedModels:
73+
- huggingface-course/*
74+
""",
75+
{"org_name": "huggingface-course", "excluded_str": None},
76+
id="test_hf_source_wildcard",
77+
),
78+
pytest.param(
79+
"""
80+
catalogs:
81+
- name: HuggingFace Hub
82+
id: hf_id
83+
type: hf
84+
enabled: true
85+
properties:
86+
allowedOrganization: "huggingface-course"
87+
includedModels:
88+
- '*'
89+
""",
90+
{"org_name": "huggingface-course", "excluded_str": None},
91+
id="test_hf_source_allowed_org",
92+
),
93+
pytest.param(
94+
"""
95+
catalogs:
96+
- name: HuggingFace Hub
97+
id: hf_id
98+
type: hf
99+
enabled: true
100+
properties:
101+
allowedOrganization: "huggingface-course"
102+
includedModels:
103+
- '*'
104+
excludedModels:
105+
- '*-accelerate'
106+
""",
107+
{"org_name": "huggingface-course", "excluded_str": "-accelerate"},
108+
id="test_hf_source_allowed_org_exclude",
109+
),
110+
pytest.param(
111+
"""
112+
catalogs:
113+
- name: HuggingFace Hub
114+
id: hf_id
115+
type: hf
116+
enabled: true
117+
includedModels:
118+
- 'ibm-granite/granite-4.0-micro*'
119+
""",
120+
{"org_name": "ibm-granite", "included_str": "ibm-granite/granite-4.0-micro"},
121+
id="test_hf_source_allowed_org_include",
122+
),
123+
],
124+
indirect=True,
125+
)
126+
def test_hugging_face_models(
127+
self: Self,
128+
admin_client: DynamicClient,
129+
model_registry_namespace: str,
130+
updated_catalog_config_map_scope_function: ConfigMap,
131+
model_catalog_rest_url: list[str],
132+
model_registry_rest_headers: dict[str, str],
133+
huggingface_api: bool,
134+
num_models_from_hf_api_with_matching_criteria: int,
135+
):
136+
"""
137+
Test that excluded models do not appear in the catalog API response
138+
"""
139+
LOGGER.info("Testing HuggingFace model exclusion functionality")
140+
wait_for_hugging_face_model_import(
141+
admin_client=admin_client,
142+
model_registry_namespace=model_registry_namespace,
143+
hf_id="hf_id",
144+
expected_num_models_from_hf_api=num_models_from_hf_api_with_matching_criteria,
145+
)
146+
wait_for_huggingface_retrival_match(
147+
source_id="hf_id",
148+
model_registry_rest_headers=model_registry_rest_headers,
149+
model_catalog_rest_url=model_catalog_rest_url,
150+
expected_num_models_from_hf_api=num_models_from_hf_api_with_matching_criteria,
151+
)

tests/model_registry/model_catalog/huggingface/utils.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import ast
22
from typing import Any
3+
from simple_logger.logger import get_logger
34

45
from tests.model_registry.model_catalog.constants import HF_SOURCE_ID
5-
from tests.model_registry.model_catalog.utils import LOGGER
6-
from tests.model_registry.utils import execute_get_command
6+
from tests.model_registry.utils import execute_get_command, get_model_catalog_pod
77
from huggingface_hub import HfApi
8+
from timeout_sampler import retry
9+
from kubernetes.dynamic import DynamicClient
10+
11+
LOGGER = get_logger(name=__name__)
812

913

1014
def get_huggingface_model_params(model_name: str, huggingface_api: HfApi) -> dict[str, Any]:
@@ -95,3 +99,44 @@ def assert_huggingface_values_matches_model_catalog_api_values(
9599
if mismatch:
96100
LOGGER.error(f"mismatches are: {mismatch}")
97101
raise AssertionError("HF api call and model catalog hf models has value mismatch")
102+
103+
104+
@retry(wait_timeout=60, sleep=5)
105+
def wait_for_huggingface_retrival_match(
106+
source_id: str,
107+
model_catalog_rest_url: list[str],
108+
model_registry_rest_headers: dict[str, str],
109+
expected_num_models_from_hf_api: int,
110+
) -> bool | None:
111+
# Get all models from the catalog API for the given source
112+
url = f"{model_catalog_rest_url[0]}models?source={source_id}&pageSize=1000"
113+
response = execute_get_command(
114+
url=url,
115+
headers=model_registry_rest_headers,
116+
)
117+
LOGGER.info(f"response: {response['size']}")
118+
models_response = [model["name"] for model in response["items"]]
119+
if int(response["size"]) == expected_num_models_from_hf_api:
120+
LOGGER.info("All models present in the catalog API.")
121+
return True
122+
LOGGER.warning(
123+
f"Expected {expected_num_models_from_hf_api} "
124+
"models to be present in response. "
125+
f"Found {response['size']}. Models in "
126+
f"response: {models_response}"
127+
)
128+
129+
130+
@retry(wait_timeout=60, sleep=5)
131+
def wait_for_hugging_face_model_import(
132+
admin_client: DynamicClient, model_registry_namespace: str, hf_id: str, expected_num_models_from_hf_api: int
133+
) -> bool:
134+
LOGGER.info("Checking pod log for model import information")
135+
pod = get_model_catalog_pod(client=admin_client, model_registry_namespace=model_registry_namespace)[0]
136+
log = pod.log(container="catalog")
137+
if f"{hf_id}: loaded {expected_num_models_from_hf_api} models" in log and f"{hf_id}: cleaned up 0 models" in log:
138+
LOGGER.info(f"Found log entry confirming model(s) imported for id: {hf_id}")
139+
return True
140+
else:
141+
LOGGER.warning(f"No relevant log entry found: {log}")
142+
return False

0 commit comments

Comments
 (0)