Skip to content

Commit 7905fc8

Browse files
abheesht17The tunix Authors
authored andcommitted
Allow sampler to take in images
Verification: https://colab.research.google.com/gist/abheesht17/e3e31d7ff5bb302928494dcf48b77e5c/tunix-vlm-text-generation.ipynb In order to allow the text sampler to take in images, we only need to take care of the pre-fill phase, because the sampler will be text, image-in, text-out. We do two things: - Call the image processor inside the sampler __call__ method so as to process the images. - We add a method to the Gemma 3 model class - `get_positions_and_attention_mask`. If the model has this class, it will be used inside the sampler during the pre-fill phase. It is necessary for vision models to have this class if any custom token processing is needed. PiperOrigin-RevId: 870890814
1 parent a3389dc commit 7905fc8

File tree

11 files changed

+517
-45
lines changed

11 files changed

+517
-45
lines changed

tests/generate/sampler_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,66 @@ def test_samples_padding_output(self, max_prompt_length, echo, return_logits):
111111
result_padded.tokens[i].shape[0], max_generation_steps
112112
)
113113

114+
def test_multimodal_samples(self):
115+
vocab = tc.MockVocab(is_multimodal=True)
116+
transformer = tc.ToyTransformer(
117+
config=tc.ModelConfig(
118+
vocab_size=vocab.GetPieceSize(), vision_config=tc.VisionConfig()
119+
),
120+
rngs=nnx.Rngs(42),
121+
)
122+
123+
class DummyImageProcessor:
124+
125+
def __call__(self, images):
126+
# returns dummy processed images
127+
return np.ones((len(images), 1, 32, 32, 3), dtype=np.float32)
128+
129+
image_processor = DummyImageProcessor()
130+
131+
sampler = sampler_lib.Sampler(
132+
transformer=transformer,
133+
tokenizer=vocab,
134+
cache_config=sampler_lib.CacheConfig(
135+
cache_size=64,
136+
num_layers=4,
137+
num_kv_heads=4,
138+
head_dim=16,
139+
),
140+
image_processor=image_processor,
141+
)
142+
143+
max_generation_steps = 8
144+
145+
# We pass in 2 strings and 2 corresponding dummy images
146+
images = [
147+
np.zeros((32, 32, 3)),
148+
np.zeros((32, 32, 3)),
149+
]
150+
151+
result = sampler(
152+
[
153+
'quantization <soi> <img> <img> Tunix',
154+
'<soi> <img> <img> Parallax distributed',
155+
],
156+
max_generation_steps=max_generation_steps,
157+
return_logits=True,
158+
max_prompt_length=8,
159+
echo=True,
160+
images=images,
161+
)
162+
163+
self.assertIsNotNone(result)
164+
self.assertReasonableTensor(result.tokens)
165+
self.assertReasonableTensor(result.logits)
166+
np.testing.assert_allclose(
167+
result.tokens,
168+
np.array([
169+
[1, 21, 23, 22, 22, 14, 8, 25, 8, 25, 8, 25, 8, 25],
170+
[1, 23, 22, 22, 15, 18, 8, 25, 8, 25, 8, 25, 8, 25],
171+
]),
172+
)
173+
114174
@parameterized.named_parameters(
115175
dict(
116176
testcase_name='case1',

tests/models/gemma3/utils_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
import jax.numpy as jnp
1818
import numpy as np
1919
from tunix.models.gemma3 import utils
20-
from tunix.models.gemma3 import vision
20+
21+
22+
_TOKEN_PLACEHOLDER = 219
2123

2224

2325
class UtilsTest(parameterized.TestCase):
@@ -46,12 +48,14 @@ def test_get_positions_and_attention_mask_multimodal(self):
4648
tokens = jnp.array([[
4749
1,
4850
2,
49-
vision.TOKEN_PLACEHOLDER,
50-
vision.TOKEN_PLACEHOLDER,
51+
_TOKEN_PLACEHOLDER,
52+
_TOKEN_PLACEHOLDER,
5153
3,
5254
utils._PADDING_ID,
5355
]])
54-
result = utils.get_positions_and_attention_mask(tokens)
56+
result = utils.get_positions_and_attention_mask(
57+
tokens, token_placeholder_id=_TOKEN_PLACEHOLDER
58+
)
5559
positions = result['positions']
5660
attention_mask = result['attention_mask']
5761

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import dataclasses
2+
import os
3+
import tempfile
4+
from absl.testing import absltest
5+
from absl.testing import parameterized
6+
import numpy as np
7+
from PIL import Image
8+
from tunix.processors import image_processor
9+
10+
11+
@dataclasses.dataclass(slots=True, kw_only=True)
12+
class DummyConfig:
13+
14+
image_height: int = 32
15+
image_width: int = 32
16+
image_channels: int = 3
17+
image_mean: tuple[float, ...] = (127.5, 127.5, 127.5)
18+
image_std: tuple[float, ...] = (127.5, 127.5, 127.5)
19+
20+
21+
class ImageProcessorTest(parameterized.TestCase):
22+
23+
def setUp(self):
24+
super().setUp()
25+
self.height = 32
26+
self.width = 32
27+
self.channels = 3
28+
config = DummyConfig(
29+
image_height=self.height,
30+
image_width=self.width,
31+
image_channels=self.channels,
32+
)
33+
self.processor = image_processor.ImageProcessor(config)
34+
35+
def _create_dummy_image_file(self, filename='test_image.png'):
36+
img_array = np.zeros((100, 100, 3), dtype=np.uint8)
37+
img = Image.fromarray(img_array)
38+
try:
39+
temp_path = self.create_tempdir().full_path
40+
except Exception:
41+
temp_path = tempfile.TemporaryDirectory().name
42+
temp_file = os.path.join(temp_path, filename)
43+
img.save(temp_file)
44+
return temp_file
45+
46+
def test_process_none_image(self):
47+
processed_image = self.processor.preprocess_image(None)
48+
self.assertEqual(
49+
processed_image.shape, (self.height, self.width, self.channels)
50+
)
51+
np.testing.assert_array_equal(processed_image, np.zeros((32, 32, 3)))
52+
53+
def test_path_input(self):
54+
img_path = self._create_dummy_image_file()
55+
processed_image = self.processor.preprocess_image(img_path)
56+
self.assertEqual(
57+
processed_image.shape, (self.height, self.width, self.channels)
58+
)
59+
np.testing.assert_allclose(processed_image, -1.0 * np.ones((32, 32, 3)))
60+
61+
def test_array_input(self):
62+
img_array = np.zeros((100, 100, 3), dtype=np.uint8)
63+
processed_image = self.processor.preprocess_image(img_array)
64+
self.assertEqual(
65+
processed_image.shape, (self.height, self.width, self.channels)
66+
)
67+
np.testing.assert_allclose(processed_image, -1.0 * np.ones((32, 32, 3)))
68+
69+
@parameterized.named_parameters(
70+
dict(testcase_name='array', input_type='array'),
71+
dict(testcase_name='path', input_type='path'),
72+
)
73+
def test_call_one_image(self, input_type):
74+
if input_type == 'array':
75+
images = [np.zeros((100, 100, 3), dtype=np.uint8)]
76+
elif input_type == 'path':
77+
images = [self._create_dummy_image_file()]
78+
79+
processed_images = self.processor(images=images) # pylint: disable=undefined-variable
80+
self.assertLen(processed_images, 1)
81+
self.assertLen(processed_images[0], 1)
82+
self.assertEqual(
83+
processed_images[0][0].shape, (self.height, self.width, self.channels) # pytype: disable=attribute-error
84+
)
85+
np.testing.assert_allclose(
86+
processed_images[0][0], -1.0 * np.ones((32, 32, 3))
87+
)
88+
89+
def test_padding(self):
90+
img1 = np.zeros((100, 100, 3), dtype=np.uint8)
91+
img2 = np.zeros((50, 50, 3), dtype=np.uint8)
92+
images = [[img1], [img1, img2]]
93+
processed_images = self.processor(images=images)
94+
self.assertLen(processed_images, 2)
95+
self.assertLen(processed_images[0], 2) # Padded to 2
96+
self.assertLen(processed_images[1], 2)
97+
np.testing.assert_allclose(
98+
processed_images[0][0], -1.0 * np.ones((32, 32, 3))
99+
)
100+
# Padded image should be zeros
101+
np.testing.assert_allclose(processed_images[0][1], np.zeros((32, 32, 3)))
102+
np.testing.assert_allclose(
103+
processed_images[1][0], -1.0 * np.ones((32, 32, 3))
104+
)
105+
np.testing.assert_allclose(
106+
processed_images[1][1], -1.0 * np.ones((32, 32, 3))
107+
)
108+
109+
def test_call_with_none_in_batch(self):
110+
images = [None, [np.zeros((100, 100, 3), dtype=np.uint8)]]
111+
processed_images = self.processor(images=images)
112+
self.assertLen(processed_images, 2)
113+
self.assertLen(processed_images[0], 1)
114+
self.assertLen(processed_images[1], 1)
115+
np.testing.assert_allclose(processed_images[0][0], np.zeros((32, 32, 3)))
116+
np.testing.assert_allclose(
117+
processed_images[1][0], -1.0 * np.ones((32, 32, 3))
118+
)
119+
120+
121+
if __name__ == '__main__':
122+
absltest.main()

0 commit comments

Comments
 (0)