Skip to content

Commit 1f11e3f

Browse files
Version bump to 0.18.0 and cherry pick (#2002)
* Adding PaliGemma2 to KerasHub (#1998) * Add PaliGemma2 (#96) * Add PaliGemma2 arch * Enable mixed precision check for PaliGemma * Add conversion script * Revert ImageConverter and reduce mem usage in the conversion script * Remove `compute_output_spec` * Fix `compute_output_shape` issue for keras 3.1 * Add model cards and update conversion script * update presets --------- Co-authored-by: divyashreepathihalli <[email protected]> * Update pali_gemma_presets.py - remove mix presets * Update pali_gemma_presets.py * Update convert_pali_gemma2_checkpoints.py --------- Co-authored-by: james77777778 <[email protected]> * Version bump to 0.18.0 * Update pali_gemma_presets.py (#2003) * Update pali_gemma_presets.py * code reformat * Adding PaliGemma2 to KerasHub (#1998) * Add PaliGemma2 (#96) * Add PaliGemma2 arch * Enable mixed precision check for PaliGemma * Add conversion script * Revert ImageConverter and reduce mem usage in the conversion script * Remove `compute_output_spec` * Fix `compute_output_shape` issue for keras 3.1 * Add model cards and update conversion script * update presets --------- Co-authored-by: divyashreepathihalli <[email protected]> * Update pali_gemma_presets.py - remove mix presets * Update pali_gemma_presets.py * Update convert_pali_gemma2_checkpoints.py --------- Co-authored-by: james77777778 <[email protected]> * Update pali_gemma_presets.py (#2003) * Update pali_gemma_presets.py * code reformat --------- Co-authored-by: james77777778 <[email protected]>
1 parent bdb7478 commit 1f11e3f

File tree

7 files changed

+851
-47
lines changed

7 files changed

+851
-47
lines changed

Diff for: keras_hub/src/models/pali_gemma/pali_gemma_backbone.py

+61-11
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone):
4848
a two-layer feedforward network for each transformer decoder block.
4949
head_dim: int. The size of each attention head in the mixed decoder.
5050
vit_patch_size: int. The size of each square patch in the input image.
51-
vit_num_heads: int. The number of attention heads for the vision(image)
51+
vit_num_heads: int. The number of attention heads for the vision (image)
5252
transformer encoder.
5353
vit_hidden_dim: int. The size of the transformer hidden state at the end
5454
of each vision transformer layer.
5555
vit_num_layers: int. The number of vision transformer layers.
5656
vit_intermediate_dim: int. The output dimension of the first Dense layer
57-
in a two-layer feedforward network for vision transformer.
58-
vit_pooling: string. The encoded vision embeddings are pooled using the
59-
specified polling setting. The accepted values are `"map"`, `"gap"`,
60-
`"0"` or `"none"`. Defaults to `"none"`.
57+
in a two-layer feedforward network for vision transformer. Defaults
58+
to `4304`.
59+
vit_pooling: `None` or string. The encoded vision embeddings are pooled
60+
using the specified polling setting. The accepted values are
61+
`"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
6162
vit_classifier_activation: activation function. The activation that
6263
is used for final output classification in the vision transformer.
64+
Defaults to `None`.
6365
vit_name: string. The name used for vision transformer layers.
66+
query_head_dim_normalize: boolean. If `True` normalize the query before
67+
attention with `head_dim`. If `False`, normalize the query with
68+
`hidden_dim / num_query_heads`. Defaults to `True`.
69+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
70+
block. Defaults to `False`.
71+
use_post_attention_norm: boolean. Whether to normalize after the attention
72+
block. Defaults to `False`.
73+
attention_logit_soft_cap: `None` or int. Soft cap for the attention
74+
logits. Defaults to `None`.
75+
final_logit_soft_cap: `None` or int. Soft cap for the final logits.
76+
Defaults to `None`.
77+
use_sliding_window_attention: boolean. Whether to use sliding local
78+
window attention. Defaults to `False`.
79+
sliding_window_size: int. Size of the sliding local window. Defaults to
80+
`4096`.
6481
layer_norm_epsilon: float. The epsilon value user for every layer norm
65-
in all transformer blocks.
82+
in all transformer blocks. Defaults to `1e-6`.
6683
dropout: float. Dropout probability for the Transformer decoder blocks.
84+
Defaults to `0`.
6785
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
6886
for the models computations and weights. Note that some
6987
computations, such as softmax and layer normalization will always
@@ -119,6 +137,13 @@ def __init__(
119137
vit_pooling=None,
120138
vit_classifier_activation=None,
121139
vit_name=None,
140+
query_head_dim_normalize=True,
141+
use_post_ffw_norm=False,
142+
use_post_attention_norm=False,
143+
attention_logit_soft_cap=None,
144+
final_logit_soft_cap=None,
145+
use_sliding_window_attention=False,
146+
sliding_window_size=4096,
122147
layer_norm_epsilon=1e-6,
123148
dropout=0,
124149
dtype=None,
@@ -136,6 +161,7 @@ def __init__(
136161
seed=None,
137162
),
138163
dtype=dtype,
164+
logit_soft_cap=final_logit_soft_cap,
139165
name="token_embedding",
140166
)
141167
# TODO Remove this. Work around for previous serialization bug.
@@ -155,12 +181,19 @@ def __init__(
155181
)
156182
self.transformer_layers = []
157183
for i in range(num_layers):
184+
sliding_window = use_sliding_window_attention and (i % 2 == 0)
158185
layer = PaliGemmaDecoderBlock(
159186
hidden_dim=hidden_dim,
160187
intermediate_dim=intermediate_dim,
161-
num_query_heads=num_query_heads,
162188
head_dim=head_dim,
189+
num_query_heads=num_query_heads,
163190
num_key_value_heads=num_key_value_heads,
191+
query_head_dim_normalize=query_head_dim_normalize,
192+
use_post_ffw_norm=use_post_ffw_norm,
193+
use_post_attention_norm=use_post_attention_norm,
194+
logit_soft_cap=attention_logit_soft_cap,
195+
use_sliding_window_attention=sliding_window,
196+
sliding_window_size=sliding_window_size,
164197
dropout=dropout,
165198
dtype=dtype,
166199
name=f"decoder_block_{i}",
@@ -173,7 +206,9 @@ def __init__(
173206
)
174207

175208
# === Functional Model ===
176-
image_input = self.vit_encoder.inputs[0]
209+
image_input = keras.Input(
210+
shape=(image_size, image_size, 3), name="images"
211+
)
177212
token_id_input = keras.Input(
178213
shape=(None,), dtype="int32", name="token_ids"
179214
)
@@ -219,7 +254,15 @@ def __init__(
219254
self.head_dim = head_dim
220255
self.layer_norm_epsilon = layer_norm_epsilon
221256
self.dropout = dropout
222-
# VIT Params
257+
# Gemma2 params
258+
self.query_head_dim_normalize = query_head_dim_normalize
259+
self.use_post_ffw_norm = use_post_ffw_norm
260+
self.use_post_attention_norm = use_post_attention_norm
261+
self.attention_logit_soft_cap = attention_logit_soft_cap
262+
self.final_logit_soft_cap = final_logit_soft_cap
263+
self.sliding_window_size = sliding_window_size
264+
self.use_sliding_window_attention = use_sliding_window_attention
265+
# ViT params
223266
self.vit_patch_size = vit_patch_size
224267
self.vit_num_heads = vit_num_heads
225268
self.vit_hidden_dim = vit_hidden_dim
@@ -243,8 +286,6 @@ def get_config(self):
243286
"hidden_dim": self.hidden_dim,
244287
"intermediate_dim": self.intermediate_dim,
245288
"head_dim": self.head_dim,
246-
"layer_norm_epsilon": self.layer_norm_epsilon,
247-
"dropout": self.dropout,
248289
"vit_patch_size": self.vit_patch_size,
249290
"vit_num_heads": self.vit_num_heads,
250291
"vit_hidden_dim": self.vit_hidden_dim,
@@ -253,6 +294,15 @@ def get_config(self):
253294
"vit_pooling": self.vit_pooling,
254295
"vit_classifier_activation": self.vit_classifier_activation,
255296
"vit_name": self.vit_name,
297+
"query_head_dim_normalize": self.query_head_dim_normalize,
298+
"use_post_ffw_norm": self.use_post_ffw_norm,
299+
"use_post_attention_norm": self.use_post_attention_norm,
300+
"final_logit_soft_cap": self.final_logit_soft_cap,
301+
"attention_logit_soft_cap": self.attention_logit_soft_cap,
302+
"sliding_window_size": self.sliding_window_size,
303+
"use_sliding_window_attention": self.use_sliding_window_attention,
304+
"layer_norm_epsilon": self.layer_norm_epsilon,
305+
"dropout": self.dropout,
256306
}
257307
)
258308
return config

Diff for: keras_hub/src/models/pali_gemma/pali_gemma_backbone_test.py

+72-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def test_backbone_basics(self):
6161
8,
6262
),
6363
variable_length_data=[self.input_data],
64-
run_mixed_precision_check=False, # TODO: Set to `True`
6564
)
6665

6766
@pytest.mark.large
@@ -98,3 +97,75 @@ def test_all_presets(self):
9897
preset=preset,
9998
input_data=self.input_data,
10099
)
100+
101+
102+
class PaliGemma2BackboneTest(TestCase):
103+
def setUp(self):
104+
self.batch_size = 2
105+
self.vocabulary_size = 256
106+
self.text_sequence_length = 64
107+
self.image_size = 16
108+
self.image_sequence_length = int((self.image_size / 4) ** 2)
109+
self.init_kwargs = {
110+
"vocabulary_size": self.vocabulary_size,
111+
"image_size": self.image_size,
112+
"num_layers": 2,
113+
"num_query_heads": 2,
114+
"num_key_value_heads": 1,
115+
"hidden_dim": 8,
116+
"intermediate_dim": 16,
117+
"head_dim": 4,
118+
"vit_patch_size": 4,
119+
"vit_num_layers": 2,
120+
"vit_num_heads": 2,
121+
"vit_hidden_dim": 8,
122+
"vit_intermediate_dim": 16,
123+
# Gemma2
124+
"query_head_dim_normalize": True,
125+
"use_post_ffw_norm": True,
126+
"use_post_attention_norm": True,
127+
"final_logit_soft_cap": 30,
128+
"attention_logit_soft_cap": 50,
129+
"use_sliding_window_attention": True,
130+
"sliding_window_size": 4096,
131+
}
132+
133+
dummy_images = np.random.rand(
134+
self.batch_size, self.image_size, self.image_size, 3
135+
)
136+
dummy_text_token_ids = np.random.rand(
137+
self.batch_size, self.text_sequence_length
138+
)
139+
self.input_data = {
140+
"token_ids": dummy_text_token_ids,
141+
"images": dummy_images,
142+
"padding_mask": np.ones(
143+
(self.batch_size, self.text_sequence_length),
144+
dtype="int32",
145+
),
146+
"response_mask": np.zeros(
147+
(self.batch_size, self.text_sequence_length),
148+
dtype="int32",
149+
),
150+
}
151+
152+
def test_backbone_basics(self):
153+
self.run_backbone_test(
154+
cls=PaliGemmaBackbone,
155+
init_kwargs=self.init_kwargs,
156+
input_data=self.input_data,
157+
expected_output_shape=(
158+
self.batch_size,
159+
self.text_sequence_length + self.image_sequence_length,
160+
8,
161+
),
162+
variable_length_data=[self.input_data],
163+
)
164+
165+
@pytest.mark.large
166+
def test_saved_model(self):
167+
self.run_model_saving_test(
168+
cls=PaliGemmaBackbone,
169+
init_kwargs=self.init_kwargs,
170+
input_data=self.input_data,
171+
)

Diff for: keras_hub/src/models/pali_gemma/pali_gemma_decoder_block.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,25 @@ class PaliGemmaDecoderBlock(GemmaDecoderBlock):
3131
the attention layer.
3232
num_key_value_heads: int. The number of heads for the key and value
3333
projections in the attention layer.
34+
query_head_dim_normalize: boolean. If `True` normalize the query before
35+
attention with `head_dim`. If `False`, normalize the query with
36+
`hidden_dim / num_query_heads`. Defaults to `True`.
37+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
38+
block. Defaults to `False`.
39+
use_post_attention_norm: boolean. Whether to normalize after the
40+
attention block. Defaults to `False`.
41+
logit_soft_cap: `None` or int. Soft cap for the attention logits.
42+
Defaults to `None`.
43+
use_sliding_window_attention: boolean. Whether to use sliding local
44+
window attention. Defaults to `False`.
45+
sliding_window_size: int. Size of the sliding local window. Defaults to
46+
`4096`.
3447
layer_norm_epsilon: float. The epsilon hyperparameter used for layer
35-
normalization.
48+
normalization. Defaults to `1e-6`.
3649
dropout: float. The dropout rate for the transformer attention layer.
50+
Defaults to `0`.
3751
"""
3852

39-
def __init__(
40-
self,
41-
hidden_dim,
42-
intermediate_dim,
43-
head_dim,
44-
num_query_heads,
45-
num_key_value_heads,
46-
layer_norm_epsilon=1e-6,
47-
dropout=0,
48-
**kwargs,
49-
):
50-
super().__init__(
51-
hidden_dim=hidden_dim,
52-
intermediate_dim=intermediate_dim,
53-
head_dim=head_dim,
54-
num_query_heads=num_query_heads,
55-
num_key_value_heads=num_key_value_heads,
56-
layer_norm_epsilon=layer_norm_epsilon,
57-
dropout=dropout,
58-
**kwargs,
59-
)
60-
6153
def call(
6254
self,
6355
x,
@@ -83,6 +75,9 @@ def call(
8375
attention_mask=attention_mask,
8476
)
8577

78+
if self.use_post_attention_norm:
79+
attention = self.post_attention_norm(attention)
80+
8681
if self.dropout:
8782
attention = self.attention_dropout(attention)
8883

@@ -94,6 +89,9 @@ def call(
9489
x = keras.activations.gelu(x1, approximate=True) * x2
9590
x = self.ffw_linear(x)
9691

92+
if self.use_post_ffw_norm:
93+
x = self.post_ffw_norm(x)
94+
9795
x = x + attention_x
9896

9997
if cache is not None:

0 commit comments

Comments
 (0)