|
| 1 | +from enum import Enum |
| 2 | +from typing import Optional, Union, Tuple, Dict, List |
| 3 | +import tensorflow as tf |
| 4 | +from keras import ops |
| 5 | +import logging |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +import math |
| 9 | +import itertools |
| 10 | +import re |
| 11 | + |
| 12 | +import PIL.Image |
| 13 | +import PIL.ImageOps |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +class ExplicitEnum(str, Enum): |
| 18 | + """ |
| 19 | + Enum with more explicit error message for missing values. |
| 20 | + """ |
| 21 | + |
| 22 | + @classmethod |
| 23 | + def _missing_(cls, value): |
| 24 | + raise ValueError( |
| 25 | + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" |
| 26 | + ) |
| 27 | + |
| 28 | +class ChannelDimension(ExplicitEnum): |
| 29 | + FIRST = "channels_first" |
| 30 | + LAST = "channels_last" |
| 31 | + |
| 32 | +def infer_channel_dimension_format( |
| 33 | + image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None |
| 34 | +) -> ChannelDimension: |
| 35 | + """ |
| 36 | + Infers the channel dimension format of `image`. |
| 37 | +
|
| 38 | + Args: |
| 39 | + image (`np.ndarray`): |
| 40 | + The image to infer the channel dimension of. |
| 41 | + num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`): |
| 42 | + The number of channels of the image. |
| 43 | +
|
| 44 | + Returns: |
| 45 | + The channel dimension of the image. |
| 46 | + """ |
| 47 | + num_channels = num_channels if num_channels is not None else (1, 3) |
| 48 | + num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels |
| 49 | + |
| 50 | + if image.ndim == 3: |
| 51 | + first_dim, last_dim = 0, 2 |
| 52 | + elif image.ndim == 4: |
| 53 | + first_dim, last_dim = 1, 3 |
| 54 | + else: |
| 55 | + raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") |
| 56 | + |
| 57 | + image_shape = image.shape |
| 58 | + |
| 59 | + if image_shape[first_dim] in num_channels and image_shape[last_dim] in num_channels: |
| 60 | + logger.warning( |
| 61 | + f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension." |
| 62 | + ) |
| 63 | + return ChannelDimension.FIRST |
| 64 | + elif image_shape[first_dim] in num_channels: |
| 65 | + return ChannelDimension.FIRST |
| 66 | + elif image_shape[last_dim] in num_channels: |
| 67 | + return ChannelDimension.LAST |
| 68 | + raise ValueError("Unable to infer channel dimension format") |
| 69 | + |
| 70 | +def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: |
| 71 | + """ |
| 72 | + Returns the (height, width) dimensions of the image. |
| 73 | +
|
| 74 | + Args: |
| 75 | + image (`np.ndarray`): |
| 76 | + The image to get the dimensions of. |
| 77 | + channel_dim (`ChannelDimension`, *optional*): |
| 78 | + Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + A tuple of the image's height and width. |
| 82 | + """ |
| 83 | + if channel_dim is None: |
| 84 | + channel_dim = infer_channel_dimension_format(image) |
| 85 | + |
| 86 | + image_shape = image.shape |
| 87 | + |
| 88 | + if channel_dim == ChannelDimension.FIRST: |
| 89 | + return image_shape[-2], image_shape[-1] |
| 90 | + elif channel_dim == ChannelDimension.LAST: |
| 91 | + return image_shape[-3], image_shape[-2] |
| 92 | + else: |
| 93 | + raise ValueError(f"Unsupported data format: {channel_dim}") |
| 94 | + |
| 95 | +def pan_and_scan( |
| 96 | + image: np.ndarray, |
| 97 | + pan_and_scan_min_crop_size: int, |
| 98 | + pan_and_scan_max_num_crops: int, |
| 99 | + pan_and_scan_min_ratio_to_activate: float, |
| 100 | + input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| 101 | +): |
| 102 | + height, width = get_image_size(image) |
| 103 | + |
| 104 | + # Square or landscape image. |
| 105 | + if width >= height: |
| 106 | + # Only apply PaS if the image is sufficiently exaggerated |
| 107 | + if width / height < pan_and_scan_min_ratio_to_activate: |
| 108 | + return [] |
| 109 | + |
| 110 | + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. |
| 111 | + num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding. |
| 112 | + num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w) |
| 113 | + |
| 114 | + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. |
| 115 | + num_crops_w = max(2, num_crops_w) |
| 116 | + num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) |
| 117 | + num_crops_h = 1 |
| 118 | + |
| 119 | + # Portrait image. |
| 120 | + else: |
| 121 | + # Only apply PaS if the image is sufficiently exaggerated |
| 122 | + if height / width < pan_and_scan_min_ratio_to_activate: |
| 123 | + return [] |
| 124 | + |
| 125 | + # Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. |
| 126 | + num_crops_h = int(math.floor(height / width + 0.5)) |
| 127 | + num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h) |
| 128 | + |
| 129 | + # Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. |
| 130 | + num_crops_h = max(2, num_crops_h) |
| 131 | + num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) |
| 132 | + num_crops_w = 1 |
| 133 | + |
| 134 | + crop_size_w = int(math.ceil(width / num_crops_w)) |
| 135 | + crop_size_h = int(math.ceil(height / num_crops_h)) |
| 136 | + |
| 137 | + # Don't apply PaS if crop size is too small. |
| 138 | + if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: |
| 139 | + return [] |
| 140 | + |
| 141 | + crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] |
| 142 | + crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] |
| 143 | + |
| 144 | + if input_data_format == ChannelDimension.LAST: |
| 145 | + image_crops = [ |
| 146 | + image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] |
| 147 | + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) |
| 148 | + ] |
| 149 | + else: |
| 150 | + image_crops = [ |
| 151 | + image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w] |
| 152 | + for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w) |
| 153 | + ] |
| 154 | + |
| 155 | + return image_crops |
| 156 | + |
| 157 | +def _process_images_for_pan_and_scan( |
| 158 | + images: np.ndarray, |
| 159 | + pan_and_scan_min_crop_size: int, |
| 160 | + pan_and_scan_max_num_crops: int, |
| 161 | + pan_and_scan_min_ratio_to_activate: float, |
| 162 | + input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| 163 | +): |
| 164 | + batched_pas_images_list = [] |
| 165 | + num_crops = [] |
| 166 | + |
| 167 | + for image in images: |
| 168 | + pas_images = pan_and_scan( |
| 169 | + image=image, |
| 170 | + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, |
| 171 | + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, |
| 172 | + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, |
| 173 | + input_data_format=input_data_format, |
| 174 | + ) |
| 175 | + |
| 176 | + batched_pas_images_list.append([image] + pas_images) |
| 177 | + num_crops.append(len(pas_images)) |
| 178 | + |
| 179 | + return batched_pas_images_list, num_crops |
| 180 | + |
| 181 | +def do_pan_and_scan( |
| 182 | + inputs: dict, |
| 183 | + pan_and_scan_min_crop_size: int, |
| 184 | + pan_and_scan_max_num_crops: int, |
| 185 | + pan_and_scan_min_ratio_to_activate: float, |
| 186 | +): |
| 187 | + |
| 188 | + crops_and_prompts = dict() |
| 189 | + crops_and_prompts['crops'] = [] |
| 190 | + crops_and_prompts['modified_prompts'] = [] |
| 191 | + images = inputs.get("images", None) |
| 192 | + prompts = inputs["prompts"] |
| 193 | + image_tag = "<img>" |
| 194 | + |
| 195 | + input_data_format = infer_channel_dimension_format(images[0][0]) |
| 196 | + |
| 197 | + image = [ |
| 198 | + _process_images_for_pan_and_scan( |
| 199 | + images=image, |
| 200 | + pan_and_scan_min_crop_size=pan_and_scan_min_crop_size, |
| 201 | + pan_and_scan_max_num_crops=pan_and_scan_max_num_crops, |
| 202 | + pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate, |
| 203 | + input_data_format=input_data_format, |
| 204 | + ) |
| 205 | + for image in images |
| 206 | + ] |
| 207 | + |
| 208 | + images_and_crops_list = [images for images, _, in image] |
| 209 | + num_crops = [num_crops for _, num_crops in image] |
| 210 | + |
| 211 | + for batch_idx, (images_and_crops, prompt_text, num_of_crops) in enumerate(zip(images_and_crops_list, prompts, num_crops)): |
| 212 | + |
| 213 | + image_tag_indexes = [m.start() for m in re.finditer(image_tag, prompt_text)] |
| 214 | + |
| 215 | + if len(images_and_crops) != len(image_tag_indexes): |
| 216 | + raise ValueError( |
| 217 | + f"Prompt contained {len(image_tag_indexes)} image tokens but received {len(images_and_crops)} images." |
| 218 | + ) |
| 219 | + |
| 220 | + for num, idx in reversed(list(zip(num_of_crops, image_tag_indexes))): |
| 221 | + if num: |
| 222 | + formatted_image_text = ( |
| 223 | + f"Here is the original image {image_tag} and here are some crops to help you see better " |
| 224 | + + " ".join([image_tag] * num) |
| 225 | + ) |
| 226 | + prompt_text = prompt_text[:idx] + formatted_image_text + prompt_text[idx + len(image_tag) :] |
| 227 | + |
| 228 | + crops_and_prompts['crops'].append(images_and_crops) |
| 229 | + crops_and_prompts['modified_prompts'].append(prompt_text) |
| 230 | + |
| 231 | + return crops_and_prompts |
| 232 | + |
| 233 | +def to_pil_image(image, rescale=None): |
| 234 | + |
| 235 | + if isinstance(image, np.ndarray): |
| 236 | + if rescale is None: |
| 237 | + # rescale default to the array being of floating type. |
| 238 | + rescale = isinstance(image.flat[0], np.floating) |
| 239 | + # If the channel as been moved to first dim, we put it back at the end. |
| 240 | + if image.ndim == 3 and image.shape[0] in [1, 3]: |
| 241 | + image = image.transpose(1, 2, 0) |
| 242 | + if rescale: |
| 243 | + image = image * 255 |
| 244 | + image = image.astype(np.uint8) |
| 245 | + return PIL.Image.fromarray(image) |
| 246 | + return image |
| 247 | + |
| 248 | +def resize(image, resample: PIL.Image.Resampling = PIL.Image.Resampling.BILINEAR): |
| 249 | + height = 896 |
| 250 | + width = 896 |
| 251 | + size = (height, width) |
| 252 | + if not isinstance(image, PIL.Image.Image): |
| 253 | + image = to_pil_image(image) |
| 254 | + return image.resize(size, resample=resample) |
| 255 | + |
0 commit comments