Skip to content

Commit 405f105

Browse files
authored
[FIX] feature errors (#912)
* add more informative errors * treat derived from as jsonb * fix test errors * fix other test names * fix boolean * style * fix style * fix style * fix style
1 parent 411239c commit 405f105

File tree

7 files changed

+151
-166
lines changed

7 files changed

+151
-166
lines changed

store/neurostore/ingest/extracted_features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,10 @@ def ingest_feature(feature_directory, overwrite=False):
102102
print(f"Skipping {paper_dir} as it contains invalid JSON in info.json")
103103
continue
104104

105+
# sometimes the model returns a boolean instead of a dict
106+
if isinstance(results, bool):
107+
results = {}
108+
105109
# check for existing result
106110
existing_result = (
107111
db.session.query(PipelineStudyResult)

store/neurostore/models/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ class Pipeline(BaseMixin, db.Model):
712712
study_dependent = db.Column(db.Boolean, default=False)
713713
ace_compatible = db.Column(db.Boolean, default=False)
714714
pubget_compatible = db.Column(db.Boolean, default=False)
715-
derived_from = db.Column(db.Text)
715+
derived_from = db.Column(JSONB)
716716

717717

718718
class PipelineConfig(BaseMixin, db.Model):

store/neurostore/resources/base.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def update_or_create(cls, data, id=None, user=None, record=None, flush=True):
288288
description=(
289289
"You do not have permission to modify this record. "
290290
"You must be the owner or the compose bot."
291-
)
291+
),
292292
)
293293
elif only_ids:
294294
to_commit.append(record)
@@ -536,7 +536,7 @@ def delete(self, id):
536536
description=(
537537
"You do not have permission to delete this record. "
538538
"Only the owner can delete records."
539-
)
539+
),
540540
)
541541
else:
542542
db.session.delete(record)
@@ -603,14 +603,26 @@ def join_tables(self, q, args):
603603
return q
604604
return q.options(selectinload(self._model.user))
605605

606-
def serialize_records(self, records, args, exclude=tuple()):
607-
"""serialize records from search"""
608-
content = self._schema(
609-
exclude=exclude,
610-
many=True,
611-
context=args,
612-
).dump(records)
613-
return content
606+
def serialize_records(self, records, args, exclude=None):
607+
schema_many = self._schema(exclude=exclude, many=True, context=args)
608+
609+
try:
610+
# Fast path
611+
return schema_many.dump(records)
612+
except Exception as e:
613+
# Fall back to manual loop to isolate the problem
614+
schema = self._schema(exclude=exclude, many=False, context=args)
615+
for idx, record in enumerate(records):
616+
try:
617+
schema.dump(record)
618+
except Exception as rec_err:
619+
# logger.error("Serialization failed on record #%d: %s", idx, record)
620+
raise ValueError(
621+
f"Serialization failed on record #{idx}: {record}. Error: {rec_err}"
622+
) from rec_err
623+
624+
# If somehow we didn't catch the failing record, re-raise the original error
625+
raise e
614626

615627
def create_metadata(self, q, total):
616628
return {"total_count": total}

store/neurostore/tests/api/test_base_studies.py

Lines changed: 46 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,16 @@ def test_features_query(auth_client, ingest_demographic_features):
2020
# flatten the features (flatten json objects)
2121
# test features organized like this: {top_key: ["list", "of", "values"]}
2222
result = auth_client.get(
23-
(
24-
"/api/base-studies/?feature_filter=ParticipantInfo:predictions.groups[].age_mean>10&"
25-
"feature_filter=ParticipantInfo:predictions.groups[].age_mean<=100&"
26-
"feature_display=ParticipantInfo&"
27-
"feature_flatten=true"
28-
)
23+
"/api/base-studies/?feature_filter="
24+
"ParticipantDemographicsExtractor:predictions.groups[].age_mean>10&"
25+
"feature_filter=ParticipantDemographicsExtractor:predictions.groups[].age_mean<=100&"
26+
"feature_display=ParticipantDemographicsExtractor&"
27+
"feature_flatten=true"
2928
)
3029
assert result.status_code == 200
3130
assert "features" in result.json()["results"][0]
32-
features = result.json()["results"][0]["features"]["ParticipantInfo"]
33-
assert any(
34-
key.startswith("predictions") and key.endswith("].age_mean") for key in features
35-
)
31+
features = result.json()["results"][0]["features"]["ParticipantDemographicsExtractor"]
32+
assert any(key.startswith("predictions") and key.endswith("].age_mean") for key in features)
3633

3734

3835
def test_features_query_with_or(auth_client, ingest_demographic_features, session):
@@ -56,7 +53,9 @@ def test_features_query_with_or(auth_client, ingest_demographic_features, sessio
5653
PipelineAlias,
5754
PipelineConfigAlias.pipeline_id == PipelineAlias.id,
5855
)
59-
.filter(PipelineAlias.name == "ParticipantInfo") # Filter for specific pipeline
56+
.filter(
57+
PipelineAlias.name == "ParticipantDemographicsExtractor"
58+
) # Filter for specific pipeline
6059
.group_by(PipelineStudyResultAlias.base_study_id)
6160
.subquery()
6261
)
@@ -72,12 +71,9 @@ def test_features_query_with_or(auth_client, ingest_demographic_features, sessio
7271
.join(
7372
latest_results,
7473
(PipelineStudyResultAlias.base_study_id == latest_results.c.base_study_id)
75-
& (
76-
PipelineStudyResultAlias.date_executed
77-
>= latest_results.c.max_date_executed
78-
),
74+
& (PipelineStudyResultAlias.date_executed >= latest_results.c.max_date_executed),
7975
)
80-
.filter(PipelineAlias.name == "ParticipantInfo")
76+
.filter(PipelineAlias.name == "ParticipantDemographicsExtractor")
8177
.filter(
8278
text(
8379
"jsonb_path_exists(result_data, '$.predictions.groups[*].diagnosis ?"
@@ -94,20 +90,18 @@ def test_features_query_with_or(auth_client, ingest_demographic_features, sessio
9490

9591
# Now make the API request
9692
result = auth_client.get(
97-
(
98-
"/api/base-studies/?feature_filter="
99-
"ParticipantInfo:predictions.groups[].diagnosis=ADHD|ASD&"
100-
"feature_display=ParticipantInfo&"
101-
"feature_flatten=true"
102-
)
93+
"/api/base-studies/?feature_filter="
94+
"ParticipantDemographicsExtractor:predictions.groups[].diagnosis=ADHD|ASD&"
95+
"feature_display=ParticipantDemographicsExtractor&"
96+
"feature_flatten=true"
10397
)
10498

10599
assert result.status_code == 200
106100
assert "features" in result.json()["results"][0]
107101

108102
api_diagnoses = set()
109103
for res in result.json()["results"]:
110-
features = res["features"]["ParticipantInfo"]
104+
features = res["features"]["ParticipantDemographicsExtractor"]
111105
# Get all diagnosis values from flattened structure
112106
diagnoses = [v for k, v in features.items() if k.endswith(".diagnosis")]
113107
api_diagnoses.update(diagnoses)
@@ -217,9 +211,7 @@ def test_info_base_study(auth_client, ingest_neurosynth, session):
217211
assert single_reg_resp.status_code == 200
218212

219213
info_fields = [
220-
f
221-
for f, v in StudySchema._declared_fields.items()
222-
if v.metadata.get("info_field")
214+
f for f, v in StudySchema._declared_fields.items() if v.metadata.get("info_field")
223215
]
224216

225217
study = single_info_resp.json()["versions"][0]
@@ -395,9 +387,7 @@ def test_has_coordinates_images(auth_client, session):
395387
assert base_study_2.has_images is True
396388

397389
# delete the full study
398-
delete_study = auth_client.delete(
399-
f"/api/studies/{create_full_study_again.json()['id']}"
400-
)
390+
delete_study = auth_client.delete(f"/api/studies/{create_full_study_again.json()['id']}")
401391

402392
assert delete_study.status_code == 200
403393
session.refresh(base_study_2)
@@ -410,8 +400,9 @@ def test_config_and_feature_filters(auth_client, ingest_demographic_features, se
410400
# Test combined feature and config filtering
411401
response = auth_client.get(
412402
"/api/base-studies/?"
413-
"feature_filter=ParticipantInfo:1.0.0:predictions.groups[].age_mean>25&"
414-
"pipeline_config=ParticipantInfo:1.0.0:extractor_kwargs.extraction_model=gpt-4-turbo"
403+
"feature_filter=ParticipantDemographicsExtractor:1.0.0:predictions.groups[].age_mean>25&"
404+
"pipeline_config=ParticipantDemographicsExtractor:"
405+
"1.0.0:extractor_kwargs.extraction_model=gpt-4-turbo"
415406
)
416407

417408
assert response.status_code == 200
@@ -420,16 +411,17 @@ def test_config_and_feature_filters(auth_client, ingest_demographic_features, se
420411
# Test with mismatched version
421412
response = auth_client.get(
422413
"/api/base-studies/?"
423-
"feature_filter=ParticipantInfo:2.0.0:predictions.groups[].age_mean>30&"
424-
"pipeline_config=ParticipantInfo:2.0.0:extractor_kwargs.extraction_model=gpt-4-turbo"
414+
"feature_filter=ParticipantDemographicsExtractor:2.0.0:predictions.groups[].age_mean>30&"
415+
"pipeline_config=ParticipantDemographicsExtractor:2.0.0:"
416+
"extractor_kwargs.extraction_model=gpt-4-turbo"
425417
)
426418

427419
assert response.status_code == 200
428420
assert len(response.json()["results"]) == 0
429421

430422
# Test error handling for invalid filter format
431423
response = auth_client.get(
432-
"/api/base-studies/?" "pipeline_config=ParticipantInfo:invalid:filter:format"
424+
"/api/base-studies/?pipeline_config=ParticipantDemographicsExtractor:invalid:filter:format"
433425
)
434426

435427
assert response.status_code == 400
@@ -440,37 +432,39 @@ def test_feature_display_and_pipeline_config(auth_client, ingest_demographic_fea
440432
# Test feature display with version specified
441433
response = auth_client.get(
442434
"/api/base-studies/?"
443-
"feature_display=ParticipantInfo:1.0.0&"
444-
"pipeline_config=ParticipantInfo:1.0.0:extractor_kwargs.extraction_model=gpt-4-turbo"
435+
"feature_display=ParticipantDemographicsExtractor:1.0.0&"
436+
"pipeline_config=ParticipantDemographicsExtractor:1.0.0:"
437+
"extractor_kwargs.extraction_model=gpt-4-turbo"
445438
)
446439
assert response.status_code == 200
447440
results = response.json()["results"]
448441
assert len(results) > 0
449442
assert "features" in results[0]
450-
assert "ParticipantInfo" in results[0]["features"]
443+
assert "ParticipantDemographicsExtractor" in results[0]["features"]
451444

452445
# Test default behavior when version not specified (should use latest version)
453446
default_response = auth_client.get(
454447
"/api/base-studies/?"
455-
"feature_display=ParticipantInfo&"
456-
"pipeline_config=ParticipantInfo:extractor_kwargs.extraction_model=gpt-4-turbo"
448+
"feature_display=ParticipantDemographicsExtractor&"
449+
"pipeline_config=ParticipantDemographicsExtractor:"
450+
"extractor_kwargs.extraction_model=gpt-4-turbo"
457451
)
458452
assert default_response.status_code == 200
459453
assert len(default_response.json()["results"]) > 0
460454

461455
# Verify the output structure
462456
result = default_response.json()["results"][0]
463457
assert "features" in result
464-
features = result["features"]["ParticipantInfo"]
458+
features = result["features"]["ParticipantDemographicsExtractor"]
465459
assert isinstance(features, dict)
466460
if "predictions" in features:
467461
assert isinstance(features["predictions"], dict)
468462

469463
# Test mismatched versions between feature_display and pipeline_config
470464
mismatch_response = auth_client.get(
471465
"/api/base-studies/?"
472-
"feature_display=ParticipantInfo:1.0.0&"
473-
"pipeline_config=ParticipantInfo:2.0.0:model_version=2"
466+
"feature_display=ParticipantDemographicsExtractor:1.0.0&"
467+
"pipeline_config=ParticipantDemographicsExtractor:2.0.0:model_version=2"
474468
)
475469
assert mismatch_response.status_code == 200
476470
assert len(mismatch_response.json()["results"]) == 0
@@ -479,12 +473,14 @@ def test_feature_display_and_pipeline_config(auth_client, ingest_demographic_fea
479473
def test_feature_flatten(auth_client, ingest_demographic_features):
480474
"""Test flattening nested feature objects into dot notation"""
481475
# Get response without flattening
482-
unflattened = auth_client.get("/api/base-studies/?feature_display=ParticipantInfo")
476+
unflattened = auth_client.get(
477+
"/api/base-studies/?feature_display=ParticipantDemographicsExtractor"
478+
)
483479
assert unflattened.status_code == 200
484480

485481
# Get response with flattening
486482
flattened = auth_client.get(
487-
"/api/base-studies/?feature_display=ParticipantInfo&feature_flatten=true"
483+
"/api/base-studies/?feature_display=ParticipantDemographicsExtractor&feature_flatten=true"
488484
)
489485
assert flattened.status_code == 200
490486

@@ -494,21 +490,19 @@ def test_feature_flatten(auth_client, ingest_demographic_features):
494490

495491
# Get the feature dictionaries
496492
unflattened_features = unflattened.json()["results"][0]["features"][
497-
"ParticipantInfo"
493+
"ParticipantDemographicsExtractor"
494+
]
495+
flattened_features = flattened.json()["results"][0]["features"][
496+
"ParticipantDemographicsExtractor"
498497
]
499-
flattened_features = flattened.json()["results"][0]["features"]["ParticipantInfo"]
500498

501499
# Verify features are flattened in dot notation
502500
# Check nested predictions.groups objects are flattened
503-
assert any(
504-
key.startswith("predictions.groups") for key in flattened_features.keys()
505-
)
501+
assert any(key.startswith("predictions.groups") for key in flattened_features.keys())
506502

507503
# Verify values are preserved after flattening
508504
# Example: predictions.groups[0].age_mean should equal the nested value
509-
if "predictions" in unflattened_features and unflattened_features[
510-
"predictions"
511-
].get("groups"):
505+
if "predictions" in unflattened_features and unflattened_features["predictions"].get("groups"):
512506
nested_age = unflattened_features["predictions"]["groups"][0].get("age_mean")
513507
if nested_age is not None:
514508
flattened_age = flattened_features.get("predictions.groups[0].age_mean")

0 commit comments

Comments
 (0)