Skip to content

Commit 0a4e56c

Browse files
authored
update recording filtering to utilize sample frame id (#513)
1 parent bba7d1a commit 0a4e56c

3 files changed

Lines changed: 46 additions & 31 deletions

File tree

bats_ai/core/utils/grts_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
from __future__ import annotations
22

3+
from django.db.models import Case, F, IntegerField, Value, When
4+
5+
6+
def recording_effective_sample_frame_id_case() -> Case:
7+
"""Return a ``Case`` matching ``normalize_sample_frame_id`` for ``Recording`` rows."""
8+
return Case(
9+
When(sample_frame_id__isnull=True, then=Value(14)),
10+
When(sample_frame_id=19, then=Value(20)),
11+
default=F("sample_frame_id"),
12+
output_field=IntegerField(),
13+
)
14+
315

416
def normalize_sample_frame_id(sample_frame_id: int | None) -> int | None:
517
"""Normalize sample frame IDs for AKCAN compatibility.
@@ -12,4 +24,6 @@ def normalize_sample_frame_id(sample_frame_id: int | None) -> int | None:
1224
"""
1325
if sample_frame_id == 19:
1426
return 20
27+
if sample_frame_id is None:
28+
return 14
1529
return sample_frame_id

bats_ai/core/views/recording.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from bats_ai.core.models import (
2020
Annotations,
2121
CompressedSpectrogram,
22-
GRTSCells,
2322
PulseMetadata,
2423
Recording,
2524
RecordingAnnotation,
@@ -29,7 +28,7 @@
2928
Spectrogram,
3029
)
3130
from bats_ai.core.tasks.tasks import recording_compute_spectrogram
32-
from bats_ai.core.views.recording_location import _parse_bbox
31+
from bats_ai.core.views.recording_location import _parse_bbox, filter_recordings_by_map_bbox
3332
from bats_ai.core.views.species import SpeciesSchema
3433

3534
if TYPE_CHECKING:
@@ -525,13 +524,7 @@ def get_recordings( # noqa: C901
525524
if q.bbox and q.bbox.strip():
526525
min_lon, min_lat, max_lon, max_lat = _parse_bbox(q.bbox)
527526
bbox_poly = Polygon.from_bbox((min_lon, min_lat, max_lon, max_lat))
528-
# Need to check the GRTSCells centroids as well as the recording_location
529-
grts_cell_ids = GRTSCells.objects.filter(centroid_4326__intersects=bbox_poly).values_list(
530-
"grts_cell_id", flat=True
531-
)
532-
queryset = queryset.filter(
533-
Q(recording_location__intersects=bbox_poly) | Q(grts_cell_id__in=grts_cell_ids)
534-
)
527+
queryset = filter_recordings_by_map_bbox(queryset, bbox_poly)
535528

536529
sort_field = q.sort_by or "created"
537530
order_prefix = "" if q.sort_direction == "asc" else "-"
@@ -610,13 +603,7 @@ def apply_filters_and_sort(qs: QuerySet[Recording]) -> QuerySet[Recording]:
610603
if bbox and bbox.strip():
611604
min_lon, min_lat, max_lon, max_lat = _parse_bbox(bbox)
612605
bbox_poly = Polygon.from_bbox((min_lon, min_lat, max_lon, max_lat))
613-
# Need to check the GRTSCells centroids as well as the recording_location
614-
grts_cell_ids = GRTSCells.objects.filter(
615-
centroid_4326__intersects=bbox_poly
616-
).values_list("grts_cell_id", flat=True)
617-
qs = qs.filter(
618-
Q(recording_location__intersects=bbox_poly) | Q(grts_cell_id__in=grts_cell_ids)
619-
)
606+
qs = filter_recordings_by_map_bbox(qs, bbox_poly)
620607
order_prefix = "" if sort_direction == "asc" else "-"
621608
if sort_by == "owner_username":
622609
qs = qs.order_by(f"{order_prefix}owner__username")

bats_ai/core/views/recording_location.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Literal
55

66
from django.contrib.gis.geos import Polygon
7-
from django.db.models import Q, QuerySet
7+
from django.db.models import Exists, OuterRef, Q, QuerySet
88
from ninja import Query, Router, Schema
99
from ninja.errors import HttpError
1010

@@ -14,7 +14,10 @@
1414
Recording,
1515
RecordingAnnotation,
1616
)
17-
from bats_ai.core.utils.grts_utils import normalize_sample_frame_id
17+
from bats_ai.core.utils.grts_utils import (
18+
normalize_sample_frame_id,
19+
recording_effective_sample_frame_id_case,
20+
)
1821

1922
if TYPE_CHECKING:
2023
from django.http import HttpRequest
@@ -60,14 +63,33 @@ def _split_tags(tags: str | None) -> list[str]:
6063
return [t.strip() for t in tags.split(",") if t.strip()]
6164

6265

63-
def _apply_recording_filters_and_sort( # noqa: PLR0913
66+
def filter_recordings_by_map_bbox(
67+
qs: QuerySet[Recording],
68+
bbox_poly: Polygon,
69+
) -> QuerySet[Recording]:
70+
"""Keep recordings whose point lies in the bbox or whose GRTS cell centroid matches.
71+
72+
Cell matching uses ``(grts_cell_id, sample_frame_id)`` on ``GRTSCells``, with the
73+
same effective sample frame rules as ``normalize_sample_frame_id``.
74+
"""
75+
cell_centroid_in_bbox = GRTSCells.objects.filter(
76+
centroid_4326__intersects=bbox_poly,
77+
grts_cell_id=OuterRef("grts_cell_id"),
78+
sample_frame_id=OuterRef("_map_bbox_effective_sf"),
79+
)
80+
return qs.annotate(_map_bbox_effective_sf=recording_effective_sample_frame_id_case()).filter(
81+
Q(recording_location__intersects=bbox_poly)
82+
| (Q(grts_cell_id__isnull=False) & Exists(cell_centroid_in_bbox))
83+
)
84+
85+
86+
def _apply_recording_filters_and_sort(
6487
*,
6588
qs: QuerySet[Recording],
6689
exclude_submitted: bool,
6790
submitted_by_user: QuerySet[int] | None,
6891
tags: str | None,
6992
bbox_poly: Polygon | None,
70-
grts_cell_ids: QuerySet[int] | None,
7193
) -> QuerySet[Recording]:
7294
if exclude_submitted and submitted_by_user is not None:
7395
qs = qs.exclude(pk__in=submitted_by_user)
@@ -78,10 +100,8 @@ def _apply_recording_filters_and_sort( # noqa: PLR0913
78100
qs = qs.filter(tags__text=tag)
79101
qs = qs.distinct()
80102

81-
if bbox_poly is not None and grts_cell_ids is not None:
82-
qs = qs.filter(
83-
Q(recording_location__intersects=bbox_poly) | Q(grts_cell_id__in=grts_cell_ids)
84-
)
103+
if bbox_poly is not None:
104+
qs = filter_recordings_by_map_bbox(qs, bbox_poly)
85105

86106
# Keep deterministic ordering even though we don't expose sorting params.
87107
return qs.order_by("-created")
@@ -177,28 +197,22 @@ def get_recording_locations(
177197

178198
bbox = _parse_bbox(q.bbox)
179199
bbox_poly: Polygon | None = None
180-
grts_cell_ids: QuerySet[int] | None = None
181200
if bbox is not None:
182201
bbox_poly = Polygon.from_bbox((bbox[0], bbox[1], bbox[2], bbox[3]))
183-
grts_cell_ids = GRTSCells.objects.filter(centroid_4326__intersects=bbox_poly).values_list(
184-
"grts_cell_id", flat=True
185-
)
186202

187203
my_qs = _apply_recording_filters_and_sort(
188204
qs=my_qs,
189205
exclude_submitted=exclude_submitted,
190206
submitted_by_user=submitted_by_user,
191207
tags=q.tags,
192208
bbox_poly=bbox_poly,
193-
grts_cell_ids=grts_cell_ids,
194209
)
195210
shared_qs = _apply_recording_filters_and_sort(
196211
qs=shared_qs,
197212
exclude_submitted=exclude_submitted,
198213
submitted_by_user=submitted_by_user,
199214
tags=q.tags,
200215
bbox_poly=bbox_poly,
201-
grts_cell_ids=grts_cell_ids,
202216
)
203217

204218
my_list = list(
@@ -226,7 +240,7 @@ def get_recording_locations(
226240
sample_frame_cell_id_pairs = {
227241
(normalize_sample_frame_id(r.sample_frame_id), r.grts_cell_id)
228242
for r in recordings
229-
if r.sample_frame_id is not None and r.grts_cell_id is not None
243+
if r.grts_cell_id is not None
230244
}
231245
cell_centroids_by_sample_frame_id = _precompute_grts_cell_centroids(sample_frame_cell_id_pairs)
232246

0 commit comments

Comments
 (0)