@@ -50,7 +50,9 @@ def collate_fn(features: List[Dict[str, Any]]) -> Dict[str, Any]:
5050 return {** tensors , ** non_tensors }
5151
5252
53- def process_image (image : Union [Dict [str , Any ], ImageObject , str ], min_pixels : int , max_pixels : int ) -> ImageObject :
53+ def process_image (
54+ image : Union [Dict [str , Any ], ImageObject , str ], min_pixels : Optional [int ], max_pixels : Optional [int ]
55+ ) -> ImageObject :
5456 if isinstance (image , str ):
5557 image = Image .open (image )
5658 elif isinstance (image , dict ):
@@ -59,12 +61,12 @@ def process_image(image: Union[Dict[str, Any], ImageObject, str], min_pixels: in
5961 image = Image .open (BytesIO (image ))
6062
6163 image .load () # avoid "Too many open files" errors
62- if (image .width * image .height ) > max_pixels :
64+ if max_pixels is not None and (image .width * image .height ) > max_pixels :
6365 resize_factor = math .sqrt (max_pixels / (image .width * image .height ))
6466 width , height = int (image .width * resize_factor ), int (image .height * resize_factor )
6567 image = image .resize ((width , height ))
6668
67- if (image .width * image .height ) < min_pixels :
69+ if min_pixels is not None and (image .width * image .height ) < min_pixels :
6870 resize_factor = math .sqrt (min_pixels / (image .width * image .height ))
6971 width , height = int (image .width * resize_factor ), int (image .height * resize_factor )
7072 image = image .resize ((width , height ))
@@ -92,8 +94,8 @@ def __init__(
9294 max_prompt_length : int = 1024 ,
9395 truncation : str = "error" ,
9496 format_prompt : Optional [str ] = None ,
95- max_pixels : Optional [int ] = None ,
9697 min_pixels : Optional [int ] = None ,
98+ max_pixels : Optional [int ] = None ,
9799 filter_overlong_prompts : bool = True ,
98100 ):
99101 self .tokenizer = tokenizer
@@ -104,8 +106,8 @@ def __init__(
104106 self .image_dir = image_dir
105107 self .max_prompt_length = max_prompt_length
106108 self .truncation = truncation
107- self .max_pixels = max_pixels
108109 self .min_pixels = min_pixels
110+ self .max_pixels = max_pixels
109111 self .filter_overlong_prompts = filter_overlong_prompts
110112
111113 if "@" in data_path :
@@ -169,17 +171,16 @@ def __getitem__(self, index):
169171 if self .image_key in example :
170172 prompt = self .processor .apply_chat_template (messages , add_generation_prompt = True , tokenize = False )
171173 images = example .pop (self .image_key )
172- if self .image_dir is not None and len (images ) != 0 and isinstance (images [0 ], str ): # image paths
174+ if self .image_dir is not None and len (images ) != 0 and isinstance (images [0 ], str ): # image paths
173175 images = [os .path .join (self .image_dir , image ) for image in images ]
174176
175177 resized_images = [
176- process_image (image , min_pixels = self .min_pixels , max_pixels = self .max_pixels )
177- for image in images
178+ process_image (image , min_pixels = self .min_pixels , max_pixels = self .max_pixels ) for image in images
178179 ]
179180 model_inputs = self .processor (resized_images , [prompt ], add_special_tokens = False , return_tensors = "pt" )
180181 input_ids = model_inputs .pop ("input_ids" )[0 ]
181182 attention_mask = model_inputs .pop ("attention_mask" )[0 ]
182- example ["multi_modal_inputs " ] = {"images" : images }
183+ example ["multi_modal_data " ] = {"images" : images }
183184 else :
184185 prompt = self .tokenizer .apply_chat_template (messages , add_generation_prompt = True , tokenize = False )
185186 model_inputs = self .tokenizer ([prompt ], add_special_tokens = False , return_tensors = "pt" )
0 commit comments