Skip to content

Commit b5726aa

Browse files
committed
fix: update calls
1 parent 2ca5c47 commit b5726aa

File tree

9 files changed

+172
-19
lines changed

9 files changed

+172
-19
lines changed

tests/model_registry/model_catalog/catalog_config/test_default_source_inclusion_exclusion_cleanup.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_include_models_by_pattern(
6666

6767
# Ensure baseline model state is restored for subsequent tests
6868
ensure_baseline_model_state(
69+
admin_client=admin_client,
6970
model_catalog_rest_url=model_catalog_rest_url,
7071
model_registry_rest_headers=model_registry_rest_headers,
7172
model_registry_namespace=model_registry_namespace,
@@ -110,6 +111,7 @@ def test_exclude_models_by_pattern(
110111

111112
# Ensure baseline model state is restored for subsequent tests
112113
ensure_baseline_model_state(
114+
admin_client=admin_client,
113115
model_catalog_rest_url=model_catalog_rest_url,
114116
model_registry_rest_headers=model_registry_rest_headers,
115117
model_registry_namespace=model_registry_namespace,
@@ -190,7 +192,7 @@ def test_combined_include_exclude_filtering(
190192
)
191193

192194
db_models = get_models_from_database_by_source(
193-
source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
195+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
194196
)
195197

196198
is_valid, error_msg = validate_model_filtering_consistency(api_models=api_models, db_models=db_models)
@@ -206,6 +208,7 @@ def test_combined_include_exclude_filtering(
206208

207209
# Ensure baseline model state is restored for subsequent tests
208210
ensure_baseline_model_state(
211+
admin_client=admin_client,
209212
model_catalog_rest_url=model_catalog_rest_url,
210213
model_registry_rest_headers=model_registry_rest_headers,
211214
model_registry_namespace=model_registry_namespace,
@@ -255,7 +258,7 @@ def test_model_cleanup_on_exclusion_change(
255258
pytest.fail(f"Phase 1: Timeout waiting for granite models {granite_models}: {e}")
256259

257260
phase1_db_models = get_models_from_database_by_source(
258-
source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
261+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
259262
)
260263

261264
assert phase1_api_models == granite_models, (
@@ -292,7 +295,7 @@ def test_model_cleanup_on_exclusion_change(
292295
pytest.fail(f"Phase 2: Timeout waiting for prometheus models {prometheus_models}: {e}")
293296

294297
phase2_db_models = get_models_from_database_by_source(
295-
source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
298+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
296299
)
297300

298301
# Should only have prometheus models now
@@ -307,6 +310,7 @@ def test_model_cleanup_on_exclusion_change(
307310

308311
# Ensure baseline model state is restored for subsequent tests
309312
ensure_baseline_model_state(
313+
admin_client=admin_client,
310314
model_catalog_rest_url=model_catalog_rest_url,
311315
model_registry_rest_headers=model_registry_rest_headers,
312316
model_registry_namespace=model_registry_namespace,
@@ -351,14 +355,15 @@ def test_source_disabling_removes_models(
351355

352356
# Verify database is also cleaned
353357
db_models = get_models_from_database_by_source(
354-
source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
358+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
355359
)
356360
assert len(db_models) == 0, f"Database should be clean when source disabled, found: {db_models}"
357361

358362
LOGGER.info("SUCCESS: Source disabling removed all models")
359363

360364
# Ensure baseline model state is restored for subsequent tests
361365
ensure_baseline_model_state(
366+
admin_client=admin_client,
362367
model_catalog_rest_url=model_catalog_rest_url,
363368
model_registry_rest_headers=model_registry_rest_headers,
364369
model_registry_namespace=model_registry_namespace,
@@ -416,6 +421,7 @@ def test_model_removal_logging(
416421

417422
# Ensure baseline model state is restored for subsequent tests
418423
ensure_baseline_model_state(
424+
admin_client=admin_client,
419425
model_catalog_rest_url=model_catalog_rest_url,
420426
model_registry_rest_headers=model_registry_rest_headers,
421427
model_registry_namespace=model_registry_namespace,
@@ -467,6 +473,7 @@ def test_source_disabling_logging(
467473

468474
# Ensure baseline model state is restored for subsequent tests
469475
ensure_baseline_model_state(
476+
admin_client=admin_client,
470477
model_catalog_rest_url=model_catalog_rest_url,
471478
model_registry_rest_headers=model_registry_rest_headers,
472479
model_registry_namespace=model_registry_namespace,

tests/model_registry/model_catalog/catalog_config/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,12 @@ def validate_model_catalog_configmap_data(configmap: ConfigMap, num_catalogs: in
132132
validate_default_catalog(catalogs=catalogs)
133133

134134

135-
def get_models_from_database_by_source(source_id: str, namespace: str) -> set[str]:
135+
def get_models_from_database_by_source(admin_client: DynamicClient, source_id: str, namespace: str) -> set[str]:
136136
"""
137137
Query database directly to get all model names for a specific source.
138138
139139
Args:
140+
admin_client: DynamicClient for Kubernetes API access
140141
source_id: Catalog source ID to filter by
141142
namespace: OpenShift namespace for database access
142143
@@ -145,7 +146,7 @@ def get_models_from_database_by_source(source_id: str, namespace: str) -> set[st
145146
"""
146147

147148
query = GET_MODELS_BY_SOURCE_ID_DB_QUERY.format(source_id=source_id)
148-
result = execute_database_query(query=query, namespace=namespace)
149+
result = execute_database_query(admin_client=admin_client, query=query, namespace=namespace)
149150
parsed = parse_psql_output(psql_output=result)
150151
return set(parsed.get("values", []))
151152

@@ -478,7 +479,7 @@ def execute_inclusion_exclusion_filter_test(
478479
pytest.fail(f"Timeout waiting for {pattern} models to appear. Expected: {expected_models}, {e}")
479480

480481
db_models = get_models_from_database_by_source(
481-
source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
482+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
482483
)
483484

484485
# Validate consistency
@@ -493,6 +494,7 @@ def execute_inclusion_exclusion_filter_test(
493494

494495
@retry(wait_timeout=300, sleep=10, exceptions_dict={Exception: []}, print_log=False)
495496
def _validate_baseline_models(
497+
admin_client: DynamicClient,
496498
model_catalog_rest_url: list[str],
497499
model_registry_rest_headers: dict[str, str],
498500
model_registry_namespace: str,
@@ -513,7 +515,9 @@ def _validate_baseline_models(
513515
api_models = {model["name"] for model in api_response.get("items", [])}
514516

515517
# Fetch current models from database
516-
db_models = get_models_from_database_by_source(source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace)
518+
db_models = get_models_from_database_by_source(
519+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
520+
)
517521

518522
count = len(api_models)
519523

@@ -542,6 +546,7 @@ def _validate_baseline_models(
542546

543547

544548
def ensure_baseline_model_state(
549+
admin_client: DynamicClient,
545550
model_catalog_rest_url: list[str],
546551
model_registry_rest_headers: dict[str, str],
547552
model_registry_namespace: str,
@@ -574,6 +579,7 @@ def ensure_baseline_model_state(
574579
# Use retry decorator for automatic polling and eventual reconciliation
575580
try:
576581
_validate_baseline_models(
582+
admin_client=admin_client,
577583
model_catalog_rest_url=model_catalog_rest_url,
578584
model_registry_rest_headers=model_registry_rest_headers,
579585
model_registry_namespace=model_registry_namespace,

tests/model_registry/model_catalog/conftest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,10 @@ def model_catalog_rest_url(model_registry_namespace: str, model_catalog_routes:
447447

448448
@pytest.fixture(scope="function")
449449
def baseline_redhat_ai_models(
450-
model_catalog_rest_url: list[str], model_registry_rest_headers: dict[str, str], model_registry_namespace: str
450+
admin_client: DynamicClient,
451+
model_catalog_rest_url: list[str],
452+
model_registry_rest_headers: dict[str, str],
453+
model_registry_namespace: str,
451454
) -> dict[str, set[str] | int]:
452455
"""
453456
fixture providing baseline model data for redhat_ai_models source.
@@ -463,6 +466,8 @@ def baseline_redhat_ai_models(
463466
)
464467
api_models = {model["name"] for model in api_response.get("items", [])}
465468

466-
db_models = get_models_from_database_by_source(source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace)
469+
db_models = get_models_from_database_by_source(
470+
admin_client=admin_client, source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace
471+
)
467472

468473
return {"api_models": api_models, "db_models": db_models, "count": len(api_models)}

tests/model_registry/model_catalog/db_check/test_model_catalog_db_validation.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from kubernetes.dynamic import DynamicClient
33
from ocp_resources.network_policy import NetworkPolicy
4+
from ocp_resources.pod import Pod
45
from simple_logger.logger import get_logger
56
from timeout_sampler import TimeoutSampler
67

@@ -77,14 +78,30 @@ def test_postgres_network_policy_allows_only_catalog_pods(self, model_catalog_po
7778
"Only model-catalog pods should be allowed to access postgres"
7879
)
7980

81+
def test_postgres_network_policy_has_correct_labels(self, model_catalog_postgres_network_policy):
82+
"""Test that NetworkPolicy has correct operator-managed labels"""
83+
labels = model_catalog_postgres_network_policy.instance.metadata.labels
84+
assert labels["app.kubernetes.io/created-by"] == "model-registry-operator", (
85+
"NetworkPolicy should be created by model-registry-operator"
86+
)
87+
assert labels["app.kubernetes.io/part-of"] == "model-catalog", "NetworkPolicy should be part of model-catalog"
88+
assert labels["app.kubernetes.io/managed-by"] == "model-registry-operator", (
89+
"NetworkPolicy should be managed by model-registry-operator"
90+
)
91+
8092
@pytest.mark.dependency(name="test_postgres_network_policy_recreation")
81-
def test_postgres_network_policy_recreated_after_deletion(
93+
def test_postgres_network_policy_recreated_on_reconciliation(
8294
self,
8395
admin_client: DynamicClient,
8496
model_catalog_postgres_network_policy,
8597
model_registry_namespace: str,
8698
):
87-
"""Test that operator recreates NetworkPolicy after deletion"""
99+
"""Test that operator recreates NetworkPolicy when reconciliation is triggered.
100+
101+
The NetworkPolicy is NOT watched directly by the operator, so deleting it alone
102+
won't trigger recreation. Deleting the postgres pod triggers a Deployment change,
103+
which triggers reconciliation and recreates the NetworkPolicy.
104+
"""
88105
model_catalog_postgres_network_policy.delete()
89106
get_postgres_pod_in_namespace(admin_client=admin_client, namespace=model_registry_namespace).delete()
90107
for np in TimeoutSampler(
@@ -98,3 +115,84 @@ def test_postgres_network_policy_recreated_after_deletion(
98115
if np.exists:
99116
LOGGER.info("NetworkPolicy has been recreated by operator")
100117
break
118+
119+
@pytest.mark.dependency(depends=["test_postgres_network_policy_recreation"])
120+
def test_postgres_network_policy_spec_preserved_after_recreation(
121+
self,
122+
admin_client: DynamicClient,
123+
model_registry_namespace: str,
124+
):
125+
"""Test that recreated NetworkPolicy has the same correct spec"""
126+
recreated_np = NetworkPolicy(
127+
client=admin_client,
128+
name="model-catalog-postgres",
129+
namespace=model_registry_namespace,
130+
)
131+
assert recreated_np.exists, "Recreated NetworkPolicy should exist"
132+
133+
spec = recreated_np.instance.spec
134+
135+
# Verify podSelector targets postgres pods
136+
assert spec.podSelector.matchLabels["app.kubernetes.io/name"] == "model-catalog-postgres", (
137+
"Recreated NetworkPolicy should still target postgres pods"
138+
)
139+
140+
# Verify ingress policy type
141+
assert "Ingress" in spec.policyTypes, "Recreated NetworkPolicy should have Ingress policy type"
142+
143+
# Verify port restriction
144+
assert len(spec.ingress) == 1, "Recreated NetworkPolicy should have exactly one ingress rule"
145+
port = spec.ingress[0].ports[0]
146+
assert port.port == 5432, "Recreated NetworkPolicy should allow only PostgreSQL port 5432"
147+
assert port.protocol == "TCP", "Recreated NetworkPolicy port should use TCP protocol"
148+
149+
# Verify from selector allows only catalog pods
150+
from_selector = spec.ingress[0]["from"][0].podSelector.matchLabels
151+
assert from_selector["component"] == "model-catalog", (
152+
"Recreated NetworkPolicy should still allow only model-catalog pods"
153+
)
154+
155+
# Verify labels
156+
labels = recreated_np.instance.metadata.labels
157+
assert labels["app.kubernetes.io/created-by"] == "model-registry-operator"
158+
assert labels["app.kubernetes.io/part-of"] == "model-catalog"
159+
160+
LOGGER.info("Recreated NetworkPolicy spec and labels match expected configuration")
161+
162+
def test_postgres_network_policy_recreated_after_operator_restart(
163+
self,
164+
admin_client: DynamicClient,
165+
model_registry_operator_pod: Pod,
166+
model_registry_namespace: str,
167+
):
168+
"""Test that operator restart recreates a deleted NetworkPolicy via initial reconciliation.
169+
170+
This simulates a production scenario where the operator pod is restarted
171+
(e.g., during upgrades) and must reconcile all managed resources including
172+
the NetworkPolicy.
173+
"""
174+
# Delete the NetworkPolicy first
175+
np = NetworkPolicy(
176+
client=admin_client,
177+
name="model-catalog-postgres",
178+
namespace=model_registry_namespace,
179+
)
180+
assert np.exists, "NetworkPolicy should exist before operator restart"
181+
np.delete()
182+
183+
# Restart the operator pod to trigger initial reconciliation
184+
LOGGER.info(f"Deleting operator pod {model_registry_operator_pod.name} to trigger reconciliation")
185+
model_registry_operator_pod.delete(wait=True)
186+
187+
# Wait for the NetworkPolicy to be recreated
188+
for recreated_np in TimeoutSampler(
189+
wait_timeout=180,
190+
sleep=10,
191+
func=NetworkPolicy,
192+
client=admin_client,
193+
name="model-catalog-postgres",
194+
namespace=model_registry_namespace,
195+
):
196+
if recreated_np.exists:
197+
LOGGER.info("NetworkPolicy has been recreated after operator restart")
198+
break

tests/model_registry/model_catalog/metadata/test_filter_options_endpoint.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from typing import Self
3+
from kubernetes.dynamic import DynamicClient
34
from simple_logger.logger import get_logger
45
from tests.model_registry.model_catalog.metadata.utils import (
56
validate_filter_options_structure,
@@ -103,6 +104,7 @@ def test_filter_options_endpoint_validation(
103104
)
104105
def test_comprehensive_coverage_against_database(
105106
self: Self,
107+
admin_client: DynamicClient,
106108
model_catalog_rest_url: list[str],
107109
user_token_for_api_calls: str,
108110
model_registry_namespace: str,
@@ -130,7 +132,9 @@ def test_comprehensive_coverage_against_database(
130132

131133
LOGGER.info(f"Executing database query in namespace: {model_registry_namespace}")
132134

133-
db_result = execute_database_query(query=FILTER_OPTIONS_DB_QUERY, namespace=model_registry_namespace)
135+
db_result = execute_database_query(
136+
admin_client=admin_client, query=FILTER_OPTIONS_DB_QUERY, namespace=model_registry_namespace
137+
)
134138
parsed_result = parse_psql_output(psql_output=db_result)
135139

136140
db_properties = parsed_result.get("properties", {})

0 commit comments

Comments
 (0)