Skip to content

Commit fa88b70

Browse files
committed
Merge branch 'new_branch_video_len_ttl' of github.com:Luffy-ZY-Wang/LLaMA-Factory into new_branch_video_len_ttl
2 parents 4c57cde + caab07e commit fa88b70

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

src/llamafactory/data/mm_plugin.py

+19
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,25 @@ def _cal_max_frames_each_video(durations: list, video_maxlen_ttl: int, video_max
104104
return max_nums_of_frames
105105

106106

107+
def _cal_max_frames_each_video(durations: list, video_maxlen_ttl: int, video_maxlen: int) -> list[int]:
108+
"""Calculate `max_num_of_frames` for each video based on their durations, and return a list of `max_num_of_frames`. Every `max_num_of_frames` should be in [2, video_maxlen]."""
109+
dura_ttl = sum(durations)
110+
max_nums_of_frames = [ # 2 < max_num_of_frames < video_maxlen
111+
min(max(int(video_maxlen_ttl * dura / dura_ttl), 2), video_maxlen) for dura in durations
112+
] # list of `max_num_of_frames`
113+
if sum(max_nums_of_frames) > video_maxlen_ttl: # may be bigger if some are set 2
114+
delta = sum(max_nums_of_frames) - video_maxlen_ttl
115+
for _ in range(delta): #
116+
max_idx = max_nums_of_frames.index(max(max_nums_of_frames))
117+
if max(max_nums_of_frames) - 1 >= 2: # should still >= 2
118+
max_nums_of_frames[max_idx] -= 1
119+
else:
120+
raise ValueError(
121+
f"Too many videos. Couldn't satisfy the requirement of having at least 2 frames for each video. Please decrease the number of videos or increase `video_maxlen_ttl` (e.g. >={2 * len(max_nums_of_frames)})."
122+
)
123+
return max_nums_of_frames
124+
125+
107126
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
108127
r"""Get paligemma token type ids for computing loss.
109128

0 commit comments

Comments
 (0)