Skip to content

Commit f8f16ea

Browse files
committed
support new processor arg video_maxlen_ttl
1 parent 2b7d564 commit f8f16ea

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

src/llamafactory/data/mm_plugin.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,25 @@ def _get_number_of_features(self, orig_height: int, orig_width: int, height: int
8686
pass
8787

8888

89+
def _cal_max_frames_each_video(durations: list, video_maxlen_ttl: int, video_maxlen: int) -> list[int]:
90+
"""Calculate `max_num_of_frames` for each video based on their durations, and return a list of `max_num_of_frames`. Every `max_num_of_frames` should be in [2, video_maxlen]."""
91+
dura_ttl = sum(durations)
92+
max_nums_of_frames = [ # 2 < max_num_of_frames < video_maxlen
93+
min(max(int(video_maxlen_ttl * dura / dura_ttl), 2), video_maxlen) for dura in durations
94+
] # list of `max_num_of_frames`
95+
if sum(max_nums_of_frames) > video_maxlen_ttl: # may be bigger if some are set 2
96+
delta = sum(max_nums_of_frames) - video_maxlen_ttl
97+
for _ in range(delta): #
98+
max_idx = max_nums_of_frames.index(max(max_nums_of_frames))
99+
if max(max_nums_of_frames) - 1 >= 2: # should still >= 2
100+
max_nums_of_frames[max_idx] -= 1
101+
else:
102+
raise ValueError(
103+
f"Too many videos. Couldn't satisfy the requirement of having at least 2 frames for each video. Please decrease the number of videos or increase `video_maxlen_ttl` (e.g. >={2 * len(max_nums_of_frames)})."
104+
)
105+
return max_nums_of_frames
106+
107+
89108
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
90109
r"""Concatenate a list of lists, numpy arrays or torch tensors.
91110
@@ -247,10 +266,20 @@ def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str,
247266
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
248267
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
249268
results = []
250-
for video in videos:
269+
video_streams = []
270+
durations = []
271+
for video in videos: # prepare durations first
251272
container = av.open(video, "r")
252273
video_stream = next(stream for stream in container.streams if stream.type == "video")
253-
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
274+
durations.append(video_stream.duration * video_stream.time_base) # unit: second
275+
video_streams.append(video_stream)
276+
max_frames_each_video = _cal_max_frames_each_video(durations, **kwargs)
277+
for video_stream, max_frames in zip(video_streams, max_frames_each_video):
278+
sample_indices = self._get_video_sample_indices(
279+
video_stream,
280+
video_fps=kwargs["video_fps"],
281+
video_maxlen=max_frames,
282+
)
254283
frames: list[ImageObject] = []
255284
container.seek(0)
256285
for frame_idx, frame in enumerate(container.decode(video_stream)):
@@ -340,6 +369,7 @@ def _get_mm_inputs(
340369
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
341370
video_fps=getattr(processor, "video_fps", 2.0),
342371
video_maxlen=getattr(processor, "video_maxlen", 128),
372+
video_maxlen_ttl=getattr(processor, "video_maxlen_ttl", 128 * len(videos)), # disabled by default
343373
)["videos"]
344374
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
345375
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
@@ -516,6 +546,7 @@ def _get_mm_inputs(
516546
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
517547
video_fps=getattr(processor, "video_fps", 2.0),
518548
video_maxlen=getattr(processor, "video_maxlen", 128),
549+
video_maxlen_ttl=getattr(processor, "video_maxlen_ttl", 128 * len(videos)), # disabled by default
519550
)["videos"]
520551

521552
if len(images) != 0:
@@ -1055,6 +1086,7 @@ def _get_mm_inputs(
10551086
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
10561087
video_fps=getattr(processor, "video_fps", 2.0),
10571088
video_maxlen=getattr(processor, "video_maxlen", 128),
1089+
video_maxlen_ttl=getattr(processor, "video_maxlen_ttl", 128 * len(videos)), # disabled by default
10581090
)["videos"]
10591091
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
10601092
mm_inputs.update(video_inputs)
@@ -1439,10 +1471,20 @@ def _regularize_videos(
14391471
self, videos: list["VideoInput"], **kwargs
14401472
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
14411473
results, fps_per_video = [], []
1474+
video_streams = []
1475+
durations = []
14421476
for video in videos:
14431477
container = av.open(video, "r")
14441478
video_stream = next(stream for stream in container.streams if stream.type == "video")
1445-
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
1479+
durations.append(video_stream.duration * video_stream.time_base) # unit: second
1480+
video_streams.append(video_stream)
1481+
max_frames_each_video = _cal_max_frames_each_video(durations, **kwargs)
1482+
for video_stream, max_frames in zip(video_streams, max_frames_each_video):
1483+
sample_indices = self._get_video_sample_indices(
1484+
video_stream,
1485+
video_fps=kwargs["video_fps"],
1486+
video_maxlen=max_frames,
1487+
)
14461488
frames: list[ImageObject] = []
14471489
container.seek(0)
14481490
for frame_idx, frame in enumerate(container.decode(video_stream)):
@@ -1486,6 +1528,7 @@ def _get_mm_inputs(
14861528
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
14871529
video_fps=getattr(processor, "video_fps", 2.0),
14881530
video_maxlen=getattr(processor, "video_maxlen", 128),
1531+
video_maxlen_ttl=getattr(processor, "video_maxlen_ttl", 128 * len(videos)), # disabled by default
14891532
)
14901533
mm_inputs.update(image_processor(images=None, videos=video_data["videos"], return_tensors="pt"))
14911534
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
@@ -1577,6 +1620,7 @@ def _get_mm_inputs(
15771620
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
15781621
video_fps=getattr(processor, "video_fps", 2.0),
15791622
video_maxlen=getattr(processor, "video_maxlen", 128),
1623+
video_maxlen_ttl=getattr(processor, "video_maxlen_ttl", 128 * len(videos)), # disabled by default
15801624
)
15811625
mm_inputs.update(image_processor(images=None, videos=video_dict["videos"], return_tensors="pt"))
15821626
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)

src/llamafactory/hparams/model_args.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ class ProcessorArguments:
249249
)
250250
video_maxlen: int = field(
251251
default=128,
252-
metadata={"help": "The maximum number of sampled frames for video inputs."},
252+
metadata={"help": "The unified maximum number of sampled frames for each video inputs."},
253+
)
254+
video_maxlen_ttl: int = field(
255+
default=128 * 50, # assume 50 videos at max in 1 input
256+
metadata={"help": "The maximum number of total sampled frames of all video inputs."},
253257
)
254258
audio_sampling_rate: int = field(
255259
default=16000,

src/llamafactory/model/patcher.py

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def patch_processor(
8484
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
8585
setattr(processor, "video_fps", model_args.video_fps)
8686
setattr(processor, "video_maxlen", model_args.video_maxlen)
87+
setattr(processor, "video_maxlen_ttl", model_args.video_maxlen_ttl)
8788
setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
8889
setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
8990

0 commit comments

Comments
 (0)