@@ -48,22 +48,40 @@ class PaliGemmaBackbone(Backbone):
48
48
a two-layer feedforward network for each transformer decoder block.
49
49
head_dim: int. The size of each attention head in the mixed decoder.
50
50
vit_patch_size: int. The size of each square patch in the input image.
51
- vit_num_heads: int. The number of attention heads for the vision(image)
51
+ vit_num_heads: int. The number of attention heads for the vision (image)
52
52
transformer encoder.
53
53
vit_hidden_dim: int. The size of the transformer hidden state at the end
54
54
of each vision transformer layer.
55
55
vit_num_layers: int. The number of vision transformer layers.
56
56
vit_intermediate_dim: int. The output dimension of the first Dense layer
57
- in a two-layer feedforward network for vision transformer.
58
- vit_pooling: string. The encoded vision embeddings are pooled using the
59
- specified polling setting. The accepted values are `"map"`, `"gap"`,
60
- `"0"` or `"none"`. Defaults to `"none"`.
57
+ in a two-layer feedforward network for vision transformer. Defaults
58
+ to `4304`.
59
+ vit_pooling: `None` or string. The encoded vision embeddings are pooled
60
+ using the specified polling setting. The accepted values are
61
+ `"map"`, `"gap"`, `"0"` or `None`. Defaults to `None`.
61
62
vit_classifier_activation: activation function. The activation that
62
63
is used for final output classification in the vision transformer.
64
+ Defaults to `None`.
63
65
vit_name: string. The name used for vision transformer layers.
66
+ query_head_dim_normalize: boolean. If `True` normalize the query before
67
+ attention with `head_dim`. If `False`, normalize the query with
68
+ `hidden_dim / num_query_heads`. Defaults to `True`.
69
+ use_post_ffw_norm: boolean. Whether to normalize after the feedforward
70
+ block. Defaults to `False`.
71
+ use_post_attention_norm: boolean. Whether to normalize after the attention
72
+ block. Defaults to `False`.
73
+ attention_logit_soft_cap: `None` or int. Soft cap for the attention
74
+ logits. Defaults to `None`.
75
+ final_logit_soft_cap: `None` or int. Soft cap for the final logits.
76
+ Defaults to `None`.
77
+ use_sliding_window_attention: boolean. Whether to use sliding local
78
+ window attention. Defaults to `False`.
79
+ sliding_window_size: int. Size of the sliding local window. Defaults to
80
+ `4096`.
64
81
layer_norm_epsilon: float. The epsilon value user for every layer norm
65
- in all transformer blocks.
82
+ in all transformer blocks. Defaults to `1e-6`.
66
83
dropout: float. Dropout probability for the Transformer decoder blocks.
84
+ Defaults to `0`.
67
85
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
68
86
for the models computations and weights. Note that some
69
87
computations, such as softmax and layer normalization will always
@@ -119,6 +137,13 @@ def __init__(
119
137
vit_pooling = None ,
120
138
vit_classifier_activation = None ,
121
139
vit_name = None ,
140
+ query_head_dim_normalize = True ,
141
+ use_post_ffw_norm = False ,
142
+ use_post_attention_norm = False ,
143
+ attention_logit_soft_cap = None ,
144
+ final_logit_soft_cap = None ,
145
+ use_sliding_window_attention = False ,
146
+ sliding_window_size = 4096 ,
122
147
layer_norm_epsilon = 1e-6 ,
123
148
dropout = 0 ,
124
149
dtype = None ,
@@ -136,6 +161,7 @@ def __init__(
136
161
seed = None ,
137
162
),
138
163
dtype = dtype ,
164
+ logit_soft_cap = final_logit_soft_cap ,
139
165
name = "token_embedding" ,
140
166
)
141
167
# TODO Remove this. Work around for previous serialization bug.
@@ -155,12 +181,19 @@ def __init__(
155
181
)
156
182
self .transformer_layers = []
157
183
for i in range (num_layers ):
184
+ sliding_window = use_sliding_window_attention and (i % 2 == 0 )
158
185
layer = PaliGemmaDecoderBlock (
159
186
hidden_dim = hidden_dim ,
160
187
intermediate_dim = intermediate_dim ,
161
- num_query_heads = num_query_heads ,
162
188
head_dim = head_dim ,
189
+ num_query_heads = num_query_heads ,
163
190
num_key_value_heads = num_key_value_heads ,
191
+ query_head_dim_normalize = query_head_dim_normalize ,
192
+ use_post_ffw_norm = use_post_ffw_norm ,
193
+ use_post_attention_norm = use_post_attention_norm ,
194
+ logit_soft_cap = attention_logit_soft_cap ,
195
+ use_sliding_window_attention = sliding_window ,
196
+ sliding_window_size = sliding_window_size ,
164
197
dropout = dropout ,
165
198
dtype = dtype ,
166
199
name = f"decoder_block_{ i } " ,
@@ -173,7 +206,9 @@ def __init__(
173
206
)
174
207
175
208
# === Functional Model ===
176
- image_input = self .vit_encoder .inputs [0 ]
209
+ image_input = keras .Input (
210
+ shape = (image_size , image_size , 3 ), name = "images"
211
+ )
177
212
token_id_input = keras .Input (
178
213
shape = (None ,), dtype = "int32" , name = "token_ids"
179
214
)
@@ -219,7 +254,15 @@ def __init__(
219
254
self .head_dim = head_dim
220
255
self .layer_norm_epsilon = layer_norm_epsilon
221
256
self .dropout = dropout
222
- # VIT Params
257
+ # Gemma2 params
258
+ self .query_head_dim_normalize = query_head_dim_normalize
259
+ self .use_post_ffw_norm = use_post_ffw_norm
260
+ self .use_post_attention_norm = use_post_attention_norm
261
+ self .attention_logit_soft_cap = attention_logit_soft_cap
262
+ self .final_logit_soft_cap = final_logit_soft_cap
263
+ self .sliding_window_size = sliding_window_size
264
+ self .use_sliding_window_attention = use_sliding_window_attention
265
+ # ViT params
223
266
self .vit_patch_size = vit_patch_size
224
267
self .vit_num_heads = vit_num_heads
225
268
self .vit_hidden_dim = vit_hidden_dim
@@ -243,8 +286,6 @@ def get_config(self):
243
286
"hidden_dim" : self .hidden_dim ,
244
287
"intermediate_dim" : self .intermediate_dim ,
245
288
"head_dim" : self .head_dim ,
246
- "layer_norm_epsilon" : self .layer_norm_epsilon ,
247
- "dropout" : self .dropout ,
248
289
"vit_patch_size" : self .vit_patch_size ,
249
290
"vit_num_heads" : self .vit_num_heads ,
250
291
"vit_hidden_dim" : self .vit_hidden_dim ,
@@ -253,6 +294,15 @@ def get_config(self):
253
294
"vit_pooling" : self .vit_pooling ,
254
295
"vit_classifier_activation" : self .vit_classifier_activation ,
255
296
"vit_name" : self .vit_name ,
297
+ "query_head_dim_normalize" : self .query_head_dim_normalize ,
298
+ "use_post_ffw_norm" : self .use_post_ffw_norm ,
299
+ "use_post_attention_norm" : self .use_post_attention_norm ,
300
+ "final_logit_soft_cap" : self .final_logit_soft_cap ,
301
+ "attention_logit_soft_cap" : self .attention_logit_soft_cap ,
302
+ "sliding_window_size" : self .sliding_window_size ,
303
+ "use_sliding_window_attention" : self .use_sliding_window_attention ,
304
+ "layer_norm_epsilon" : self .layer_norm_epsilon ,
305
+ "dropout" : self .dropout ,
256
306
}
257
307
)
258
308
return config
0 commit comments