Skip to content

Commit 8e1ac45

Browse files
authored
[FIX] flatten display of features and fix query list of lists (#881)
* flatten display of features and fix query list of lists * fix tests
1 parent 9064cd7 commit 8e1ac45

File tree

7 files changed

+101
-41
lines changed

7 files changed

+101
-41
lines changed

store/neurostore/resources/pipeline.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ def build_jsonpath(field_path: str, operator: str, value: str) -> str:
5252
Returns:
5353
PostgreSQL jsonpath query string
5454
"""
55-
5655
# Handle regular field queries
5756
cast_val, is_numeric = determine_value_type(value)
5857

@@ -83,35 +82,35 @@ def build_jsonpath(field_path: str, operator: str, value: str) -> str:
8382
raw_value = f'"{cast_val}"'
8483
raw_value = f"@ {sql_op} {raw_value}"
8584

86-
# Check if we're querying an array field
85+
# Check if we're querying array fields
8786
path_parts = field_path.split(".")
8887
if any(p.endswith("[]") for p in path_parts):
89-
# Handle array field queries
88+
query = "$"
9089
path_segments = []
91-
for i, part in enumerate(path_parts):
90+
91+
for part in path_parts:
9292
if part.endswith("[]"):
93-
# Convert path up to this point into the base path
94-
base_path = ".".join(path_segments)
93+
# When we hit an array, add previous path segments if any
94+
if path_segments:
95+
query += "." + ".".join(path_segments)
96+
path_segments = []
97+
# Add the array access
9598
array_field = part[:-2]
96-
remaining_path = ".".join(path_parts[i + 1:])
97-
98-
if remaining_path:
99-
full_path = (
100-
f"{base_path}.{array_field}" if base_path else array_field
101-
)
102-
return f"$.{full_path}[*] ? ({raw_value})".replace(
103-
"@", f"@.{remaining_path}"
104-
)
105-
else:
106-
full_path = (
107-
f"{base_path}.{array_field}" if base_path else array_field
108-
)
109-
return f"$.{full_path}[*] ? ({raw_value})"
99+
query += f".{array_field}[*]"
110100
else:
111101
path_segments.append(part)
112-
else:
113-
# Regular field query
114-
return f"$.{field_path} ? ({raw_value})"
102+
103+
# Add any remaining path segments
104+
if path_segments:
105+
query += "." + ".".join(path_segments)
106+
107+
# Add the filter condition
108+
query += f" ? ({raw_value})"
109+
110+
return query
111+
112+
# Regular field query
113+
return f"$.{field_path} ? ({raw_value})"
115114

116115

117116
def validate_pipeline_name(pipeline_name: str) -> None:
@@ -185,11 +184,6 @@ def parse_json_filter(filter_str: str) -> tuple:
185184
pipeline_name, field_spec = parts
186185
validate_pipeline_name(pipeline_name)
187186

188-
# Match array queries first
189-
# array_match = re.match(r"(.+?)\[\]=(.+)", field_spec)
190-
# if array_match:
191-
# return pipeline_name, array_match.group(1), "[]", array_match.group(2)
192-
193187
# Then match regular field queries
194188
match = re.match(r"(.+?)(~|=|>=|<=|>|<)(.+)", field_spec)
195189
if not match:

store/neurostore/schemas/data.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,27 @@ class BaseStudySchema(BaseDataSchema):
337337
features = fields.Method("get_features")
338338

339339
def get_features(self, obj):
340+
from .pipeline import PipelineStudyResultSchema
341+
340342
pipelines = self.context.get("feature_display", None)
341343

342344
if pipelines is None:
343345
return {}
344346

345-
return obj.display_features(pipelines)
347+
features = obj.display_features(pipelines)
348+
# Flatten each pipeline's predictions
349+
if features:
350+
flattened_features = {}
351+
for pipeline_name, feature_data in features.items():
352+
if isinstance(feature_data, dict):
353+
flattened_features[pipeline_name] = (
354+
PipelineStudyResultSchema.flatten_dict(feature_data)
355+
)
356+
else:
357+
flattened_features[pipeline_name] = feature_data
358+
return flattened_features
359+
360+
return features
346361

347362
class Meta:
348363
additional = (

store/neurostore/schemas/pipeline.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,43 @@ class PipelineStudyResultSchema(BaseSchema):
4242
class Meta:
4343
model = PipelineStudyResult
4444

45+
@classmethod
46+
def flatten_dict(cls, d, parent_key="", sep="."):
47+
"""Flatten nested dictionaries and arrays containing dictionaries."""
48+
items = []
49+
for k, v in d.items():
50+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
51+
52+
if isinstance(v, dict):
53+
items.extend(cls.flatten_dict(v, new_key, sep=sep).items())
54+
elif isinstance(v, list):
55+
if v and all(isinstance(item, dict) for item in v):
56+
# For arrays of dictionaries, include index in the key
57+
for idx, dict_item in enumerate(v):
58+
array_key = f"{new_key}[{idx}]"
59+
flattened = cls.flatten_dict(dict_item)
60+
for sub_key, sub_value in flattened.items():
61+
items.append((f"{array_key}.{sub_key}", sub_value))
62+
else:
63+
# Keep non-dictionary arrays intact
64+
items.append((new_key, v))
65+
else:
66+
items.append((new_key, v))
67+
return dict(items)
68+
4569
@post_dump
46-
def remove_none(self, data, **kwargs):
47-
"""Remove null values from serialized output."""
48-
return {key: value for key, value in data.items() if value is not None}
70+
def remove_none_and_flatten(self, data, **kwargs):
71+
"""Remove null values and flatten nested dictionaries in result_data."""
72+
# Remove None values
73+
data = {key: value for key, value in data.items() if value is not None}
74+
75+
# Flatten result_data if it exists
76+
if "result_data" in data and isinstance(data["result_data"], dict):
77+
# Get predictions section which contains our nested data
78+
result_data = data["result_data"]
79+
data["result_data"] = self.flatten_dict(result_data)
80+
81+
return data
4982

5083

5184
# Register schemas

store/neurostore/tests/api/test_base_studies.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ def test_features_query(auth_client, ingest_demographic_features):
2828
)
2929
assert result.status_code == 200
3030
assert "features" in result.json()["results"][0]
31-
assert (
32-
"age_mean"
33-
in result.json()["results"][0]["features"]["ParticipantInfo"]["predictions"][
34-
"groups"
35-
][0]
31+
features = result.json()["results"][0]["features"]["ParticipantInfo"]
32+
assert any(
33+
key.startswith("predictions") and key.endswith("].age_mean") for key in features
3634
)
3735

3836

@@ -98,8 +96,10 @@ def test_features_query_with_or(auth_client, ingest_demographic_features, sessio
9896

9997
api_diagnoses = set()
10098
for res in result.json()["results"]:
101-
for group in res["features"]["ParticipantInfo"]["predictions"]["groups"]:
102-
api_diagnoses.add(group["diagnosis"])
99+
features = res["features"]["ParticipantInfo"]
100+
# Get all diagnosis values from flattened structure
101+
diagnoses = [v for k, v in features.items() if k.endswith(".diagnosis")]
102+
api_diagnoses.update(diagnoses)
103103

104104
# Compare database and API results
105105
assert db_diagnoses == api_diagnoses

store/neurostore/tests/api/test_json_queries.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ def test_pipeline_multiple_filters(auth_client, study_pipeline_data):
145145
assert resp.status_code == 200
146146

147147

148+
def test_search_list_of_lists(auth_client, study_pipeline_data):
149+
"""Test search queries on lists of lists."""
150+
# Test searching for a specific task name in a list of lists
151+
resp = auth_client.get(
152+
(
153+
"/api/pipeline-study-results?feature_filter="
154+
"TaskInfo:predictions.fMRITasks[].Concepts[]~emotion"
155+
)
156+
)
157+
assert resp.status_code == 200
158+
results = resp.json()["results"]
159+
assert len(results) > 0
160+
161+
148162
@pytest.mark.parametrize(
149163
"query,expected_error",
150164
[

store/neurostore/tests/api/test_pipeline_resources.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,13 @@ def result2(pipeline_study_result_payload, session):
9090
"TestPipeline:nested.array[]=nested1",
9191
1,
9292
"nested1",
93-
lambda x: x["result_data"]["nested"]["array"],
93+
lambda x: x["result_data"]["nested.array"],
9494
),
9595
(
9696
"TestPipeline:nested.string~other",
9797
1,
9898
"other",
99-
lambda x: x["result_data"]["nested"]["string"],
99+
lambda x: x["result_data"]["nested.string"],
100100
),
101101
(
102102
"TestPipeline:array_field[]=value3",

store/neurostore/tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,10 @@ def create_pipeline_results(session, ingest_neurosynth, tmp_path):
702702
"fMRITasks": [
703703
{
704704
"TaskName": random.choice(["oddball", "n-back", "rest"]),
705+
"Concepts": random.sample(
706+
["emotion", "memory", "attention", "learning"],
707+
k=random.randint(1, 3),
708+
),
705709
"TaskDescription": (
706710
"Participants performed a "
707711
f"{random.choice(['visual', 'auditory'])} task"

0 commit comments

Comments
 (0)