Skip to content

Commit 509b776

Browse files
authored
[ENH] jsonb query modality (#1179)
* speed up slow query * add special logic for modality * make predicate match * run black
1 parent 03f9898 commit 509b776

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

store/backend/neurostore/models/data.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,13 @@ def validate_version(self, key, value):
753753

754754
class PipelineStudyResult(BaseMixin, db.Model):
755755
__tablename__ = "pipeline_study_results"
756+
__table_args__ = (
757+
sa.Index(
758+
"ix_pipeline_study_results__modality",
759+
sa.text("(result_data -> 'Modality')"),
760+
postgresql_using="gin",
761+
),
762+
)
756763

757764
config_id = db.Column(
758765
db.Text, db.ForeignKey("pipeline_configs.id", ondelete="CASCADE"), index=True

store/backend/neurostore/resources/data.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,34 @@ def view_search(self, q, args):
956956
for idx, (field_path, operator, value) in enumerate(
957957
filters["result_filters"]
958958
):
959+
normalized_field = field_path.replace("[]", "")
960+
if (
961+
pipeline_name == "TaskExtractor"
962+
and normalized_field == "Modality"
963+
and operator == "="
964+
):
965+
modality_values = [
966+
val.strip() for val in value.split("|") if val.strip()
967+
]
968+
if modality_values:
969+
modality_field = PipelineStudyResultAlias.result_data.op("->")(
970+
sa.literal_column("'Modality'")
971+
)
972+
modality_clauses = []
973+
for idx, modality_value in enumerate(modality_values):
974+
param_name = f"modality_filter_{pipeline_name}_{idx}"
975+
modality_clauses.append(
976+
modality_field.op("@>")(
977+
sa.func.jsonb_build_array(
978+
sa.bindparam(param_name, modality_value)
979+
)
980+
)
981+
)
982+
pipeline_query = pipeline_query.filter(
983+
sae.or_(*modality_clauses)
984+
)
985+
pipeline_subqueries.append(pipeline_query.subquery())
986+
continue
959987
jsonpath = build_jsonpath(field_path, operator, value)
960988
param_name = f"jsonpath_result_{pipeline_name}_{idx}"
961989
pipeline_query = pipeline_query.filter(

0 commit comments

Comments
 (0)