Skip to content

Commit fdc219f

Browse files
committed
address comments
1 parent ee39766 commit fdc219f

3 files changed

Lines changed: 42 additions & 29 deletions

File tree

tests/model_registry/model_catalog/huggingface/conftest.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import time
33
from huggingface_hub import HfApi
44
from simple_logger.logger import get_logger
5-
from tests.model_registry.utils import execute_get_command
5+
6+
from tests.model_registry.model_catalog.huggingface.utils import get_huggingface_model_from_api
67

78
LOGGER = get_logger(name=__name__)
89

@@ -56,11 +57,11 @@ def initial_last_synced_values(
5657
"""
5758
Collect initial last_synced values for a given model.
5859
"""
59-
model_name = request.param
60-
url = f"{model_catalog_rest_url[0]}sources/hf_id/models/{model_name}"
61-
result = execute_get_command(
62-
url=url,
63-
headers=model_registry_rest_headers,
60+
result = get_huggingface_model_from_api(
61+
model_registry_rest_headers=model_registry_rest_headers,
62+
model_catalog_rest_url=model_catalog_rest_url,
63+
model_name=request.param,
64+
source_id="hf_id",
6465
)
6566

6667
return result["customProperties"]["last_synced"]["string_value"]

tests/model_registry/model_catalog/huggingface/test_huggingface_model_validation.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
from tests.model_registry.model_catalog.utils import (
77
get_hf_catalog_str,
88
)
9-
from tests.model_registry.utils import execute_get_command
109
from tests.model_registry.model_catalog.huggingface.utils import (
1110
assert_huggingface_values_matches_model_catalog_api_values,
1211
wait_for_huggingface_retrival_match,
1312
wait_for_hugging_face_model_import,
1413
wait_for_last_sync_update,
14+
get_huggingface_model_from_api,
1515
)
1616
from huggingface_hub import HfApi
1717
from kubernetes.dynamic import DynamicClient
@@ -60,6 +60,7 @@ def test_huggingface_last_synced_custom(
6060
model_registry_rest_headers=model_registry_rest_headers,
6161
model_catalog_rest_url=model_catalog_rest_url,
6262
model_name=model_name,
63+
source_id="hf_id",
6364
initial_last_synced_values=float(initial_last_synced_values),
6465
)
6566

@@ -85,7 +86,6 @@ class TestHuggingFaceModelValidation:
8586
def test_huggingface_model_metadata_last_synced(
8687
self: Self,
8788
epoch_time_before_config_map_update: float,
88-
updated_catalog_config_map: tuple[ConfigMap, str, str],
8989
model_catalog_rest_url: list[str],
9090
model_registry_rest_headers: dict[str, str],
9191
expected_catalog_values: dict[str, str],
@@ -100,12 +100,12 @@ def test_huggingface_model_metadata_last_synced(
100100
)
101101
error = {}
102102
for model_name in expected_catalog_values:
103-
url = f"{model_catalog_rest_url[0]}sources/{HF_SOURCE_ID}/models/{model_name}"
104-
result = execute_get_command(
105-
url=url,
106-
headers=model_registry_rest_headers,
103+
result = get_huggingface_model_from_api(
104+
model_catalog_rest_url=model_catalog_rest_url,
105+
model_registry_rest_headers=model_registry_rest_headers,
106+
model_name=model_name,
107+
source_id=HF_SOURCE_ID,
107108
)
108-
109109
error_msg = ""
110110
if result["name"] != model_name:
111111
error_msg += f"Expected model name {model_name}, but got {result['name']}. "
@@ -115,17 +115,13 @@ def test_huggingface_model_metadata_last_synced(
115115
LOGGER.info(f"Model {model_name} last synced at: {last_synced}")
116116

117117
# Validate that last_synced field exists and is not empty
118-
if last_synced is None:
119-
error_msg += f"last_synced field is None for model {model_name}. "
120-
elif last_synced == "":
121-
error_msg += f"last_synced field is empty for model {model_name}. "
122-
else:
123-
# Compare timestamps: current_epoch_time should be earlier than last_synced
124-
if epoch_time_before_config_map_update > float(last_synced):
125-
error_msg += (
126-
f"Model {model_name} last_synced ({last_synced}) should be after "
127-
f"test start time ({epoch_time_before_config_map_update}). "
128-
)
118+
if not last_synced or last_synced == "":
119+
error_msg += f"last_synced field is not present for model {model_name}. "
120+
elif epoch_time_before_config_map_update > float(last_synced):
121+
error_msg += (
122+
f"Model {model_name} last_synced ({last_synced}) should be after "
123+
f"test start time ({epoch_time_before_config_map_update}). "
124+
)
129125
if error_msg:
130126
error[model_name] = error_msg
131127
if error:

tests/model_registry/model_catalog/huggingface/utils.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import ast
22
from typing import Any
3+
34
from simple_logger.logger import get_logger
45

56
from tests.model_registry.model_catalog.constants import HF_SOURCE_ID
@@ -142,20 +143,35 @@ def wait_for_hugging_face_model_import(
142143
return False
143144

144145

146+
def get_huggingface_model_from_api(
147+
model_catalog_rest_url: list[str],
148+
model_registry_rest_headers: dict[str, str],
149+
model_name: str,
150+
source_id: str,
151+
) -> dict[str, Any]:
152+
url = f"{model_catalog_rest_url[0]}sources/{source_id}/models/{model_name}"
153+
return execute_get_command(
154+
url=url,
155+
headers=model_registry_rest_headers,
156+
)
157+
158+
145159
@retry(wait_timeout=135, sleep=15)
146160
def wait_for_last_sync_update(
147161
model_catalog_rest_url: list[str],
148162
model_registry_rest_headers: dict[str, str],
149163
model_name: str,
164+
source_id: str,
150165
initial_last_synced_values: float,
151166
) -> bool:
152167
"""Wait for the last_synced value to be updated with exact 120-second difference"""
153-
url = f"{model_catalog_rest_url[0]}sources/hf_id/models/{model_name}"
154-
result = execute_get_command(
155-
url=url,
156-
headers=model_registry_rest_headers,
157-
)
158168

169+
result = get_huggingface_model_from_api(
170+
model_registry_rest_headers=model_registry_rest_headers,
171+
model_catalog_rest_url=model_catalog_rest_url,
172+
model_name=model_name,
173+
source_id=source_id,
174+
)
159175
current_last_synced = float(result["customProperties"]["last_synced"]["string_value"])
160176
if current_last_synced != initial_last_synced_values:
161177
# Calculate difference in milliseconds and convert to seconds

0 commit comments

Comments
 (0)