55from collections .abc import Sequence
66from uuid import UUID
77
8- from pydantic import BaseModel
8+ from sqlalchemy import and_
9+ from sqlalchemy .orm import joinedload , selectinload
910from sqlmodel import Session , col , func , select
1011
1112from lightly_studio .api .routes .api .validators import Paginated
12- from lightly_studio .models .sample import SampleTable
13- from lightly_studio .models .video import VideoTable
14-
15-
16- class VideosWithCount (BaseModel ):
17- """Result of getting all samples."""
18-
19- samples : Sequence [VideoTable ]
20- total_count : int
21- next_cursor : int | None = None
13+ from lightly_studio .models .sample import SampleTable , SampleView
14+ from lightly_studio .models .video import (
15+ FrameView ,
16+ VideoFrameTable ,
17+ VideoTable ,
18+ VideoView ,
19+ VideoViewsWithCount ,
20+ )
2221
2322
2423def get_all_by_dataset_id (
2524 session : Session ,
2625 dataset_id : UUID ,
2726 pagination : Paginated | None = None ,
2827 sample_ids : list [UUID ] | None = None ,
29- ) -> VideosWithCount :
28+ ) -> VideoViewsWithCount :
3029 """Retrieve samples for a specific dataset with optional filtering."""
30+ # Subquery to find the minimum frame_number for each video
31+ min_frame_subquery = (
32+ select (
33+ VideoFrameTable .parent_sample_id ,
34+ func .min (col (VideoFrameTable .frame_number )).label ("min_frame_number" ),
35+ )
36+ .group_by (col (VideoFrameTable .parent_sample_id ))
37+ .subquery ()
38+ )
39+ # TODO(Horatiu, 11/2025): Check if it is possible to optimize this query.
40+ # Query to get videos with their first frame (frame with min frame_number)
41+ # First join the subquery to VideoTable, then join VideoFrameTable
3142 samples_query = (
32- select (VideoTable ).join (VideoTable .sample ).where (SampleTable .dataset_id == dataset_id )
43+ select (VideoTable , VideoFrameTable )
44+ .join (VideoTable .sample )
45+ .outerjoin (
46+ min_frame_subquery ,
47+ min_frame_subquery .c .parent_sample_id == VideoTable .sample_id ,
48+ )
49+ .outerjoin (
50+ VideoFrameTable ,
51+ and_ (
52+ col (VideoFrameTable .parent_sample_id ) == col (VideoTable .sample_id ),
53+ col (VideoFrameTable .frame_number ) == min_frame_subquery .c .min_frame_number ,
54+ ),
55+ )
56+ .where (SampleTable .dataset_id == dataset_id )
57+ .options (
58+ selectinload (VideoFrameTable .sample ).options (
59+ joinedload (SampleTable .tags ),
60+ # Ignore type checker error - false positive from TYPE_CHECKING.
61+ joinedload (SampleTable .metadata_dict ), # type: ignore[arg-type]
62+ selectinload (SampleTable .captions ),
63+ ),
64+ selectinload (VideoTable .sample ).options (
65+ joinedload (SampleTable .tags ),
66+ # Ignore type checker error - false positive from TYPE_CHECKING.
67+ joinedload (SampleTable .metadata_dict ), # type: ignore[arg-type]
68+ selectinload (SampleTable .captions ),
69+ ),
70+ )
3371 )
72+
3473 total_count_query = (
3574 select (func .count ())
3675 .select_from (VideoTable )
@@ -54,8 +93,55 @@ def get_all_by_dataset_id(
5493 if pagination and pagination .offset + pagination .limit < total_count :
5594 next_cursor = pagination .offset + pagination .limit
5695
57- return VideosWithCount (
58- samples = session .exec (samples_query ).all (),
96+ # Fetch videos with their first frames and convert to VideoView
97+ results = session .exec (samples_query ).all ()
98+ video_views = [
99+ _convert_video_table_to_view (video = video , first_frame = first_frame )
100+ for video , first_frame in results
101+ ]
102+
103+ return VideoViewsWithCount (
104+ samples = video_views ,
59105 total_count = total_count ,
60106 next_cursor = next_cursor ,
61107 )
108+
109+
110+ # TODO(Horatiu, 11/2025): This should be deleted when we have proper way of getting all frames for
111+ # a video.
112+ def get_all_by_dataset_id_with_frames (
113+ session : Session ,
114+ dataset_id : UUID ,
115+ ) -> Sequence [VideoTable ]:
116+ """Retrieve video table with all the samples."""
117+ samples_query = (
118+ select (VideoTable ).join (VideoTable .sample ).where (SampleTable .dataset_id == dataset_id )
119+ )
120+ samples_query = samples_query .order_by (col (VideoTable .file_path_abs ).asc ())
121+ return session .exec (samples_query ).all ()
122+
123+
124+ def _convert_video_table_to_view (
125+ video : VideoTable , first_frame : VideoFrameTable | None
126+ ) -> VideoView :
127+ """Convert VideoTable to VideoView with only the first frame."""
128+ first_frame_view = None
129+ if first_frame :
130+ first_frame_view = FrameView (
131+ frame_number = first_frame .frame_number ,
132+ frame_timestamp_s = first_frame .frame_timestamp_s ,
133+ sample_id = first_frame .sample_id ,
134+ sample = SampleView .model_validate (first_frame .sample ),
135+ )
136+
137+ return VideoView (
138+ width = video .width ,
139+ height = video .height ,
140+ duration_s = video .duration_s ,
141+ fps = video .fps ,
142+ file_name = video .file_name ,
143+ file_path_abs = video .file_path_abs ,
144+ sample_id = video .sample_id ,
145+ sample = SampleView .model_validate (video .sample ),
146+ frame = first_frame_view ,
147+ )
0 commit comments