diff --git a/store/backend/neurostore/resources/base.py b/store/backend/neurostore/resources/base.py index 2ca5b2a6..bfcd2958 100644 --- a/store/backend/neurostore/resources/base.py +++ b/store/backend/neurostore/resources/base.py @@ -160,26 +160,29 @@ def update_base_studies(self, base_studies): if not base_studies: return - # Subquery for new_has_coordinates - new_has_coordinates_subquery = ( - sa.select(sa.func.count(Point.id) > 0) - .select_from(Study) - .join(Analysis, Analysis.study_id == Study.id) - .join(Point, Point.analysis_id == Analysis.id) + studies_for_base_study = ( + sa.select(Study.id) .where(Study.base_study_id == BaseStudy.id) .correlate(BaseStudy) .scalar_subquery() ) - # Subquery for new_has_images + # Subquery for new_has_coordinates using EXISTS for early exit + new_has_coordinates_subquery = ( + sa.select(sa.literal(1)) + .select_from(Analysis) + .join(Point, Point.analysis_id == Analysis.id) + .where(Analysis.study_id.in_(studies_for_base_study)) + .exists() + ) + + # Subquery for new_has_images using EXISTS for early exit new_has_images_subquery = ( - sa.select(sa.func.count(Image.id) > 0) - .select_from(Study) - .join(Analysis, Analysis.study_id == Study.id) + sa.select(sa.literal(1)) + .select_from(Analysis) .join(Image, Image.analysis_id == Analysis.id) - .where(Study.base_study_id == BaseStudy.id) - .correlate(BaseStudy) - .scalar_subquery() + .where(Analysis.study_id.in_(studies_for_base_study)) + .exists() ) # Main query diff --git a/store/backend/neurostore/resources/data.py b/store/backend/neurostore/resources/data.py index 9e92c132..9d83223b 100644 --- a/store/backend/neurostore/resources/data.py +++ b/store/backend/neurostore/resources/data.py @@ -722,25 +722,33 @@ def view_search(self, q, args): radius = float(radius) except Exception: abort_validation("Spatial parameters must be numeric.") - # Join BaseStudy -> Study -> Analysis -> Point - q = q.join(Study, Study.base_study_id == self._model.id) - q = q.join(Analysis, Analysis.study_id == Study.id) - q = q.join(Point, Point.analysis_id == Analysis.id) - # Box filter first, then Euclidean distance - q = q.filter( - Point.x <= x + radius, - Point.x >= x - radius, - Point.y <= y + radius, - Point.y >= y - radius, - Point.z <= z + radius, - Point.z >= z - radius, - (Point.x - x) * (Point.x - x) - + (Point.y - y) * (Point.y - y) - + (Point.z - z) * (Point.z - z) - <= radius * radius, + # Use EXISTS so we do not duplicate base studies when filtering by spatial criteria + spatial_point = aliased(Point) + spatial_analysis = aliased(Analysis) + spatial_study = aliased(Study) + + spatial_filter = ( + sa.select(sa.literal(True)) + .select_from(spatial_study) + .join(spatial_analysis, spatial_analysis.study_id == spatial_study.id) + .join(spatial_point, spatial_point.analysis_id == spatial_analysis.id) + .where( + spatial_study.base_study_id == self._model.id, + spatial_point.x <= x + radius, + spatial_point.x >= x - radius, + spatial_point.y <= y + radius, + spatial_point.y >= y - radius, + spatial_point.z <= z + radius, + spatial_point.z >= z - radius, + (spatial_point.x - x) * (spatial_point.x - x) + + (spatial_point.y - y) * (spatial_point.y - y) + + (spatial_point.z - z) * (spatial_point.z - z) + <= radius * radius, + ) + .correlate(self._model) + .exists() ) - # Only return distinct base studies - q = q.distinct() + q = q.filter(spatial_filter) elif any(v is not None for v in [x, y, z, radius]): abort_validation("Spatial query requires x, y, z, and radius together.")