Skip to content

Commit 992ba69

Browse files
abheesht17The tunix Authors
authored andcommitted
Add image processor
This PR adds a basic image processor, which takes in a batch of images and processes them (resizing, normalising, etc.). Verification: https://colab.research.google.com/gist/abheesht17/3ca408a919bbda9d4400f6c30f193dcd/-tunix-vlm-image-processor-verification.ipynb PiperOrigin-RevId: 867024240
1 parent b39d48b commit 992ba69

File tree

8 files changed

+311
-61
lines changed

8 files changed

+311
-61
lines changed

tests/models/gemma3/utils_test.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,19 @@
1717
import jax.numpy as jnp
1818
import numpy as np
1919
from tunix.models.gemma3 import utils
20-
from tunix.models.gemma3 import vision
20+
21+
22+
_TOKEN_PLACEHOLDER = 219
2123

2224

2325
class UtilsTest(parameterized.TestCase):
2426

25-
def test_get_positions_and_attention_mask_not_multimodal(self):
27+
def test_get_attention_mask_not_multimodal(self):
2628
tokens = jnp.array([[1, 2, 3, utils._PADDING_ID, utils._PADDING_ID]])
27-
result = utils.get_positions_and_attention_mask(tokens)
28-
positions = result['positions']
29-
attention_mask = result['attention_mask']
29+
attention_mask = utils.get_attention_mask(
30+
tokens, token_placeholder_id=_TOKEN_PLACEHOLDER
31+
)
3032

31-
expected_positions = jnp.array([[0, 1, 2, 2, 2]])
3233
expected_attention_mask = jnp.array(
3334
[[
3435
[1, 0, 0, 0, 0],
@@ -39,23 +40,21 @@ def test_get_positions_and_attention_mask_not_multimodal(self):
3940
]],
4041
dtype=jnp.bool_,
4142
)
42-
np.testing.assert_array_equal(positions, expected_positions)
4343
np.testing.assert_array_equal(attention_mask, expected_attention_mask)
4444

45-
def test_get_positions_and_attention_mask_multimodal(self):
45+
def test_get_attention_mask_multimodal(self):
4646
tokens = jnp.array([[
4747
1,
4848
2,
49-
vision.TOKEN_PLACEHOLDER,
50-
vision.TOKEN_PLACEHOLDER,
49+
_TOKEN_PLACEHOLDER,
50+
_TOKEN_PLACEHOLDER,
5151
3,
5252
utils._PADDING_ID,
5353
]])
54-
result = utils.get_positions_and_attention_mask(tokens)
55-
positions = result['positions']
56-
attention_mask = result['attention_mask']
54+
attention_mask = utils.get_attention_mask(
55+
tokens, token_placeholder_id=_TOKEN_PLACEHOLDER
56+
)
5757

58-
expected_positions = jnp.array([[0, 1, 2, 3, 4, 4]])
5958
expected_attention_mask = jnp.array(
6059
[[
6160
[1, 0, 0, 0, 0, 0],
@@ -67,19 +66,15 @@ def test_get_positions_and_attention_mask_multimodal(self):
6766
]],
6867
dtype=jnp.bool_,
6968
)
70-
np.testing.assert_array_equal(positions, expected_positions)
7169
np.testing.assert_array_equal(attention_mask, expected_attention_mask)
7270

73-
def test_get_positions_and_attention_mask_precomputed_mask(self):
71+
def test_get_attention_mask_precomputed_mask(self):
7472
tokens = jnp.array([[1, 2, 3, utils._PADDING_ID, utils._PADDING_ID]])
7573
inputs_mask = jnp.array([[1, 0, 1, 0, 0]])
76-
result = utils.get_positions_and_attention_mask(
77-
tokens, inputs_mask=inputs_mask
74+
attention_mask = utils.get_attention_mask(
75+
tokens, inputs_mask=inputs_mask, token_placeholder_id=_TOKEN_PLACEHOLDER
7876
)
79-
positions = result['positions']
80-
attention_mask = result['attention_mask']
8177

82-
expected_positions = jnp.array([[0, 0, 1, 1, 1]])
8378
expected_attention_mask = jnp.array(
8479
[[
8580
[1, 0, 0, 0, 0],
@@ -90,7 +85,6 @@ def test_get_positions_and_attention_mask_precomputed_mask(self):
9085
]],
9186
dtype=jnp.bool_,
9287
)
93-
np.testing.assert_array_equal(positions, expected_positions)
9488
np.testing.assert_array_equal(attention_mask, expected_attention_mask)
9589

9690

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import dataclasses
2+
import os
3+
import shutil
4+
import tempfile
5+
from absl.testing import absltest
6+
from absl.testing import parameterized
7+
import numpy as np
8+
from PIL import Image
9+
from tunix.processors import image_processor
10+
11+
12+
@dataclasses.dataclass(slots=True, kw_only=True)
13+
class DummyConfig:
14+
15+
image_height: int = 32
16+
image_width: int = 32
17+
image_channels: int = 3
18+
image_mean: tuple[float, ...] = (127.5, 127.5, 127.5)
19+
image_std: tuple[float, ...] = (127.5, 127.5, 127.5)
20+
21+
22+
class ImageProcessorTest(parameterized.TestCase):
23+
24+
def setUp(self):
25+
super().setUp()
26+
self.height = 32
27+
self.width = 32
28+
self.channels = 3
29+
config = DummyConfig(
30+
image_height=self.height,
31+
image_width=self.width,
32+
image_channels=self.channels,
33+
)
34+
self.processor = image_processor.ImageProcessor(config)
35+
36+
def _create_dummy_image_file(self, filename='test_image.png'):
37+
img_array = np.zeros((100, 100, 3), dtype=np.uint8)
38+
img = Image.fromarray(img_array)
39+
40+
temp_dir = tempfile.mkdtemp()
41+
self.addCleanup(lambda: shutil.rmtree(temp_dir))
42+
43+
temp_file = os.path.join(temp_dir, filename)
44+
img.save(temp_file)
45+
return temp_file
46+
47+
def test_process_none_image(self):
48+
processed_image = self.processor.preprocess_image(None)
49+
self.assertEqual(
50+
processed_image.shape, (self.height, self.width, self.channels)
51+
)
52+
np.testing.assert_array_equal(processed_image, np.zeros((32, 32, 3)))
53+
54+
def test_path_input(self):
55+
img_path = self._create_dummy_image_file()
56+
processed_image = self.processor.preprocess_image(img_path)
57+
self.assertEqual(
58+
processed_image.shape, (self.height, self.width, self.channels)
59+
)
60+
np.testing.assert_allclose(processed_image, -1.0 * np.ones((32, 32, 3)))
61+
62+
def test_array_input(self):
63+
img_array = np.zeros((100, 100, 3), dtype=np.uint8)
64+
processed_image = self.processor.preprocess_image(img_array)
65+
self.assertEqual(
66+
processed_image.shape, (self.height, self.width, self.channels)
67+
)
68+
np.testing.assert_allclose(processed_image, -1.0 * np.ones((32, 32, 3)))
69+
70+
@parameterized.named_parameters(
71+
dict(testcase_name='array', input_type='array'),
72+
dict(testcase_name='path', input_type='path'),
73+
)
74+
def test_call_one_image(self, input_type):
75+
if input_type == 'array':
76+
images = [np.zeros((100, 100, 3), dtype=np.uint8)]
77+
elif input_type == 'path':
78+
images = [self._create_dummy_image_file()]
79+
80+
processed_images = self.processor(images=images) # pylint: disable=undefined-variable
81+
self.assertLen(processed_images, 1)
82+
self.assertLen(processed_images[0], 1)
83+
self.assertEqual(
84+
processed_images[0][0].shape, (self.height, self.width, self.channels) # pytype: disable=attribute-error
85+
)
86+
np.testing.assert_allclose(
87+
processed_images[0][0], -1.0 * np.ones((32, 32, 3))
88+
)
89+
90+
def test_padding(self):
91+
img1 = np.zeros((100, 100, 3), dtype=np.uint8)
92+
img2 = np.zeros((50, 50, 3), dtype=np.uint8)
93+
images = [[img1], [img1, img2]]
94+
processed_images = self.processor(images=images)
95+
self.assertLen(processed_images, 2)
96+
self.assertLen(processed_images[0], 2) # Padded to 2
97+
self.assertLen(processed_images[1], 2)
98+
np.testing.assert_allclose(
99+
processed_images[0][0], -1.0 * np.ones((32, 32, 3))
100+
)
101+
# Padded image should be zeros
102+
np.testing.assert_allclose(processed_images[0][1], np.zeros((32, 32, 3)))
103+
np.testing.assert_allclose(
104+
processed_images[1][0], -1.0 * np.ones((32, 32, 3))
105+
)
106+
np.testing.assert_allclose(
107+
processed_images[1][1], -1.0 * np.ones((32, 32, 3))
108+
)
109+
110+
def test_call_with_none_in_batch(self):
111+
images = [None, [np.zeros((100, 100, 3), dtype=np.uint8)]]
112+
processed_images = self.processor(images=images)
113+
self.assertLen(processed_images, 2)
114+
self.assertLen(processed_images[0], 1)
115+
self.assertLen(processed_images[1], 1)
116+
np.testing.assert_allclose(processed_images[0][0], np.zeros((32, 32, 3)))
117+
np.testing.assert_allclose(
118+
processed_images[1][0], -1.0 * np.ones((32, 32, 3))
119+
)
120+
121+
122+
if __name__ == '__main__':
123+
absltest.main()

tunix/models/gemma3/model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ModelConfig:
115115
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
116116
)
117117

118-
siglip_config: vision.SigLIPConfig | None = None
118+
vision_config: vision.SigLIPConfig | None = None
119119

120120
shd_config: ShardingConfig = ShardingConfig.get_default_sharding()
121121
remat_config: RematConfig = RematConfig.NONE
@@ -203,7 +203,7 @@ def _gemma3_4b(
203203
local_base_frequency=10_000,
204204
global_base_frequency=1_000_000,
205205
global_scale_factor=8.0,
206-
siglip_config=None if text_only else vision.SigLIPConfig(),
206+
vision_config=None if text_only else vision.SigLIPConfig(),
207207
shd_config=sharding_config,
208208
)
209209

@@ -245,7 +245,7 @@ def _gemma3_12b(
245245
local_base_frequency=10_000,
246246
global_base_frequency=1_000_000,
247247
global_scale_factor=8.0,
248-
siglip_config=None if text_only else vision.SigLIPConfig(),
248+
vision_config=None if text_only else vision.SigLIPConfig(),
249249
shd_config=sharding_config,
250250
)
251251

@@ -287,7 +287,7 @@ def _gemma3_27b(
287287
local_base_frequency=10_000,
288288
global_base_frequency=1_000_000,
289289
global_scale_factor=8.0,
290-
siglip_config=None if text_only else vision.SigLIPConfig(),
290+
vision_config=None if text_only else vision.SigLIPConfig(),
291291
shd_config=sharding_config,
292292
)
293293

@@ -911,9 +911,9 @@ class Gemma3(nnx.Module):
911911
def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs):
912912
self.config = config
913913

914-
if config.siglip_config is not None:
914+
if config.vision_config is not None:
915915
self.vision_encoder = vision.SigLiP(
916-
config=config.siglip_config,
916+
config=config.vision_config,
917917
shd_config=config.shd_config.siglip,
918918
rngs=rngs,
919919
)
@@ -1009,7 +1009,7 @@ def _encode_and_get_inputs(
10091009
images: jaxtyping.Array | None = None, # (B, H, W, C) or (B, N, H, W, C)
10101010
) -> jaxtyping.Array:
10111011
"""Encode the text tokens, eventually including the vision embeddings."""
1012-
if images is not None:
1012+
if self.config.vision_config is not None and images is not None:
10131013
self._assert_support_mm()
10141014
if len(images.shape) == 4: # If num_images is 1, add an axis.
10151015
images = einops.rearrange(images, 'b h w c -> b 1 h w c')
@@ -1048,7 +1048,7 @@ def _merge_mm_embeddings(
10481048
merged_embeddings = merge_embeddings_lib.merge_embeddings(
10491049
text_embeddings=embeddings,
10501050
vision_embeddings=soft_embeddings,
1051-
mask=tokens == vision.TOKEN_PLACEHOLDER,
1051+
mask=tokens == self.config.vision_config.soft_token_placeholder_id, # pytype: disable=attribute-error
10521052
)
10531053

10541054
return merged_embeddings

tunix/models/gemma3/params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def create_model_from_checkpoint(
9191
)
9292
params = ocp.StandardCheckpointer().restore(checkpoint_path)
9393
params = map_from_upstream_checkpoint(
94-
params, text_only=model_config.siglip_config is None
94+
params, text_only=model_config.vision_config is None
9595
)
9696

9797
if mesh is not None:

tunix/models/gemma3/params_safetensors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _get_key_and_transform_mapping(cfg: model_lib.ModelConfig):
113113
}
114114

115115
# Vision Tower (SigLIP).
116-
if cfg.siglip_config is not None:
116+
if cfg.vision_config is not None:
117117
mapping.update({
118118
r"vision_tower\.vision_model\.embeddings\.patch_embedding\.weight": (
119119
"vision_encoder.siglip_encoder.embedding.kernel",

tunix/models/gemma3/utils.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,29 @@
1818

1919
import jax.numpy as jnp
2020
import jaxtyping
21-
from tunix.models.gemma3 import vision
2221

2322
_PADDING_ID = 0
2423

2524

26-
def get_positions_and_attention_mask(
25+
def get_attention_mask(
2726
tokens: jaxtyping.ArrayLike, # (B, L)
2827
*,
2928
inputs_mask: jaxtyping.ArrayLike | None = None, # (B, L, L')
29+
token_placeholder_id: int = 219,
3030
):
31-
"""Returns the positions and attention mask for the transformer."""
31+
"""Returns the attention mask for the transformer."""
3232
# Compute the mask
3333
if inputs_mask is None:
3434
inputs_mask = tokens != _PADDING_ID
35-
positions = _build_positions_from_mask(inputs_mask)
3635

3736
# The image tokens have bidirectional attention within themselves.
38-
bidirectional_mask = tokens == vision.TOKEN_PLACEHOLDER
37+
bidirectional_mask = tokens == token_placeholder_id
3938
attention_mask = make_causal_bidirectional_attention_mask(
4039
inputs_mask,
4140
bidirectional_mask=bidirectional_mask,
4241
)
4342

44-
return {
45-
'positions': positions,
46-
'attention_mask': attention_mask,
47-
}
43+
return attention_mask
4844

4945

5046
def make_causal_bidirectional_attention_mask(
@@ -153,21 +149,3 @@ def _add_bidirectional_mask(
153149
& (q_block_indices[..., None] > 0)
154150
)
155151
return attn_mask
156-
157-
158-
def _build_positions_from_mask(
159-
input_mask: jaxtyping.ArrayLike,
160-
) -> jaxtyping.ArrayLike:
161-
"""Computes the `positions` from the `input_mask`.
162-
163-
Args:
164-
input_mask: The tokens `input_mask`, True for non-padded tokens only.
165-
166-
Returns:
167-
The indices to use for RoPE and absolute position encodings for the given
168-
input mask.
169-
"""
170-
positions = jnp.cumsum(input_mask, axis=-1)
171-
# Subtract one for all positions from the first valid one as they are
172-
# 0-indexed
173-
return positions - (positions >= 1)

tunix/models/gemma3/vision.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929
from tunix.utils import compat
3030
from tunix.utils import sharding_utils
3131

32-
TOKEN_PLACEHOLDER = 262144
33-
3432

3533
@dataclasses.dataclass(slots=True, frozen=True)
3634
class SigLIPShardingConfig:
@@ -84,9 +82,14 @@ class SigLIPConfig:
8482

8583
num_mm_tokens_per_image_prepool: int = 4096
8684
num_mm_tokens_per_image: int = 256
85+
86+
# Processor args
8787
image_height: int = 896
8888
image_width: int = 896
8989
image_channels: int = 3
90+
image_mean: tuple[float, ...] = (127.5, 127.5, 127.5)
91+
image_std: tuple[float, ...] = (127.5, 127.5, 127.5)
92+
soft_token_placeholder: int = 219
9093

9194
patch_size: tuple[int, int] = (14, 14)
9295
width: int = 1152

0 commit comments

Comments
 (0)