Skip to content

Commit 98fd129

Browse files
lugi0mwaykole
authored andcommitted
feat: wait for default source state after each test executes (opendatahub-io#1057)
Signed-off-by: lugi0 <lgiorgi@redhat.com>
1 parent 883bd4b commit 98fd129

3 files changed

Lines changed: 144 additions & 96 deletions

File tree

tests/model_registry/model_catalog/catalog_config/test_default_source_inclusion_exclusion_cleanup.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
validate_cleanup_logging,
1717
filter_models_by_pattern,
1818
execute_inclusion_exclusion_filter_test,
19+
ensure_baseline_model_state,
1920
)
2021
from tests.model_registry.utils import wait_for_model_catalog_api
2122

2223
LOGGER = get_logger(name=__name__)
2324

2425
pytestmark = [
25-
pytest.mark.usefixtures(
26-
"updated_dsc_component_state_scope_session", "model_registry_namespace", "baseline_model_state"
27-
),
26+
pytest.mark.usefixtures("updated_dsc_component_state_scope_session", "model_registry_namespace"),
2827
]
2928

3029

@@ -65,6 +64,13 @@ def test_include_models_by_pattern(
6564
model_registry_rest_headers=model_registry_rest_headers,
6665
)
6766

67+
# Ensure baseline model state is restored for subsequent tests
68+
ensure_baseline_model_state(
69+
model_catalog_rest_url=model_catalog_rest_url,
70+
model_registry_rest_headers=model_registry_rest_headers,
71+
model_registry_namespace=model_registry_namespace,
72+
)
73+
6874

6975
class TestModelExclusionFiltering:
7076
"""Test exclusion filtering functionality (RHOAIENG-41841 part 2)"""
@@ -102,6 +108,13 @@ def test_exclude_models_by_pattern(
102108
model_registry_rest_headers=model_registry_rest_headers,
103109
)
104110

111+
# Ensure baseline model state is restored for subsequent tests
112+
ensure_baseline_model_state(
113+
model_catalog_rest_url=model_catalog_rest_url,
114+
model_registry_rest_headers=model_registry_rest_headers,
115+
model_registry_namespace=model_registry_namespace,
116+
)
117+
105118

106119
class TestCombinedIncludeExcludeFiltering:
107120
"""Test combined include+exclude filtering (RHOAIENG-41841 part 3)"""
@@ -191,6 +204,13 @@ def test_combined_include_exclude_filtering(
191204
f"SUCCESS: {len(api_models)} {include_pattern} models after excluding {exclude_pattern} variants"
192205
)
193206

207+
# Ensure baseline model state is restored for subsequent tests
208+
ensure_baseline_model_state(
209+
model_catalog_rest_url=model_catalog_rest_url,
210+
model_registry_rest_headers=model_registry_rest_headers,
211+
model_registry_namespace=model_registry_namespace,
212+
)
213+
194214

195215
class TestModelCleanupLifecycle:
196216
"""Test automatic model cleanup during lifecycle changes (RHOAIENG-41846)"""
@@ -285,6 +305,13 @@ def test_model_cleanup_on_exclusion_change(
285305
f"Phase 2 SUCCESS: Granite models cleaned up, {len(phase2_api_models)} prometheus models remain"
286306
)
287307

308+
# Ensure baseline model state is restored for subsequent tests
309+
ensure_baseline_model_state(
310+
model_catalog_rest_url=model_catalog_rest_url,
311+
model_registry_rest_headers=model_registry_rest_headers,
312+
model_registry_namespace=model_registry_namespace,
313+
)
314+
288315

289316
class TestSourceLifecycleCleanup:
290317
"""Test source disabling cleanup scenarios (RHOAIENG-41846)"""
@@ -330,6 +357,13 @@ def test_source_disabling_removes_models(
330357

331358
LOGGER.info("SUCCESS: Source disabling removed all models")
332359

360+
# Ensure baseline model state is restored for subsequent tests
361+
ensure_baseline_model_state(
362+
model_catalog_rest_url=model_catalog_rest_url,
363+
model_registry_rest_headers=model_registry_rest_headers,
364+
model_registry_namespace=model_registry_namespace,
365+
)
366+
333367

334368
class TestLoggingValidation:
335369
"""Test cleanup operation logging (RHOAIENG-41846)"""
@@ -380,6 +414,13 @@ def test_model_removal_logging(
380414
except TimeoutExpiredError as e:
381415
pytest.fail(f"Expected log patterns not found: {e}")
382416

417+
# Ensure baseline model state is restored for subsequent tests
418+
ensure_baseline_model_state(
419+
model_catalog_rest_url=model_catalog_rest_url,
420+
model_registry_rest_headers=model_registry_rest_headers,
421+
model_registry_namespace=model_registry_namespace,
422+
)
423+
383424
@pytest.mark.sanity
384425
def test_source_disabling_logging(
385426
self,
@@ -423,3 +464,10 @@ def test_source_disabling_logging(
423464
LOGGER.info(f"SUCCESS: Found expected source disabling log patterns: {found_patterns}")
424465
except TimeoutExpiredError as e:
425466
pytest.fail(f"Expected source disabling log patterns not found: {e}")
467+
468+
# Ensure baseline model state is restored for subsequent tests
469+
ensure_baseline_model_state(
470+
model_catalog_rest_url=model_catalog_rest_url,
471+
model_registry_rest_headers=model_registry_rest_headers,
472+
model_registry_namespace=model_registry_namespace,
473+
)

tests/model_registry/model_catalog/catalog_config/utils.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,96 @@ def execute_inclusion_exclusion_filter_test(
489489
assert api_models == expected_models, f"Expected {test_description}: {expected_models}, got {api_models}"
490490

491491
LOGGER.info(f"SUCCESS: {len(api_models)} {pattern} models {filter_type}")
492+
493+
494+
@retry(wait_timeout=300, sleep=10, exceptions_dict={Exception: []}, print_log=False)
495+
def _validate_baseline_models(
496+
model_catalog_rest_url: list[str],
497+
model_registry_rest_headers: dict[str, str],
498+
model_registry_namespace: str,
499+
expected_models: set[str],
500+
expected_count: int,
501+
) -> None:
502+
"""
503+
Validate that baseline model expectations are met.
504+
Raises exception if validation fails (triggers retry).
505+
Returns None if successful (stops retry).
506+
"""
507+
# Fetch current models from API
508+
api_response = get_models_from_catalog_api(
509+
model_catalog_rest_url=model_catalog_rest_url,
510+
model_registry_rest_headers=model_registry_rest_headers,
511+
source_label="Red Hat AI",
512+
)
513+
api_models = {model["name"] for model in api_response.get("items", [])}
514+
515+
# Fetch current models from database
516+
db_models = get_models_from_database_by_source(source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace)
517+
518+
count = len(api_models)
519+
520+
# Validate all expectations - raise on any failure
521+
if count != expected_count:
522+
raise AssertionError(f"Expected {expected_count} models, got {count}")
523+
524+
if api_models != db_models:
525+
raise AssertionError(f"API models {api_models} don't match database models {db_models}")
526+
527+
if api_models != expected_models:
528+
raise AssertionError(f"Models {api_models} don't match expected set {expected_models}")
529+
530+
# Additional category validation
531+
granite_models = {model for model in api_models if "granite" in model}
532+
prometheus_models = {model for model in api_models if "prometheus" in model}
533+
534+
if len(granite_models) != 6 or len(prometheus_models) != 1:
535+
raise AssertionError(
536+
f"""Expected 6 granite + 1 prometheus models, \
537+
got {len(granite_models)} granite + {len(prometheus_models)} prometheus"""
538+
)
539+
540+
LOGGER.info("Baseline model validation successful: 7 models (6 granite, 1 prometheus)")
541+
return True
542+
543+
544+
def ensure_baseline_model_state(
545+
model_catalog_rest_url: list[str],
546+
model_registry_rest_headers: dict[str, str],
547+
model_registry_namespace: str,
548+
) -> None:
549+
"""
550+
Utility function to ensure that our baseline assumptions about the model data are correct.
551+
This should be called at the end of tests to ensure state consistency for subsequent tests.
552+
Uses @retry decorator for automatic polling (300s timeout, 10s interval) and eventual reconciliation.
553+
554+
Args:
555+
model_catalog_rest_url: URL for model catalog API
556+
model_registry_rest_headers: Headers for API requests
557+
model_registry_namespace: Namespace for model registry
558+
559+
Raises:
560+
pytest.FailError: If baseline state cannot be achieved after timeout
561+
"""
562+
# Expected baseline data
563+
expected_models = {
564+
"granite-3.1-8b-lab-v1",
565+
"granite-7b-redhat-lab",
566+
"granite-8b-code-base",
567+
"granite-8b-code-instruct",
568+
"granite-8b-lab-v1",
569+
"granite-8b-starter-v1",
570+
"prometheus-8x7b-v2-0",
571+
}
572+
expected_count = 7
573+
574+
# Use retry decorator for automatic polling and eventual reconciliation
575+
try:
576+
_validate_baseline_models(
577+
model_catalog_rest_url=model_catalog_rest_url,
578+
model_registry_rest_headers=model_registry_rest_headers,
579+
model_registry_namespace=model_registry_namespace,
580+
expected_models=expected_models,
581+
expected_count=expected_count,
582+
)
583+
except TimeoutExpiredError:
584+
pytest.fail("Failed to restore baseline model state after 300s timeout")

tests/model_registry/model_catalog/conftest.py

Lines changed: 0 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import requests
44

55
from simple_logger.logger import get_logger
6-
from timeout_sampler import retry, TimeoutExpiredError
76
import yaml
87
import pytest
98
from kubernetes.dynamic import DynamicClient
@@ -467,95 +466,3 @@ def baseline_redhat_ai_models(
467466
db_models = get_models_from_database_by_source(source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace)
468467

469468
return {"api_models": api_models, "db_models": db_models, "count": len(api_models)}
470-
471-
472-
@retry(wait_timeout=300, sleep=10, exceptions_dict={Exception: []}, print_log=False)
473-
def _validate_baseline_models(
474-
model_catalog_rest_url: list[str],
475-
model_registry_rest_headers: dict[str, str],
476-
model_registry_namespace: str,
477-
expected_models: set[str],
478-
expected_count: int,
479-
) -> None:
480-
"""
481-
Validate that baseline model expectations are met.
482-
Raises exception if validation fails (triggers retry).
483-
Returns None if successful (stops retry).
484-
"""
485-
# Fetch current models from API
486-
api_response = get_models_from_catalog_api(
487-
model_catalog_rest_url=model_catalog_rest_url,
488-
model_registry_rest_headers=model_registry_rest_headers,
489-
source_label="Red Hat AI",
490-
)
491-
api_models = {model["name"] for model in api_response.get("items", [])}
492-
493-
# Fetch current models from database
494-
db_models = get_models_from_database_by_source(source_id=REDHAT_AI_CATALOG_ID, namespace=model_registry_namespace)
495-
496-
count = len(api_models)
497-
498-
# Validate all expectations - raise on any failure
499-
if count != expected_count:
500-
raise AssertionError(f"Expected {expected_count} models, got {count}")
501-
502-
if api_models != db_models:
503-
raise AssertionError(f"API models {api_models} don't match database models {db_models}")
504-
505-
if api_models != expected_models:
506-
raise AssertionError(f"Models {api_models} don't match expected set {expected_models}")
507-
508-
# Additional category validation
509-
granite_models = {model for model in api_models if "granite" in model}
510-
prometheus_models = {model for model in api_models if "prometheus" in model}
511-
512-
if len(granite_models) != 6 or len(prometheus_models) != 1:
513-
raise AssertionError(
514-
f"""Expected 6 granite + 1 prometheus models, \
515-
got {len(granite_models)} granite + {len(prometheus_models)} prometheus"""
516-
)
517-
518-
LOGGER.info("Baseline model validation successful: 7 models (6 granite, 1 prometheus)")
519-
return True
520-
521-
522-
@pytest.fixture(scope="function")
523-
def baseline_model_state(
524-
model_catalog_rest_url: list[str],
525-
model_registry_rest_headers: dict[str, str],
526-
model_registry_namespace: str,
527-
) -> None:
528-
"""
529-
Validate that our baseline assumptions about the model data are correct.
530-
This fixture should be used by all test classes to ensure data consistency.
531-
Uses @retry decorator for automatic polling (300s timeout, 10s interval) and eventual reconciliation.
532-
533-
Args:
534-
model_catalog_rest_url: URL for model catalog API
535-
model_registry_rest_headers: Headers for API requests
536-
model_registry_namespace: Namespace for model registry
537-
"""
538-
539-
# Expected baseline data
540-
expected_models = {
541-
"granite-3.1-8b-lab-v1",
542-
"granite-7b-redhat-lab",
543-
"granite-8b-code-base",
544-
"granite-8b-code-instruct",
545-
"granite-8b-lab-v1",
546-
"granite-8b-starter-v1",
547-
"prometheus-8x7b-v2-0",
548-
}
549-
expected_count = 7
550-
551-
# Use retry decorator for automatic polling and eventual reconciliation
552-
try:
553-
_validate_baseline_models(
554-
model_catalog_rest_url=model_catalog_rest_url,
555-
model_registry_rest_headers=model_registry_rest_headers,
556-
model_registry_namespace=model_registry_namespace,
557-
expected_models=expected_models,
558-
expected_count=expected_count,
559-
)
560-
except TimeoutExpiredError:
561-
pytest.fail("Failed to fetch model data after 300s timeout")

0 commit comments

Comments
 (0)