Skip to content

Commit f81e76e

Browse files
authored
Hf model verifications (#950)
* HF tests for data validations * Move test and utilities based on review comments
1 parent bfb5d25 commit f81e76e

File tree

6 files changed

+1097
-871
lines changed

6 files changed

+1097
-871
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ dependencies = [
7474
"pytest-xdist==3.8.0",
7575
"dictdiffer>=0.9.0",
7676
"pytest>=9.0.0",
77+
"huggingface-hub>=1.2.3",
7778
]
7879

7980
[project.urls]

tests/model_registry/model_catalog/huggingface/__init__.py

Whitespace-only changes.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import pytest
2+
from huggingface_hub import HfApi
3+
4+
5+
@pytest.fixture()
6+
def huggingface_api():
7+
return HfApi()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
from typing import Self
3+
from ocp_resources.config_map import ConfigMap
4+
from simple_logger.logger import get_logger
5+
6+
from tests.model_registry.model_catalog.constants import HF_MODELS
7+
from tests.model_registry.model_catalog.utils import (
8+
get_hf_catalog_str,
9+
)
10+
from tests.model_registry.model_catalog.huggingface.utils import (
11+
assert_huggingface_values_matches_model_catalog_api_values,
12+
)
13+
14+
LOGGER = get_logger(name=__name__)
15+
16+
pytestmark = [pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace")]
17+
18+
19+
@pytest.mark.parametrize(
20+
"updated_catalog_config_map, expected_catalog_values",
21+
[
22+
pytest.param(
23+
{
24+
"sources_yaml": get_hf_catalog_str(ids=["mixed"]),
25+
},
26+
HF_MODELS["mixed"],
27+
id="validate_hf_fields",
28+
marks=pytest.mark.install,
29+
),
30+
],
31+
indirect=True,
32+
)
33+
@pytest.mark.usefixtures("updated_catalog_config_map")
34+
class TestHuggingFaceModelValidation:
35+
"""Test HuggingFace model values by comparing values between HF API calls and Model Catalog api call"""
36+
37+
def test_huggingface_model_metadata(
38+
self: Self,
39+
updated_catalog_config_map: tuple[ConfigMap, str, str],
40+
model_catalog_rest_url: list[str],
41+
model_registry_rest_headers: dict[str, str],
42+
expected_catalog_values: dict[str, str],
43+
huggingface_api: bool,
44+
):
45+
"""
46+
Validate HuggingFace model metadata structure and required fields
47+
Cross-validate with actual HuggingFace Hub API
48+
"""
49+
assert_huggingface_values_matches_model_catalog_api_values(
50+
model_registry_rest_headers=model_registry_rest_headers,
51+
model_catalog_rest_url=model_catalog_rest_url,
52+
expected_catalog_values=expected_catalog_values,
53+
huggingface_api=huggingface_api,
54+
)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import ast
2+
from typing import Any
3+
4+
from tests.model_registry.model_catalog.constants import HF_SOURCE_ID
5+
from tests.model_registry.model_catalog.utils import LOGGER
6+
from tests.model_registry.utils import execute_get_command
7+
from huggingface_hub import HfApi
8+
9+
10+
def get_huggingface_model_params(model_name: str, huggingface_api: HfApi) -> dict[str, Any]:
11+
"""
12+
Get some of the fields from HuggingFace API for validation against our model catalog data
13+
"""
14+
hf_model_info = huggingface_api.model_info(repo_id=model_name)
15+
fields_mapping = {
16+
"tags": "tags",
17+
"gated": "gated",
18+
"private": "private",
19+
"architectures": "config.architectures",
20+
"model_type": "config.model_type",
21+
}
22+
23+
result = {}
24+
for key, path in fields_mapping.items():
25+
value = get_huggingface_nested_attributes(obj=hf_model_info, attr_path=path)
26+
if key == "tags":
27+
value = list(filter(lambda field: not field.startswith("license:"), value))
28+
# Convert gated to string if it's the gated field
29+
if key in ["gated", "private"] and value is not None:
30+
# model registry converts them to lower case. So before validation we need to do the same
31+
value = str(value).lower()
32+
result[key] = value
33+
return result
34+
35+
36+
def get_huggingface_nested_attributes(obj, attr_path) -> Any:
37+
"""
38+
Get nested attribute using dot notation like 'config.architectures'
39+
"""
40+
try:
41+
current_obj = obj
42+
for index, attr in enumerate(attr_path.split(".")):
43+
# Handle both object attributes and dictionary keys
44+
if isinstance(current_obj, dict):
45+
# For dictionaries, use key access
46+
if attr not in current_obj:
47+
return None
48+
current_obj = current_obj[attr]
49+
else:
50+
# For objects, use attribute access
51+
if not hasattr(current_obj, attr):
52+
return None
53+
current_obj = getattr(current_obj, attr)
54+
return current_obj
55+
except AttributeError as e:
56+
LOGGER.error(f"AttributeError getting '{attr_path}': {e}")
57+
return None
58+
except Exception as e:
59+
LOGGER.error(f"Unexpected error getting '{attr_path}': {e}")
60+
return None
61+
62+
63+
def assert_huggingface_values_matches_model_catalog_api_values(
64+
model_catalog_rest_url: list[str],
65+
model_registry_rest_headers: dict[str, str],
66+
expected_catalog_values: dict[str, str],
67+
huggingface_api: HfApi,
68+
) -> None:
69+
mismatch = {}
70+
LOGGER.info("Validating HuggingFace model metadata:")
71+
for model_name in expected_catalog_values:
72+
url = f"{model_catalog_rest_url[0]}sources/{HF_SOURCE_ID}/models/{model_name}"
73+
result = execute_get_command(
74+
url=url,
75+
headers=model_registry_rest_headers,
76+
)
77+
assert result["name"] == model_name
78+
hf_api_values = get_huggingface_model_params(model_name=model_name, huggingface_api=huggingface_api)
79+
error = ""
80+
for field_name in ["gated", "private", "model_type"]:
81+
model_catalog_value = result["customProperties"][f"hf_{field_name}"]["string_value"]
82+
if model_catalog_value != str(hf_api_values[field_name]):
83+
error += (
84+
f"HuggingFace api value for {field_name} is {hf_api_values[field_name]} and "
85+
f"value found from model catalog api call is {model_catalog_value}"
86+
)
87+
for field_name in ["architectures", "tags"]:
88+
field_value = sorted(ast.literal_eval(result["customProperties"][f"hf_{field_name}"]["string_value"]))
89+
hf_api_value = sorted(hf_api_values[field_name])
90+
if field_value != hf_api_value:
91+
error += f"HF api value for {field_name} {field_value} and found {hf_api_value}"
92+
if error:
93+
mismatch[model_name] = error
94+
95+
if mismatch:
96+
LOGGER.error(f"mismatches are: {mismatch}")
97+
raise AssertionError("HF api call and model catalog hf models has value mismatch")

0 commit comments

Comments
 (0)