@@ -89,15 +89,30 @@ def filter_long_prompt(origin_samples: list[Sample], tokenizer, processor, max_l
8989 return origin_samples
9090
9191 if processor :
92- filtered_samples = []
92+ # Use processor only for samples with actual multimodal content; use batched tokenizer for text-only.
93+ text_only = []
94+ multimodal = []
9395 for sample in origin_samples :
96+ if sample .multimodal_inputs and any (v is not None for v in sample .multimodal_inputs .values ()):
97+ multimodal .append (sample )
98+ else :
99+ text_only .append (sample )
100+ filtered_samples = []
101+ if text_only :
102+ prompts = [s .prompt for s in text_only ]
103+ input_ids_list = tokenizer (prompts , add_special_tokens = False )["input_ids" ]
104+ for sample , input_ids in zip (text_only , input_ids_list , strict = True ):
105+ if len (input_ids ) <= max_length :
106+ filtered_samples .append (sample )
107+ if multimodal :
94108 from slime .utils .processing_utils import process_vision_info
95109
96- multimodal_inputs = process_vision_info (sample .prompt , processor )
97- processor_output = processor (text = sample .prompt , ** multimodal_inputs )
98- input_ids = processor_output ["input_ids" ][0 ]
99- if len (input_ids ) <= max_length :
100- filtered_samples .append (sample )
110+ for sample in multimodal :
111+ multimodal_inputs = process_vision_info (sample .prompt , processor )
112+ processor_output = processor (text = sample .prompt , ** multimodal_inputs )
113+ input_ids = processor_output ["input_ids" ][0 ]
114+ if len (input_ids ) <= max_length :
115+ filtered_samples .append (sample )
101116 else :
102117 prompts = [sample .prompt for sample in origin_samples ]
103118 input_ids_list = tokenizer (prompts , add_special_tokens = False )["input_ids" ]
0 commit comments