diff --git a/unsloth_zoo/vision_utils.py b/unsloth_zoo/vision_utils.py index a655ec1d..dd72da20 100644 --- a/unsloth_zoo/vision_utils.py +++ b/unsloth_zoo/vision_utils.py @@ -114,7 +114,11 @@ def smart_resize( pass -def fetch_image(ele: dict[Union[Tuple[str, str], Image.Image]], size_factor: int = IMAGE_FACTOR) -> Image.Image: +def fetch_image( + ele: dict[Union[Tuple[str, str], Image.Image]], + size_factor: int = IMAGE_FACTOR, + max_image_size: int = None +) -> Image.Image: if "image" in ele: image = ele["image"] else: @@ -136,17 +140,19 @@ def fetch_image(ele: dict[Union[Tuple[str, str], Image.Image]], size_factor: int if image_obj is None: raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") image = image_obj.convert("RGB") - ## resize + + # Resize logic if "resized_height" in ele and "resized_width" in ele: resized_height, resized_width = smart_resize( ele["resized_height"], ele["resized_width"], factor=size_factor, + max_pixels=max_image_size if max_image_size is not None else MAX_PIXELS, ) else: width, height = image.size min_pixels = ele.get("min_pixels", MIN_PIXELS) - max_pixels = ele.get("max_pixels", MAX_PIXELS) + max_pixels = max_image_size if max_image_size is not None else ele.get("max_pixels", MAX_PIXELS) resized_height, resized_width = smart_resize( height, width, @@ -155,7 +161,6 @@ def fetch_image(ele: dict[Union[Tuple[str, str], Image.Image]], size_factor: int max_pixels=max_pixels, ) image = image.resize((resized_width, resized_height)) - return image pass @@ -181,16 +186,16 @@ def extract_vision_info(conversations: Union[list[dict], list[list[dict]]]) -> l def process_vision_info( conversations: Union[list[dict], list[list[dict]]], + max_image_size: int = None # New parameter ) -> tuple[Union[list[Image.Image], None], Union[list[Union[torch.Tensor, list[Image.Image]]], None]]: vision_infos = extract_vision_info(conversations) - ## Read images or videos image_inputs = [] video_inputs = [] for vision_info in vision_infos: if "image" in vision_info or "image_url" in vision_info: - image_inputs.append(fetch_image(vision_info)) + image_inputs.append(fetch_image(vision_info, max_image_size=max_image_size)) elif "video" in vision_info: - video_inputs.append(fetch_video(vision_info)) + video_inputs.append(fetch_video(vision_info)) # [TODO] Handle video max size if needed else: raise ValueError("image, image_url or video should in content.") if len(image_inputs) == 0: @@ -242,18 +247,28 @@ def _get_dtype(dtype): class UnslothVisionDataCollator: # All Unsloth Zoo code licensed under LGPLv3 - __slots__ = "padding_token_ids", "dtype", "ignore_index", "processor", "formatting_func" + __slots__ = "padding_token_ids", "dtype", "ignore_index", "processor", "formatting_func", "max_image_size" - def __init__(self, model, processor, formatting_func = None, ignore_index = -100): + def __init__(self, model, processor, formatting_func=None, ignore_index=-100, max_image_size=None): self.padding_token_ids = get_padding_tokens_ids(processor) self.dtype = _get_dtype( - model.config.torch_dtype \ - if hasattr(model.config, "torch_dtype") else \ + model.config.torch_dtype + if hasattr(model.config, "torch_dtype") else model.get_input_embeddings().weight.dtype ) self.ignore_index = ignore_index self.processor = processor self.formatting_func = formatting_func + + # Determine max_image_size: use provided value, fall back to model config, or default to MAX_PIXELS + if max_image_size is not None: + self.max_image_size = max_image_size + elif hasattr(model.config, "vision_config") and hasattr(model.config.vision_config, "max_image_size"): + self.max_image_size = model.config.vision_config.max_image_size + elif hasattr(model.config, "max_image_size"): + self.max_image_size = model.config.max_image_size + else: + self.max_image_size = MAX_PIXELS return pass @@ -277,7 +292,7 @@ def __call__(self, examples): if "images" in example: image = example["images"][0] else: - image, video = process_vision_info(messages) + image, video = process_vision_info(messages, max_image_size=self.max_image_size) texts .append(message) images.append(image) pass