Skip to content

Commit da3aca5

Browse files
committed
feat: add test to validate api response against DB query
Signed-off-by: lugi0 <lgiorgi@redhat.com>
1 parent efee474 commit da3aca5

File tree

3 files changed

+255
-54
lines changed

3 files changed

+255
-54
lines changed

tests/model_registry/model_catalog/constants.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,42 @@
3636

3737
REDHAT_AI_FILTER: str = "Red+Hat+AI"
3838
REDHAT_AI_VALIDATED_FILTER = "Red+Hat+AI+Validated"
39+
40+
# SQL query for filter_options endpoint database validation
41+
# Replicates the exact database query used by GetFilterableProperties for the filter_options endpoint
42+
# in kubeflow/model-registry catalog/internal/db/service/catalog_model.go
43+
# Note: Uses dynamic type_id lookup via 'kf.CatalogModel' name since type_id appears to be dynamic
44+
FILTER_OPTIONS_DB_QUERY = """
45+
SELECT name, array_agg(string_value) FROM (
46+
SELECT
47+
name,
48+
string_value
49+
FROM "ContextProperty" WHERE
50+
context_id IN (
51+
SELECT id FROM "Context" WHERE type_id = (
52+
SELECT id FROM "Type" WHERE name = 'kf.CatalogModel'
53+
)
54+
)
55+
AND string_value IS NOT NULL
56+
AND string_value != ''
57+
AND string_value IS NOT JSON ARRAY
58+
59+
UNION
60+
61+
SELECT
62+
name,
63+
json_array_elements_text(string_value::json) AS string_value
64+
FROM "ContextProperty" WHERE
65+
context_id IN (
66+
SELECT id FROM "Context" WHERE type_id = (
67+
SELECT id FROM "Type" WHERE name = 'kf.CatalogModel'
68+
)
69+
)
70+
AND string_value IS JSON ARRAY
71+
)
72+
GROUP BY name HAVING MAX(CHAR_LENGTH(string_value)) <= 100;
73+
"""
74+
75+
# Fields that are explicitly filtered out by the filter_options endpoint API
76+
# From db_catalog.go:204-206 in kubeflow/model-registry GetFilterOptions method
77+
API_EXCLUDED_FILTER_FIELDS = {"source_id", "logo", "license_link"}

tests/model_registry/model_catalog/test_filter_options_endpoint.py

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
from typing import Self
33
from simple_logger.logger import get_logger
44

5-
from tests.model_registry.model_catalog.utils import execute_get_command, validate_filter_options_structure
5+
from tests.model_registry.model_catalog.utils import (
6+
execute_get_command,
7+
validate_filter_options_structure,
8+
parse_psql_array_agg_output,
9+
get_postgres_pod_in_namespace,
10+
compare_filter_options_with_database,
11+
)
12+
from tests.model_registry.model_catalog.constants import FILTER_OPTIONS_DB_QUERY, API_EXCLUDED_FILTER_FIELDS
613
from tests.model_registry.utils import get_rest_headers
714
from utilities.user_utils import UserTestSession
815

@@ -13,30 +20,30 @@
1320
]
1421

1522

16-
@pytest.mark.parametrize(
17-
"user_token_for_api_calls,",
18-
[
19-
pytest.param(
20-
{},
21-
id="test_filter_options_admin_user",
22-
),
23-
pytest.param(
24-
{"user_type": "test"},
25-
id="test_filter_options_non_admin_user",
26-
),
27-
pytest.param(
28-
{"user_type": "sa_user"},
29-
id="test_filter_options_service_account",
30-
),
31-
],
32-
indirect=["user_token_for_api_calls"],
33-
)
3423
class TestFilterOptionsEndpoint:
3524
"""
3625
Test class for validating the models/filter_options endpoint
3726
RHOAIENG-36696
3827
"""
3928

29+
@pytest.mark.parametrize(
30+
"user_token_for_api_calls,",
31+
[
32+
pytest.param(
33+
{},
34+
id="test_filter_options_admin_user",
35+
),
36+
pytest.param(
37+
{"user_type": "test"},
38+
id="test_filter_options_non_admin_user",
39+
),
40+
pytest.param(
41+
{"user_type": "sa_user"},
42+
id="test_filter_options_service_account",
43+
),
44+
],
45+
indirect=["user_token_for_api_calls"],
46+
)
4047
def test_filter_options_endpoint_validation(
4148
self: Self,
4249
model_catalog_rest_url: list[str],
@@ -74,48 +81,66 @@ def test_filter_options_endpoint_validation(
7481
LOGGER.info(f"Found {len(filters)} filter properties: {list(filters.keys())}")
7582
LOGGER.info("All filter options validation passed successfully")
7683

77-
@pytest.mark.skip(reason="TODO: Implement after investigating backend DB queries")
84+
# Cannot use non-admin user for this test as it cannot list the pods in the namespace
85+
@pytest.mark.parametrize(
86+
"user_token_for_api_calls,",
87+
[
88+
pytest.param(
89+
{},
90+
id="test_filter_options_admin_user",
91+
),
92+
pytest.param(
93+
{"user_type": "sa_user"},
94+
id="test_filter_options_service_account",
95+
),
96+
],
97+
indirect=["user_token_for_api_calls"],
98+
)
7899
def test_comprehensive_coverage_against_database(
79100
self: Self,
80101
model_catalog_rest_url: list[str],
81102
user_token_for_api_calls: str,
82-
test_idp_user: UserTestSession,
103+
model_registry_namespace: str,
83104
):
84105
"""
85-
STUBBED: Validate filter options are comprehensive across all sources/models in DB.
106+
Validate filter options are comprehensive across all sources/models in DB.
86107
Acceptance Criteria: The returned options are comprehensive and not limited to a
87108
subset of models or a single source.
88109
89-
TODO IMPLEMENTATION PLAN:
90-
1. Investigate backend endpoint logic:
91-
- Find the source code for /models/filter_options endpoint in kubeflow/model-registry
92-
- Understand what DB tables it queries (likely model/artifact tables)
93-
- Identify the exact SQL queries used to build filter values
94-
- Determine database schema and column names
95-
96-
2. Replicate queries via pod shell:
97-
- Use get_model_catalog_pod() to access catalog pod
98-
- Execute psql commands via pod.execute()
99-
- Query same tables/columns the endpoint uses
100-
- Extract all distinct values for string properties: SELECT DISTINCT license FROM models;
101-
- Extract min/max ranges for numeric properties: SELECT MIN(metric), MAX(metric) FROM models;
102-
103-
3. Compare results:
104-
- API response filter values should match DB query results exactly
105-
- Ensure no values are missing (comprehensive coverage)
106-
- Validate across all sources, not just one
107-
108-
4. DB Access Pattern Example:
109-
catalog_pod = get_model_catalog_pod(client, namespace)[0]
110-
result = catalog_pod.execute(
111-
command=["psql", "-U", "catalog_user", "-d", "catalog_db", "-c", "SELECT DISTINCT license FROM models;"],
112-
container="catalog"
113-
)
114-
115-
5. Implementation considerations:
116-
- Handle different data types (strings vs arrays like tasks)
117-
- Parse psql output correctly
118-
- Handle null/empty values
119-
- Ensure database connection credentials are available
110+
This test executes the exact same SQL query the API uses and compares results
111+
to catch any discrepancies between database content and API response.
112+
113+
Expected failure because of RHOAIENG-37069
120114
"""
121-
pytest.skip("TODO: Implement comprehensive coverage validation after backend investigation")
115+
api_url = f"{model_catalog_rest_url[0]}models/filter_options"
116+
LOGGER.info(f"Testing comprehensive database coverage for: {api_url}")
117+
118+
api_response = execute_get_command(
119+
url=api_url,
120+
headers=get_rest_headers(token=user_token_for_api_calls),
121+
)
122+
123+
api_filters = api_response["filters"]
124+
LOGGER.info(f"API returned {len(api_filters)} filter properties: {list(api_filters.keys())}")
125+
126+
postgres_pod = get_postgres_pod_in_namespace(namespace=model_registry_namespace)
127+
LOGGER.info(f"Using PostgreSQL pod: {postgres_pod.name}")
128+
129+
db_result = postgres_pod.execute(
130+
command=["psql", "-U", "catalog_user", "-d", "model_catalog", "-c", FILTER_OPTIONS_DB_QUERY],
131+
container="postgresql",
132+
)
133+
134+
db_properties = parse_psql_array_agg_output(psql_output=db_result)
135+
LOGGER.info(f"Raw database query returned {len(db_properties)} properties: {list(db_properties.keys())}")
136+
137+
is_valid, comparison_errors = compare_filter_options_with_database(
138+
api_filters=api_filters, db_properties=db_properties, excluded_fields=API_EXCLUDED_FILTER_FIELDS
139+
)
140+
141+
if not is_valid:
142+
failure_msg = "Filter options API response does not match database content"
143+
failure_msg += "\nDetailed comparison errors:\n" + "\n".join(comparison_errors)
144+
assert False, failure_msg
145+
146+
LOGGER.info("Comprehensive database coverage validation passed - API matches database exactly")

tests/model_registry/model_catalog/utils.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,140 @@ def validate_model_catalog_configmap_data(configmap: ConfigMap, num_catalogs: in
351351
assert len(catalogs) == num_catalogs, f"{configmap.name} should have {num_catalogs} catalog"
352352
if num_catalogs:
353353
validate_default_catalog(catalogs=catalogs)
354+
355+
356+
def parse_psql_array_agg_output(psql_output: str) -> dict[str, list[str]]:
357+
"""
358+
Parse psql output from array_agg query into Python dict.
359+
360+
Expected format:
361+
name | array_agg
362+
---------+----------
363+
license | {apache-2.0,mit,bsd}
364+
provider | {Meta,Microsoft}
365+
366+
Returns:
367+
dict mapping property names to lists of values
368+
"""
369+
result = {}
370+
lines = psql_output.strip().split("\n")
371+
372+
# Skip header lines (first 2-3 lines are typically headers and separators)
373+
data_started = False
374+
for line in lines:
375+
line = line.strip()
376+
if not line or line.startswith("-") or "|" not in line:
377+
continue
378+
379+
# Skip header row
380+
if "array_agg" in line and not data_started:
381+
data_started = True
382+
continue
383+
384+
if not data_started:
385+
continue
386+
387+
# Parse data row: "property_name | {val1,val2,val3}"
388+
parts = line.split("|", 1)
389+
if len(parts) != 2:
390+
continue
391+
392+
property_name = parts[0].strip()
393+
array_str = parts[1].strip()
394+
395+
# Parse PostgreSQL array format: {val1,val2,val3}
396+
if array_str.startswith("{") and array_str.endswith("}"):
397+
# Remove braces and split by comma
398+
values_str = array_str[1:-1]
399+
if values_str:
400+
# Handle escaped commas and quotes properly
401+
values = [v.strip().strip('"') for v in values_str.split(",")]
402+
result[property_name] = values
403+
else:
404+
result[property_name] = []
405+
406+
return result
407+
408+
409+
def get_postgres_pod_in_namespace(namespace: str = "rhoai-model-registries") -> Pod:
410+
"""Get the PostgreSQL pod for model catalog database."""
411+
postgres_pods = list(Pod.get(namespace=namespace, label_selector="app=model-catalog-postgres"))
412+
413+
if not postgres_pods:
414+
# Fallback: try finding by name pattern
415+
all_pods = list(Pod.get(namespace=namespace))
416+
postgres_pods = [pod for pod in all_pods if "postgres" in pod.name]
417+
418+
assert postgres_pods, f"No PostgreSQL pod found in namespace {namespace}"
419+
return postgres_pods[0]
420+
421+
422+
def compare_filter_options_with_database(
423+
api_filters: dict[str, Any], db_properties: dict[str, list[str]], excluded_fields: set[str]
424+
) -> Tuple[bool, List[str]]:
425+
"""
426+
Compare API filter options response with database query results.
427+
428+
Args:
429+
api_filters: The "filters" dict from API response
430+
db_properties: Raw database properties before API filtering
431+
excluded_fields: Fields that API excludes from response
432+
433+
Returns:
434+
Tuple of (is_valid, list_of_error_messages)
435+
"""
436+
comparison_errors = []
437+
438+
# Apply the same filtering logic the API uses
439+
expected_properties = {name: values for name, values in db_properties.items() if name not in excluded_fields}
440+
441+
LOGGER.info(f"Database returned {len(db_properties)} total properties")
442+
LOGGER.info(
443+
f"After applying API filtering, expecting {len(expected_properties)} properties: {list(expected_properties.keys())}" # noqa: E501
444+
)
445+
446+
# Check for missing/extra properties
447+
missing_in_api = set(expected_properties.keys()) - set(api_filters.keys())
448+
extra_in_api = set(api_filters.keys()) - set(expected_properties.keys())
449+
450+
# Log detailed comparison for each property
451+
for prop_name in sorted(set(expected_properties.keys()) | set(api_filters.keys())):
452+
if prop_name in expected_properties and prop_name in api_filters:
453+
db_values = set(expected_properties[prop_name])
454+
api_values = set(api_filters[prop_name]["values"])
455+
456+
missing_values = db_values - api_values
457+
extra_values = api_values - db_values
458+
459+
if missing_values:
460+
error_msg = (
461+
f"Property '{prop_name}': DB has {len(missing_values)} values missing from API: {missing_values}"
462+
)
463+
LOGGER.error(error_msg)
464+
comparison_errors.append(error_msg)
465+
if extra_values:
466+
error_msg = (
467+
f"Property '{prop_name}': API has {len(extra_values)} values missing from DB: {extra_values}"
468+
)
469+
LOGGER.error(error_msg)
470+
comparison_errors.append(error_msg)
471+
if not missing_values and not extra_values:
472+
LOGGER.info(f"Property '{prop_name}': Perfect match ({len(api_values)} values)")
473+
elif prop_name in expected_properties:
474+
error_msg = f"Property '{prop_name}': In DB ({len(expected_properties[prop_name])} values) but NOT in API"
475+
LOGGER.error(error_msg)
476+
comparison_errors.append(error_msg)
477+
elif prop_name in api_filters:
478+
error_msg = f"Property '{prop_name}': In API ({len(api_filters[prop_name]['values'])} values) but NOT in DB"
479+
LOGGER.error(error_msg)
480+
comparison_errors.append(error_msg)
481+
482+
# Check for property-level mismatches
483+
if missing_in_api:
484+
comparison_errors.append(f"API missing properties found in database: {missing_in_api}")
485+
486+
if extra_in_api:
487+
comparison_errors.append(f"API has extra properties not in database: {extra_in_api}")
488+
489+
is_valid = len(comparison_errors) == 0
490+
return is_valid, comparison_errors

0 commit comments

Comments
 (0)