|
1 | | -from typing import Any, Tuple, List |
| 1 | +from typing import Any, Tuple, List, Dict |
2 | 2 | import yaml |
3 | 3 |
|
4 | 4 | from kubernetes.dynamic import DynamicClient |
|
22 | 22 | CATALOG_CONTAINER, |
23 | 23 | PERFORMANCE_DATA_DIR, |
24 | 24 | ) |
| 25 | +from tests.model_registry.constants import DEFAULT_CUSTOM_MODEL_CATALOG, DEFAULT_MODEL_CATALOG_CM |
25 | 26 | from tests.model_registry.utils import execute_get_command |
| 27 | +from tests.model_registry.utils import get_rest_headers |
26 | 28 |
|
27 | 29 | LOGGER = get_logger(name=__name__) |
28 | 30 |
|
@@ -1121,3 +1123,75 @@ def validate_model_artifacts_match_criteria_or( |
1121 | 1123 |
|
1122 | 1124 | LOGGER.error(f"Model {model_name} failed all OR validations") |
1123 | 1125 | return False |
| 1126 | + |
| 1127 | + |
| 1128 | +def get_labels_from_configmaps(admin_client: DynamicClient, namespace: str) -> List[Dict[str, Any]]: |
| 1129 | + """ |
| 1130 | + Get all labels from both model catalog ConfigMaps. |
| 1131 | +
|
| 1132 | + Args: |
| 1133 | + admin_client: Kubernetes client |
| 1134 | + namespace: Namespace containing the ConfigMaps |
| 1135 | +
|
| 1136 | + Returns: |
| 1137 | + List of all label dictionaries from both ConfigMaps |
| 1138 | + """ |
| 1139 | + labels = [] |
| 1140 | + |
| 1141 | + # Get labels from default ConfigMap |
| 1142 | + default_cm = ConfigMap(name=DEFAULT_MODEL_CATALOG_CM, client=admin_client, namespace=namespace) |
| 1143 | + default_data = yaml.safe_load(default_cm.instance.data["sources.yaml"]) |
| 1144 | + if "labels" in default_data: |
| 1145 | + labels.extend(default_data["labels"]) |
| 1146 | + |
| 1147 | + # Get labels from sources ConfigMap |
| 1148 | + sources_cm = ConfigMap(name=DEFAULT_CUSTOM_MODEL_CATALOG, client=admin_client, namespace=namespace) |
| 1149 | + sources_data = yaml.safe_load(sources_cm.instance.data["sources.yaml"]) |
| 1150 | + if "labels" in sources_data: |
| 1151 | + labels.extend(sources_data["labels"]) |
| 1152 | + |
| 1153 | + return labels |
| 1154 | + |
| 1155 | + |
| 1156 | +def get_labels_from_api(model_catalog_rest_url: str, user_token: str) -> List[Dict[str, Any]]: |
| 1157 | + """ |
| 1158 | + Get labels from the API endpoint. |
| 1159 | +
|
| 1160 | + Args: |
| 1161 | + model_catalog_rest_url: Base URL for model catalog API |
| 1162 | + user_token: Authentication token |
| 1163 | +
|
| 1164 | + Returns: |
| 1165 | + List of label dictionaries from API response |
| 1166 | + """ |
| 1167 | + url = f"{model_catalog_rest_url}labels" |
| 1168 | + headers = get_rest_headers(token=user_token) |
| 1169 | + response = execute_get_command(url=url, headers=headers) |
| 1170 | + return response["items"] |
| 1171 | + |
| 1172 | + |
| 1173 | +def verify_labels_match(expected_labels: List[Dict[str, Any]], api_labels: List[Dict[str, Any]]) -> None: |
| 1174 | + """ |
| 1175 | + Verify that all expected labels are present in the API response. |
| 1176 | +
|
| 1177 | + Args: |
| 1178 | + expected_labels: Labels expected from ConfigMaps |
| 1179 | + api_labels: Labels returned by API |
| 1180 | +
|
| 1181 | + Raises: |
| 1182 | + AssertionError: If any expected label is not found in API response |
| 1183 | + """ |
| 1184 | + LOGGER.info(f"Verifying {len(expected_labels)} expected labels against {len(api_labels)} API labels") |
| 1185 | + |
| 1186 | + for expected_label in expected_labels: |
| 1187 | + found = False |
| 1188 | + for api_label in api_labels: |
| 1189 | + if ( |
| 1190 | + expected_label.get("name") == api_label.get("name") |
| 1191 | + and expected_label.get("displayName") == api_label.get("displayName") |
| 1192 | + and expected_label.get("description") == api_label.get("description") |
| 1193 | + ): |
| 1194 | + found = True |
| 1195 | + break |
| 1196 | + |
| 1197 | + assert found, f"Expected label not found in API response: {expected_label}" |
0 commit comments