Skip to content

Commit d618e7c

Browse files
Improve video grid loading speed (#183)
1 parent 462b300 commit d618e7c

File tree

22 files changed

+224
-108
lines changed

22 files changed

+224
-108
lines changed

lightly_studio/src/lightly_studio/api/routes/api/video.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@
1010
from lightly_studio.db_manager import SessionDep
1111
from lightly_studio.models.video import VideoTable, VideoView, VideoViewsWithCount
1212
from lightly_studio.resolvers import video_resolver
13-
from lightly_studio.resolvers.video_resolver import VideosWithCount
1413

1514
video_router = APIRouter(prefix="/datasets/{dataset_id}/video", tags=["video"])
1615

1716

18-
@video_router.get("/", response_model=VideoViewsWithCount)
17+
@video_router.get("/")
1918
def get_all_videos(
2019
session: SessionDep,
2120
dataset_id: Annotated[UUID, Path(title="Dataset Id")],
2221
pagination: Annotated[PaginatedWithCursor, Depends()],
23-
) -> VideosWithCount:
22+
) -> VideoViewsWithCount:
2423
"""Retrieve a list of all videos for a given dataset ID with pagination.
2524
2625
Args:

lightly_studio/src/lightly_studio/examples/example_video_annotations.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def load_annotations(session: Session, dataset_id: UUID, annotations_path: Path)
9292
Temporarily use internal add_samples API until labelformat supports videos natively.
9393
"""
9494
print("Loading video annotations...")
95-
videos = video_resolver.get_all_by_dataset_id(session=session, dataset_id=dataset_id).samples
95+
videos = video_resolver.get_all_by_dataset_id_with_frames(
96+
session=session, dataset_id=dataset_id
97+
)
9698
video_name_to_video = {video.file_name: video for video in videos}
9799
yvis_input = YouTubeVISObjectDetectionInput(input_file=annotations_path)
98100
label_map = add_samples._create_label_map( # noqa: SLF001

lightly_studio/src/lightly_studio/models/video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,13 @@ class VideoView(SQLModel):
6060

6161
width: int
6262
height: int
63-
duration_s: float
63+
duration_s: Optional[float] = None
6464
fps: float
6565
file_name: str
6666
file_path_abs: str
6767
sample_id: UUID
6868
sample: SampleView
69-
frames: List["FrameView"] = []
69+
frame: Optional["FrameView"] = None
7070

7171

7272
class VideoViewsWithCount(BaseModel):

lightly_studio/src/lightly_studio/resolvers/video_resolver/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
from lightly_studio.resolvers.video_resolver.create_many import create_many
44
from lightly_studio.resolvers.video_resolver.filter_new_paths import filter_new_paths
55
from lightly_studio.resolvers.video_resolver.get_all_by_dataset_id import (
6-
VideosWithCount,
76
get_all_by_dataset_id,
7+
get_all_by_dataset_id_with_frames,
88
)
99
from lightly_studio.resolvers.video_resolver.get_by_id import get_by_id
1010

1111
__all__ = [
12-
"VideosWithCount",
1312
"create_many",
1413
"filter_new_paths",
1514
"get_all_by_dataset_id",
15+
"get_all_by_dataset_id_with_frames",
1616
"get_by_id",
1717
]

lightly_studio/src/lightly_studio/resolvers/video_resolver/get_all_by_dataset_id.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,71 @@
55
from collections.abc import Sequence
66
from uuid import UUID
77

8-
from pydantic import BaseModel
8+
from sqlalchemy import and_
9+
from sqlalchemy.orm import joinedload, selectinload
910
from sqlmodel import Session, col, func, select
1011

1112
from lightly_studio.api.routes.api.validators import Paginated
12-
from lightly_studio.models.sample import SampleTable
13-
from lightly_studio.models.video import VideoTable
14-
15-
16-
class VideosWithCount(BaseModel):
17-
"""Result of getting all samples."""
18-
19-
samples: Sequence[VideoTable]
20-
total_count: int
21-
next_cursor: int | None = None
13+
from lightly_studio.models.sample import SampleTable, SampleView
14+
from lightly_studio.models.video import (
15+
FrameView,
16+
VideoFrameTable,
17+
VideoTable,
18+
VideoView,
19+
VideoViewsWithCount,
20+
)
2221

2322

2423
def get_all_by_dataset_id(
2524
session: Session,
2625
dataset_id: UUID,
2726
pagination: Paginated | None = None,
2827
sample_ids: list[UUID] | None = None,
29-
) -> VideosWithCount:
28+
) -> VideoViewsWithCount:
3029
"""Retrieve samples for a specific dataset with optional filtering."""
30+
# Subquery to find the minimum frame_number for each video
31+
min_frame_subquery = (
32+
select(
33+
VideoFrameTable.parent_sample_id,
34+
func.min(col(VideoFrameTable.frame_number)).label("min_frame_number"),
35+
)
36+
.group_by(col(VideoFrameTable.parent_sample_id))
37+
.subquery()
38+
)
39+
# TODO(Horatiu, 11/2025): Check if it is possible to optimize this query.
40+
# Query to get videos with their first frame (frame with min frame_number)
41+
# First join the subquery to VideoTable, then join VideoFrameTable
3142
samples_query = (
32-
select(VideoTable).join(VideoTable.sample).where(SampleTable.dataset_id == dataset_id)
43+
select(VideoTable, VideoFrameTable)
44+
.join(VideoTable.sample)
45+
.outerjoin(
46+
min_frame_subquery,
47+
min_frame_subquery.c.parent_sample_id == VideoTable.sample_id,
48+
)
49+
.outerjoin(
50+
VideoFrameTable,
51+
and_(
52+
col(VideoFrameTable.parent_sample_id) == col(VideoTable.sample_id),
53+
col(VideoFrameTable.frame_number) == min_frame_subquery.c.min_frame_number,
54+
),
55+
)
56+
.where(SampleTable.dataset_id == dataset_id)
57+
.options(
58+
selectinload(VideoFrameTable.sample).options(
59+
joinedload(SampleTable.tags),
60+
# Ignore type checker error - false positive from TYPE_CHECKING.
61+
joinedload(SampleTable.metadata_dict), # type: ignore[arg-type]
62+
selectinload(SampleTable.captions),
63+
),
64+
selectinload(VideoTable.sample).options(
65+
joinedload(SampleTable.tags),
66+
# Ignore type checker error - false positive from TYPE_CHECKING.
67+
joinedload(SampleTable.metadata_dict), # type: ignore[arg-type]
68+
selectinload(SampleTable.captions),
69+
),
70+
)
3371
)
72+
3473
total_count_query = (
3574
select(func.count())
3675
.select_from(VideoTable)
@@ -54,8 +93,55 @@ def get_all_by_dataset_id(
5493
if pagination and pagination.offset + pagination.limit < total_count:
5594
next_cursor = pagination.offset + pagination.limit
5695

57-
return VideosWithCount(
58-
samples=session.exec(samples_query).all(),
96+
# Fetch videos with their first frames and convert to VideoView
97+
results = session.exec(samples_query).all()
98+
video_views = [
99+
_convert_video_table_to_view(video=video, first_frame=first_frame)
100+
for video, first_frame in results
101+
]
102+
103+
return VideoViewsWithCount(
104+
samples=video_views,
59105
total_count=total_count,
60106
next_cursor=next_cursor,
61107
)
108+
109+
110+
# TODO(Horatiu, 11/2025): This should be deleted when we have proper way of getting all frames for
111+
# a video.
112+
def get_all_by_dataset_id_with_frames(
113+
session: Session,
114+
dataset_id: UUID,
115+
) -> Sequence[VideoTable]:
116+
"""Retrieve video table with all the samples."""
117+
samples_query = (
118+
select(VideoTable).join(VideoTable.sample).where(SampleTable.dataset_id == dataset_id)
119+
)
120+
samples_query = samples_query.order_by(col(VideoTable.file_path_abs).asc())
121+
return session.exec(samples_query).all()
122+
123+
124+
def _convert_video_table_to_view(
125+
video: VideoTable, first_frame: VideoFrameTable | None
126+
) -> VideoView:
127+
"""Convert VideoTable to VideoView with only the first frame."""
128+
first_frame_view = None
129+
if first_frame:
130+
first_frame_view = FrameView(
131+
frame_number=first_frame.frame_number,
132+
frame_timestamp_s=first_frame.frame_timestamp_s,
133+
sample_id=first_frame.sample_id,
134+
sample=SampleView.model_validate(first_frame.sample),
135+
)
136+
137+
return VideoView(
138+
width=video.width,
139+
height=video.height,
140+
duration_s=video.duration_s,
141+
fps=video.fps,
142+
file_name=video.file_name,
143+
file_path_abs=video.file_path_abs,
144+
sample_id=video.sample_id,
145+
sample=SampleView.model_validate(video.sample),
146+
frame=first_frame_view,
147+
)

lightly_studio/tests/api/routes/api/test_frame.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
from lightly_studio.api.routes.api.status import HTTP_STATUS_OK
99
from lightly_studio.models.dataset import SampleType
1010
from tests.helpers_resolvers import create_dataset
11-
from tests.resolvers.video_frame_resolver.helpers import create_video_with_frames
12-
from tests.resolvers.video_resolver.helpers import VideoStub
11+
from tests.resolvers.video.helpers import VideoStub, create_video_with_frames
1312

1413

1514
def test_get_all_frames(

lightly_studio/tests/api/routes/api/test_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from lightly_studio.models.dataset import SampleType
88
from lightly_studio.resolvers import video_resolver
99
from tests.helpers_resolvers import create_dataset
10-
from tests.resolvers.video_resolver.helpers import VideoStub, create_videos
10+
from tests.resolvers.video.helpers import VideoStub, create_videos
1111

1212

1313
def test_get_all_videos(test_client: TestClient, db_session: Session) -> None:

lightly_studio/tests/core/test_add_videos.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def test_load_into_dataset_from_paths(db_session: Session, tmp_path: Path) -> No
5757
video = videos[0]
5858
assert video.file_name == "test_video_0.mp4"
5959
assert video.file_path_abs == str(second_video_path)
60+
assert video.frame is not None
61+
assert video.frame.frame_number == 0
6062
video = videos[1]
6163
assert video.file_name == "test_video_1.mp4"
6264
assert video.file_path_abs == str(first_video_path)

lightly_studio/tests/resolvers/test_annotation_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
create_image,
3131
create_tag,
3232
)
33-
from tests.resolvers.video_resolver.helpers import VideoStub, create_videos
33+
from tests.resolvers.video.helpers import VideoStub, create_videos
3434

3535

3636
@dataclass

0 commit comments

Comments
 (0)