Skip to content

Commit 125893f

Browse files
authored
Add test to validate filter_options api response against DB query (#748)
* feat: add test to validate api response against DB query Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: resolve merge conflict Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: add xfail marker to prevent red failure Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: add small comment Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: add additional jira id for xfail Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: move db constants to their own file Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: push untracked change Signed-off-by: lugi0 <lgiorgi@redhat.com> * fix: fail if db pod not found with label Signed-off-by: lugi0 <lgiorgi@redhat.com> --------- Signed-off-by: lugi0 <lgiorgi@redhat.com>
1 parent 780ee13 commit 125893f

File tree

4 files changed

+1059
-864
lines changed

4 files changed

+1059
-864
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Constants useful for querying the model catalog database and parsing its responses
2+
3+
# SQL query for filter_options endpoint database validation
4+
# Replicates the exact database query used by GetFilterableProperties for the filter_options endpoint
5+
# in kubeflow/model-registry catalog/internal/db/service/catalog_model.go
6+
# Note: Uses dynamic type_id lookup via 'kf.CatalogModel' name since type_id appears to be dynamic
7+
FILTER_OPTIONS_DB_QUERY = """
8+
SELECT name, array_agg(string_value) FROM (
9+
SELECT
10+
name,
11+
string_value
12+
FROM "ContextProperty" WHERE
13+
context_id IN (
14+
SELECT id FROM "Context" WHERE type_id = (
15+
SELECT id FROM "Type" WHERE name = 'kf.CatalogModel'
16+
)
17+
)
18+
AND string_value IS NOT NULL
19+
AND string_value != ''
20+
AND string_value IS NOT JSON ARRAY
21+
22+
UNION
23+
24+
SELECT
25+
name,
26+
json_array_elements_text(string_value::json) AS string_value
27+
FROM "ContextProperty" WHERE
28+
context_id IN (
29+
SELECT id FROM "Context" WHERE type_id = (
30+
SELECT id FROM "Type" WHERE name = 'kf.CatalogModel'
31+
)
32+
)
33+
AND string_value IS JSON ARRAY
34+
)
35+
GROUP BY name HAVING MAX(CHAR_LENGTH(string_value)) <= 100;
36+
"""
37+
38+
# Fields that are explicitly filtered out by the filter_options endpoint API
39+
# From db_catalog.go:204-206 in kubeflow/model-registry GetFilterOptions method
40+
API_EXCLUDED_FILTER_FIELDS = {"source_id", "logo", "license_link"}

tests/model_registry/model_catalog/test_filter_options_endpoint.py

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@
22
from typing import Self
33
from simple_logger.logger import get_logger
44

5-
from tests.model_registry.model_catalog.utils import validate_filter_options_structure
5+
from tests.model_registry.model_catalog.utils import (
6+
validate_filter_options_structure,
7+
parse_psql_array_agg_output,
8+
get_postgres_pod_in_namespace,
9+
compare_filter_options_with_database,
10+
)
11+
from tests.model_registry.model_catalog.db_constants import FILTER_OPTIONS_DB_QUERY, API_EXCLUDED_FILTER_FIELDS
612
from tests.model_registry.utils import get_rest_headers, execute_get_command
713
from utilities.user_utils import UserTestSession
814

@@ -13,30 +19,30 @@
1319
]
1420

1521

16-
@pytest.mark.parametrize(
17-
"user_token_for_api_calls,",
18-
[
19-
pytest.param(
20-
{},
21-
id="test_filter_options_admin_user",
22-
),
23-
pytest.param(
24-
{"user_type": "test"},
25-
id="test_filter_options_non_admin_user",
26-
),
27-
pytest.param(
28-
{"user_type": "sa_user"},
29-
id="test_filter_options_service_account",
30-
),
31-
],
32-
indirect=["user_token_for_api_calls"],
33-
)
3422
class TestFilterOptionsEndpoint:
3523
"""
3624
Test class for validating the models/filter_options endpoint
3725
RHOAIENG-36696
3826
"""
3927

28+
@pytest.mark.parametrize(
29+
"user_token_for_api_calls,",
30+
[
31+
pytest.param(
32+
{},
33+
id="test_filter_options_admin_user",
34+
),
35+
pytest.param(
36+
{"user_type": "test"},
37+
id="test_filter_options_non_admin_user",
38+
),
39+
pytest.param(
40+
{"user_type": "sa_user"},
41+
id="test_filter_options_service_account",
42+
),
43+
],
44+
indirect=["user_token_for_api_calls"],
45+
)
4046
def test_filter_options_endpoint_validation(
4147
self: Self,
4248
model_catalog_rest_url: list[str],
@@ -74,48 +80,67 @@ def test_filter_options_endpoint_validation(
7480
LOGGER.info(f"Found {len(filters)} filter properties: {list(filters.keys())}")
7581
LOGGER.info("All filter options validation passed successfully")
7682

77-
@pytest.mark.skip(reason="TODO: Implement after investigating backend DB queries")
83+
# Cannot use non-admin user for this test as it cannot list the pods in the namespace
84+
@pytest.mark.parametrize(
85+
"user_token_for_api_calls,",
86+
[
87+
pytest.param(
88+
{},
89+
id="test_filter_options_admin_user",
90+
),
91+
pytest.param(
92+
{"user_type": "sa_user"},
93+
id="test_filter_options_service_account",
94+
),
95+
],
96+
indirect=["user_token_for_api_calls"],
97+
)
98+
@pytest.mark.xfail(strict=True, reason="RHOAIENG-37069: backend/API discrepancy expected")
7899
def test_comprehensive_coverage_against_database(
79100
self: Self,
80101
model_catalog_rest_url: list[str],
81102
user_token_for_api_calls: str,
82-
test_idp_user: UserTestSession,
103+
model_registry_namespace: str,
83104
):
84105
"""
85-
STUBBED: Validate filter options are comprehensive across all sources/models in DB.
106+
Validate filter options are comprehensive across all sources/models in DB.
86107
Acceptance Criteria: The returned options are comprehensive and not limited to a
87108
subset of models or a single source.
88109
89-
TODO IMPLEMENTATION PLAN:
90-
1. Investigate backend endpoint logic:
91-
- Find the source code for /models/filter_options endpoint in kubeflow/model-registry
92-
- Understand what DB tables it queries (likely model/artifact tables)
93-
- Identify the exact SQL queries used to build filter values
94-
- Determine database schema and column names
95-
96-
2. Replicate queries via pod shell:
97-
- Use get_model_catalog_pod() to access catalog pod
98-
- Execute psql commands via pod.execute()
99-
- Query same tables/columns the endpoint uses
100-
- Extract all distinct values for string properties: SELECT DISTINCT license FROM models;
101-
- Extract min/max ranges for numeric properties: SELECT MIN(metric), MAX(metric) FROM models;
102-
103-
3. Compare results:
104-
- API response filter values should match DB query results exactly
105-
- Ensure no values are missing (comprehensive coverage)
106-
- Validate across all sources, not just one
107-
108-
4. DB Access Pattern Example:
109-
catalog_pod = get_model_catalog_pod(client, namespace)[0]
110-
result = catalog_pod.execute(
111-
command=["psql", "-U", "catalog_user", "-d", "catalog_db", "-c", "SELECT DISTINCT license FROM models;"],
112-
container="catalog"
113-
)
114-
115-
5. Implementation considerations:
116-
- Handle different data types (strings vs arrays like tasks)
117-
- Parse psql output correctly
118-
- Handle null/empty values
119-
- Ensure database connection credentials are available
110+
This test executes the exact same SQL query the API uses and compares results
111+
to catch any discrepancies between database content and API response.
112+
113+
Expected failure because of RHOAIENG-37069 & RHOAIENG-37226
120114
"""
121-
pytest.skip("TODO: Implement comprehensive coverage validation after backend investigation")
115+
api_url = f"{model_catalog_rest_url[0]}models/filter_options"
116+
LOGGER.info(f"Testing comprehensive database coverage for: {api_url}")
117+
118+
api_response = execute_get_command(
119+
url=api_url,
120+
headers=get_rest_headers(token=user_token_for_api_calls),
121+
)
122+
123+
api_filters = api_response["filters"]
124+
LOGGER.info(f"API returned {len(api_filters)} filter properties: {list(api_filters.keys())}")
125+
126+
postgres_pod = get_postgres_pod_in_namespace(namespace=model_registry_namespace)
127+
LOGGER.info(f"Using PostgreSQL pod: {postgres_pod.name}")
128+
129+
db_result = postgres_pod.execute(
130+
command=["psql", "-U", "catalog_user", "-d", "model_catalog", "-c", FILTER_OPTIONS_DB_QUERY],
131+
container="postgresql",
132+
)
133+
134+
db_properties = parse_psql_array_agg_output(psql_output=db_result)
135+
LOGGER.info(f"Raw database query returned {len(db_properties)} properties: {list(db_properties.keys())}")
136+
137+
is_valid, comparison_errors = compare_filter_options_with_database(
138+
api_filters=api_filters, db_properties=db_properties, excluded_fields=API_EXCLUDED_FILTER_FIELDS
139+
)
140+
141+
if not is_valid:
142+
failure_msg = "Filter options API response does not match database content"
143+
failure_msg += "\nDetailed comparison errors:\n" + "\n".join(comparison_errors)
144+
assert False, failure_msg
145+
146+
LOGGER.info("Comprehensive database coverage validation passed - API matches database exactly")

tests/model_registry/model_catalog/utils.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,140 @@ def validate_model_catalog_configmap_data(configmap: ConfigMap, num_catalogs: in
244244
validate_default_catalog(catalogs=catalogs)
245245

246246

247+
def parse_psql_array_agg_output(psql_output: str) -> dict[str, list[str]]:
248+
"""
249+
Parse psql output from array_agg query into Python dict.
250+
251+
Expected format:
252+
name | array_agg
253+
---------+----------
254+
license | {apache-2.0,mit,bsd}
255+
provider | {Meta,Microsoft}
256+
257+
Returns:
258+
dict mapping property names to lists of values
259+
"""
260+
result = {}
261+
lines = psql_output.strip().split("\n")
262+
263+
# Skip header lines (first 2-3 lines are typically headers and separators)
264+
data_started = False
265+
for line in lines:
266+
line = line.strip()
267+
if not line or line.startswith("-") or "|" not in line:
268+
continue
269+
270+
# Skip header row
271+
if "array_agg" in line and not data_started:
272+
data_started = True
273+
continue
274+
275+
if not data_started:
276+
continue
277+
278+
# Parse data row: "property_name | {val1,val2,val3}"
279+
parts = line.split("|", 1)
280+
if len(parts) != 2:
281+
continue
282+
283+
property_name = parts[0].strip()
284+
array_str = parts[1].strip()
285+
286+
# Parse PostgreSQL array format: {val1,val2,val3}
287+
if array_str.startswith("{") and array_str.endswith("}"):
288+
# Remove braces and split by comma
289+
values_str = array_str[1:-1]
290+
if values_str:
291+
# Handle escaped commas and quotes properly
292+
values = [v.strip().strip('"') for v in values_str.split(",")]
293+
result[property_name] = values
294+
else:
295+
result[property_name] = []
296+
297+
return result
298+
299+
300+
def get_postgres_pod_in_namespace(namespace: str = "rhoai-model-registries") -> Pod:
301+
"""Get the PostgreSQL pod for model catalog database."""
302+
postgres_pods = list(Pod.get(namespace=namespace, label_selector="app.kubernetes.io/name=model-catalog-postgres"))
303+
assert postgres_pods, f"No PostgreSQL pod found in namespace {namespace}"
304+
return postgres_pods[0]
305+
306+
307+
def compare_filter_options_with_database(
308+
api_filters: dict[str, Any], db_properties: dict[str, list[str]], excluded_fields: set[str]
309+
) -> Tuple[bool, List[str]]:
310+
"""
311+
Compare API filter options response with database query results.
312+
313+
Note: Currently assumes all properties are string types. Numeric/range
314+
properties are not returned by the API or DB query at this time.
315+
316+
Args:
317+
api_filters: The "filters" dict from API response
318+
db_properties: Raw database properties before API filtering
319+
excluded_fields: Fields that API excludes from response
320+
321+
Returns:
322+
Tuple of (is_valid, list_of_error_messages)
323+
"""
324+
comparison_errors = []
325+
326+
# Apply the same filtering logic the API uses
327+
expected_properties = {name: values for name, values in db_properties.items() if name not in excluded_fields}
328+
329+
LOGGER.info(f"Database returned {len(db_properties)} total properties")
330+
LOGGER.info(
331+
f"After applying API filtering, expecting {len(expected_properties)} properties: {list(expected_properties.keys())}" # noqa: E501
332+
)
333+
334+
# Check for missing/extra properties
335+
missing_in_api = set(expected_properties.keys()) - set(api_filters.keys())
336+
extra_in_api = set(api_filters.keys()) - set(expected_properties.keys())
337+
338+
# Log detailed comparison for each property
339+
for prop_name in sorted(set(expected_properties.keys()) | set(api_filters.keys())):
340+
if prop_name in expected_properties and prop_name in api_filters:
341+
db_values = set(expected_properties[prop_name])
342+
api_values = set(api_filters[prop_name]["values"])
343+
344+
missing_values = db_values - api_values
345+
extra_values = api_values - db_values
346+
347+
if missing_values:
348+
error_msg = (
349+
f"Property '{prop_name}': DB has {len(missing_values)} values missing from API: {missing_values}"
350+
)
351+
LOGGER.error(error_msg)
352+
comparison_errors.append(error_msg)
353+
if extra_values:
354+
error_msg = (
355+
f"Property '{prop_name}': API has {len(extra_values)} values missing from DB: {extra_values}"
356+
)
357+
LOGGER.error(error_msg)
358+
comparison_errors.append(error_msg)
359+
if not missing_values and not extra_values:
360+
LOGGER.info(f"Property '{prop_name}': Perfect match ({len(api_values)} values)")
361+
elif prop_name in expected_properties:
362+
error_msg = f"Property '{prop_name}': In DB ({len(expected_properties[prop_name])} values) but NOT in API"
363+
LOGGER.error(error_msg)
364+
comparison_errors.append(error_msg)
365+
elif prop_name in api_filters:
366+
error_msg = f"Property '{prop_name}': In API ({len(api_filters[prop_name]['values'])} values) but NOT in DB"
367+
LOGGER.error(error_msg)
368+
comparison_errors.append(error_msg)
369+
370+
# Check for property-level mismatches
371+
if missing_in_api:
372+
comparison_errors.append(f"API missing properties found in database: {missing_in_api}")
373+
374+
if extra_in_api:
375+
comparison_errors.append(f"API has extra properties not in database: {extra_in_api}")
376+
377+
is_valid = len(comparison_errors) == 0
378+
return is_valid, comparison_errors
379+
380+
247381
def get_models_from_catalog_api(
248382
model_catalog_rest_url: list[str],
249383
model_registry_rest_headers: dict[str, str],

0 commit comments

Comments
 (0)