Skip to content

Commit 55d1d9e

Browse files
Adding pan and scan to gemma 3
1 parent a5337f5 commit 55d1d9e

File tree

1 file changed

+255
-0
lines changed

1 file changed

+255
-0
lines changed
+255
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
from enum import Enum
2+
from typing import Optional, Union, Tuple, Dict, List
3+
import tensorflow as tf
4+
from keras import ops
5+
import logging
6+
import numpy as np
7+
8+
import math
9+
import itertools
10+
import re
11+
12+
import PIL.Image
13+
import PIL.ImageOps
14+
15+
logger = logging.getLogger(__name__)
16+
17+
class ExplicitEnum(str, Enum):
18+
"""
19+
Enum with more explicit error message for missing values.
20+
"""
21+
22+
@classmethod
23+
def _missing_(cls, value):
24+
raise ValueError(
25+
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
26+
)
27+
28+
class ChannelDimension(ExplicitEnum):
29+
FIRST = "channels_first"
30+
LAST = "channels_last"
31+
32+
def infer_channel_dimension_format(
33+
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
34+
) -> ChannelDimension:
35+
"""
36+
Infers the channel dimension format of `image`.
37+
38+
Args:
39+
image (`np.ndarray`):
40+
The image to infer the channel dimension of.
41+
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
42+
The number of channels of the image.
43+
44+
Returns:
45+
The channel dimension of the image.
46+
"""
47+
num_channels = num_channels if num_channels is not None else (1, 3)
48+
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
49+
50+
if image.ndim == 3:
51+
first_dim, last_dim = 0, 2
52+
elif image.ndim == 4:
53+
first_dim, last_dim = 1, 3
54+
else:
55+
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
56+
57+
image_shape = image.shape
58+
59+
if image_shape[first_dim] in num_channels and image_shape[last_dim] in num_channels:
60+
logger.warning(
61+
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
62+
)
63+
return ChannelDimension.FIRST
64+
elif image_shape[first_dim] in num_channels:
65+
return ChannelDimension.FIRST
66+
elif image_shape[last_dim] in num_channels:
67+
return ChannelDimension.LAST
68+
raise ValueError("Unable to infer channel dimension format")
69+
70+
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
71+
"""
72+
Returns the (height, width) dimensions of the image.
73+
74+
Args:
75+
image (`np.ndarray`):
76+
The image to get the dimensions of.
77+
channel_dim (`ChannelDimension`, *optional*):
78+
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
79+
80+
Returns:
81+
A tuple of the image's height and width.
82+
"""
83+
if channel_dim is None:
84+
channel_dim = infer_channel_dimension_format(image)
85+
86+
image_shape = image.shape
87+
88+
if channel_dim == ChannelDimension.FIRST:
89+
return image_shape[-2], image_shape[-1]
90+
elif channel_dim == ChannelDimension.LAST:
91+
return image_shape[-3], image_shape[-2]
92+
else:
93+
raise ValueError(f"Unsupported data format: {channel_dim}")
94+
95+
def pan_and_scan(
96+
image: np.ndarray,
97+
pan_and_scan_min_crop_size: int,
98+
pan_and_scan_max_num_crops: int,
99+
pan_and_scan_min_ratio_to_activate: float,
100+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
101+
):
102+
height, width = get_image_size(image)
103+
104+
# Square or landscape image.
105+
if width >= height:
106+
# Only apply PaS if the image is sufficiently exaggerated
107+
if width / height < pan_and_scan_min_ratio_to_activate:
108+
return []
109+
110+
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
111+
num_crops_w = int(math.floor(width / height + 0.5)) # Half round up rounding.
112+
num_crops_w = min(int(math.floor(width / pan_and_scan_min_crop_size)), num_crops_w)
113+
114+
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
115+
num_crops_w = max(2, num_crops_w)
116+
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
117+
num_crops_h = 1
118+
119+
# Portrait image.
120+
else:
121+
# Only apply PaS if the image is sufficiently exaggerated
122+
if height / width < pan_and_scan_min_ratio_to_activate:
123+
return []
124+
125+
# Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
126+
num_crops_h = int(math.floor(height / width + 0.5))
127+
num_crops_h = min(int(math.floor(height / pan_and_scan_min_crop_size)), num_crops_h)
128+
129+
# Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
130+
num_crops_h = max(2, num_crops_h)
131+
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
132+
num_crops_w = 1
133+
134+
crop_size_w = int(math.ceil(width / num_crops_w))
135+
crop_size_h = int(math.ceil(height / num_crops_h))
136+
137+
# Don't apply PaS if crop size is too small.
138+
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
139+
return []
140+
141+
crop_positions_w = [crop_size_w * i for i in range(num_crops_w)]
142+
crop_positions_h = [crop_size_h * i for i in range(num_crops_h)]
143+
144+
if input_data_format == ChannelDimension.LAST:
145+
image_crops = [
146+
image[pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
147+
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
148+
]
149+
else:
150+
image_crops = [
151+
image[:, pos_h : pos_h + crop_size_h, pos_w : pos_w + crop_size_w]
152+
for pos_h, pos_w in itertools.product(crop_positions_h, crop_positions_w)
153+
]
154+
155+
return image_crops
156+
157+
def _process_images_for_pan_and_scan(
158+
images: np.ndarray,
159+
pan_and_scan_min_crop_size: int,
160+
pan_and_scan_max_num_crops: int,
161+
pan_and_scan_min_ratio_to_activate: float,
162+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
163+
):
164+
batched_pas_images_list = []
165+
num_crops = []
166+
167+
for image in images:
168+
pas_images = pan_and_scan(
169+
image=image,
170+
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
171+
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
172+
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
173+
input_data_format=input_data_format,
174+
)
175+
176+
batched_pas_images_list.append([image] + pas_images)
177+
num_crops.append(len(pas_images))
178+
179+
return batched_pas_images_list, num_crops
180+
181+
def do_pan_and_scan(
182+
inputs: dict,
183+
pan_and_scan_min_crop_size: int,
184+
pan_and_scan_max_num_crops: int,
185+
pan_and_scan_min_ratio_to_activate: float,
186+
):
187+
188+
crops_and_prompts = dict()
189+
crops_and_prompts['crops'] = []
190+
crops_and_prompts['modified_prompts'] = []
191+
images = inputs.get("images", None)
192+
prompts = inputs["prompts"]
193+
image_tag = "<img>"
194+
195+
input_data_format = infer_channel_dimension_format(images[0][0])
196+
197+
image = [
198+
_process_images_for_pan_and_scan(
199+
images=image,
200+
pan_and_scan_min_crop_size=pan_and_scan_min_crop_size,
201+
pan_and_scan_max_num_crops=pan_and_scan_max_num_crops,
202+
pan_and_scan_min_ratio_to_activate=pan_and_scan_min_ratio_to_activate,
203+
input_data_format=input_data_format,
204+
)
205+
for image in images
206+
]
207+
208+
images_and_crops_list = [images for images, _, in image]
209+
num_crops = [num_crops for _, num_crops in image]
210+
211+
for batch_idx, (images_and_crops, prompt_text, num_of_crops) in enumerate(zip(images_and_crops_list, prompts, num_crops)):
212+
213+
image_tag_indexes = [m.start() for m in re.finditer(image_tag, prompt_text)]
214+
215+
if len(images_and_crops) != len(image_tag_indexes):
216+
raise ValueError(
217+
f"Prompt contained {len(image_tag_indexes)} image tokens but received {len(images_and_crops)} images."
218+
)
219+
220+
for num, idx in reversed(list(zip(num_of_crops, image_tag_indexes))):
221+
if num:
222+
formatted_image_text = (
223+
f"Here is the original image {image_tag} and here are some crops to help you see better "
224+
+ " ".join([image_tag] * num)
225+
)
226+
prompt_text = prompt_text[:idx] + formatted_image_text + prompt_text[idx + len(image_tag) :]
227+
228+
crops_and_prompts['crops'].append(images_and_crops)
229+
crops_and_prompts['modified_prompts'].append(prompt_text)
230+
231+
return crops_and_prompts
232+
233+
def to_pil_image(image, rescale=None):
234+
235+
if isinstance(image, np.ndarray):
236+
if rescale is None:
237+
# rescale default to the array being of floating type.
238+
rescale = isinstance(image.flat[0], np.floating)
239+
# If the channel as been moved to first dim, we put it back at the end.
240+
if image.ndim == 3 and image.shape[0] in [1, 3]:
241+
image = image.transpose(1, 2, 0)
242+
if rescale:
243+
image = image * 255
244+
image = image.astype(np.uint8)
245+
return PIL.Image.fromarray(image)
246+
return image
247+
248+
def resize(image, resample: PIL.Image.Resampling = PIL.Image.Resampling.BILINEAR):
249+
height = 896
250+
width = 896
251+
size = (height, width)
252+
if not isinstance(image, PIL.Image.Image):
253+
image = to_pil_image(image)
254+
return image.resize(size, resample=resample)
255+

0 commit comments

Comments
 (0)