Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions store/backend/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,29 @@ def update_base_studies(self, base_studies):
if not base_studies:
return

# Subquery for new_has_coordinates
new_has_coordinates_subquery = (
sa.select(sa.func.count(Point.id) > 0)
.select_from(Study)
.join(Analysis, Analysis.study_id == Study.id)
.join(Point, Point.analysis_id == Analysis.id)
studies_for_base_study = (
sa.select(Study.id)
.where(Study.base_study_id == BaseStudy.id)
.correlate(BaseStudy)
.scalar_subquery()
)

# Subquery for new_has_images
# Subquery for new_has_coordinates using EXISTS for early exit
new_has_coordinates_subquery = (
sa.select(sa.literal(1))
.select_from(Analysis)
.join(Point, Point.analysis_id == Analysis.id)
.where(Analysis.study_id.in_(studies_for_base_study))
.exists()
)

# Subquery for new_has_images using EXISTS for early exit
new_has_images_subquery = (
sa.select(sa.func.count(Image.id) > 0)
.select_from(Study)
.join(Analysis, Analysis.study_id == Study.id)
sa.select(sa.literal(1))
.select_from(Analysis)
.join(Image, Image.analysis_id == Analysis.id)
.where(Study.base_study_id == BaseStudy.id)
.correlate(BaseStudy)
.scalar_subquery()
.where(Analysis.study_id.in_(studies_for_base_study))
.exists()
)

# Main query
Expand Down
44 changes: 26 additions & 18 deletions store/backend/neurostore/resources/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,25 +722,33 @@ def view_search(self, q, args):
radius = float(radius)
except Exception:
abort_validation("Spatial parameters must be numeric.")
# Join BaseStudy -> Study -> Analysis -> Point
q = q.join(Study, Study.base_study_id == self._model.id)
q = q.join(Analysis, Analysis.study_id == Study.id)
q = q.join(Point, Point.analysis_id == Analysis.id)
# Box filter first, then Euclidean distance
q = q.filter(
Point.x <= x + radius,
Point.x >= x - radius,
Point.y <= y + radius,
Point.y >= y - radius,
Point.z <= z + radius,
Point.z >= z - radius,
(Point.x - x) * (Point.x - x)
+ (Point.y - y) * (Point.y - y)
+ (Point.z - z) * (Point.z - z)
<= radius * radius,
# Use EXISTS so we do not duplicate base studies when filtering by spatial criteria
spatial_point = aliased(Point)
spatial_analysis = aliased(Analysis)
spatial_study = aliased(Study)

spatial_filter = (
sa.select(sa.literal(True))
.select_from(spatial_study)
.join(spatial_analysis, spatial_analysis.study_id == spatial_study.id)
.join(spatial_point, spatial_point.analysis_id == spatial_analysis.id)
.where(
spatial_study.base_study_id == self._model.id,
spatial_point.x <= x + radius,
spatial_point.x >= x - radius,
spatial_point.y <= y + radius,
spatial_point.y >= y - radius,
spatial_point.z <= z + radius,
spatial_point.z >= z - radius,
(spatial_point.x - x) * (spatial_point.x - x)
+ (spatial_point.y - y) * (spatial_point.y - y)
+ (spatial_point.z - z) * (spatial_point.z - z)
<= radius * radius,
)
.correlate(self._model)
.exists()
)
# Only return distinct base studies
q = q.distinct()
q = q.filter(spatial_filter)
elif any(v is not None for v in [x, y, z, radius]):
abort_validation("Spatial query requires x, y, z, and radius together.")

Expand Down
Loading