@@ -84,8 +84,21 @@ class FlaxT5DenseReluDense(nn.Module):
84
84
dtype : jnp .dtype = jnp .float32
85
85
86
86
def setup (self ):
87
- self .wi = nn .Dense (self .config .d_ff , use_bias = False , dtype = self .dtype )
88
- self .wo = nn .Dense (self .config .d_model , use_bias = False , dtype = self .dtype )
87
+ wi_init_std = self .config .initializer_factor * (self .config .d_model ** - 0.5 )
88
+ wo_init_std = self .config .initializer_factor * (self .config .d_ff ** - 0.5 )
89
+
90
+ self .wi = nn .Dense (
91
+ self .config .d_ff ,
92
+ use_bias = False ,
93
+ kernel_init = jax .nn .initializers .normal (wi_init_std , self .dtype ),
94
+ dtype = self .dtype ,
95
+ )
96
+ self .wo = nn .Dense (
97
+ self .config .d_model ,
98
+ use_bias = False ,
99
+ kernel_init = jax .nn .initializers .normal (wo_init_std , self .dtype ),
100
+ dtype = self .dtype ,
101
+ )
89
102
self .dropout = nn .Dropout (self .config .dropout_rate )
90
103
91
104
def __call__ (self , hidden_states , deterministic = True ):
@@ -101,9 +114,27 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
101
114
dtype : jnp .dtype = jnp .float32 # the dtype of the computation
102
115
103
116
def setup (self ):
104
- self .wi_0 = nn .Dense (self .config .d_ff , use_bias = False , dtype = self .dtype )
105
- self .wi_1 = nn .Dense (self .config .d_ff , use_bias = False , dtype = self .dtype )
106
- self .wo = nn .Dense (self .config .d_model , use_bias = False , dtype = self .dtype )
117
+ wi_init_std = self .config .initializer_factor * (self .config .d_model ** - 0.5 )
118
+ wo_init_std = self .config .initializer_factor * (self .config .d_ff ** - 0.5 )
119
+
120
+ self .wi_0 = nn .Dense (
121
+ self .config .d_ff ,
122
+ use_bias = False ,
123
+ kernel_init = jax .nn .initializers .normal (wi_init_std , self .dtype ),
124
+ dtype = self .dtype ,
125
+ )
126
+ self .wi_1 = nn .Dense (
127
+ self .config .d_ff ,
128
+ use_bias = False ,
129
+ kernel_init = jax .nn .initializers .normal (wi_init_std , self .dtype ),
130
+ dtype = self .dtype ,
131
+ )
132
+ self .wo = nn .Dense (
133
+ self .config .d_model ,
134
+ use_bias = False ,
135
+ kernel_init = jax .nn .initializers .normal (wo_init_std , self .dtype ),
136
+ dtype = self .dtype ,
137
+ )
107
138
self .dropout = nn .Dropout (self .config .dropout_rate )
108
139
self .gelu_act = ACT2FN ["gelu_new" ]
109
140
@@ -154,14 +185,40 @@ def setup(self):
154
185
self .dropout = self .config .dropout_rate
155
186
self .inner_dim = self .n_heads * self .key_value_proj_dim
156
187
157
- self .q = nn .Dense (self .inner_dim , use_bias = False , dtype = self .dtype )
158
- self .k = nn .Dense (self .inner_dim , use_bias = False , dtype = self .dtype )
159
- self .v = nn .Dense (self .inner_dim , use_bias = False , dtype = self .dtype )
160
- self .o = nn .Dense (self .d_model , use_bias = False , dtype = self .dtype )
188
+ inner_dim_init_std = self .config .initializer_factor * (self .inner_dim ** - 0.5 )
189
+ d_model_init_std = self .config .initializer_factor * (self .inner_dim ** - 0.5 )
190
+
191
+ self .q = nn .Dense (
192
+ self .inner_dim ,
193
+ use_bias = False ,
194
+ kernel_init = jax .nn .initializers .normal (d_model_init_std , self .dtype ),
195
+ dtype = self .dtype ,
196
+ )
197
+ self .k = nn .Dense (
198
+ self .inner_dim ,
199
+ use_bias = False ,
200
+ kernel_init = jax .nn .initializers .normal (d_model_init_std , self .dtype ),
201
+ dtype = self .dtype ,
202
+ )
203
+ self .v = nn .Dense (
204
+ self .inner_dim ,
205
+ use_bias = False ,
206
+ kernel_init = jax .nn .initializers .normal (d_model_init_std , self .dtype ),
207
+ dtype = self .dtype ,
208
+ )
209
+ self .o = nn .Dense (
210
+ self .d_model ,
211
+ use_bias = False ,
212
+ kernel_init = jax .nn .initializers .normal (inner_dim_init_std , self .dtype ),
213
+ dtype = self .dtype ,
214
+ )
161
215
162
216
if self .has_relative_attention_bias :
163
217
self .relative_attention_bias = nn .Embed (
164
- self .relative_attention_num_buckets , self .n_heads , dtype = self .dtype
218
+ self .relative_attention_num_buckets ,
219
+ self .n_heads ,
220
+ embedding_init = jax .nn .initializers .normal (d_model_init_std , self .dtype ),
221
+ dtype = self .dtype ,
165
222
)
166
223
167
224
@staticmethod
@@ -246,7 +303,8 @@ def _concatenate_to_cache(self, key, value, query, attention_mask):
246
303
cached_value .value = value
247
304
num_updated_cache_vectors = query .shape [1 ]
248
305
cache_index .value = cache_index .value + num_updated_cache_vectors
249
- # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
306
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions
307
+ # that have already been generated and cached, not the remaining zero elements.
250
308
pad_mask = jnp .broadcast_to (
251
309
jnp .arange (max_length ) < cur_index + num_updated_cache_vectors ,
252
310
tuple (batch_dims ) + (1 , num_updated_cache_vectors , max_length ),
@@ -488,7 +546,6 @@ def __call__(
488
546
encoder_hidden_states = None ,
489
547
encoder_attention_mask = None ,
490
548
encoder_decoder_position_bias = None ,
491
- cross_attn_layer_head_mask = None ,
492
549
output_attentions = False ,
493
550
return_dict = True ,
494
551
deterministic = True ,
@@ -527,7 +584,9 @@ def __call__(
527
584
528
585
outputs = outputs + attention_outputs
529
586
530
- return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
587
+ # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
588
+ # (cross-attention position bias), (cross-attention weights)
589
+ return outputs
531
590
532
591
533
592
class FlaxT5LayerCollection (nn .Module ):
@@ -548,7 +607,6 @@ def __call__(
548
607
encoder_hidden_states = None ,
549
608
encoder_attention_mask = None ,
550
609
encoder_decoder_position_bias = None ,
551
- cross_attn_layer_head_mask = None ,
552
610
output_attentions = False ,
553
611
return_dict = True ,
554
612
deterministic = True ,
@@ -713,7 +771,7 @@ def __call__(
713
771
714
772
T5_ENCODE_INPUTS_DOCSTRING = r"""
715
773
Args:
716
- input_ids (:obj:`torch.LongTensor ` of shape :obj:`(batch_size, sequence_length)`):
774
+ input_ids (:obj:`jnp.ndarray ` of shape :obj:`(batch_size, sequence_length)`):
717
775
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
718
776
should be able to pad the inputs on both the right and the left.
719
777
@@ -723,23 +781,13 @@ def __call__(
723
781
724
782
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
725
783
<./t5.html#training>`__.
726
- attention_mask (:obj:`torch.FloatTensor ` of shape :obj:`(batch_size, sequence_length)`, `optional`):
784
+ attention_mask (:obj:`jnp.ndarray ` of shape :obj:`(batch_size, sequence_length)`, `optional`):
727
785
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
728
786
729
787
- 1 for tokens that are **not masked**,
730
788
- 0 for tokens that are **masked**.
731
789
732
790
`What are attention masks? <../glossary.html#attention-mask>`__
733
- head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
734
- Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
735
-
736
- - 1 indicates the head is **not masked**,
737
- - 0 indicates the head is **masked**.
738
-
739
- inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
740
- Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
741
- This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
742
- vectors than the model's internal embedding lookup matrix.
743
791
output_attentions (:obj:`bool`, `optional`):
744
792
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
745
793
tensors for more detail.
@@ -838,7 +886,7 @@ def __call__(
838
886
self ,
839
887
input_ids : jnp .ndarray ,
840
888
attention_mask : Optional [jnp .ndarray ] = None ,
841
- decoder_input_ids : Optional [ jnp .ndarray ] = None ,
889
+ decoder_input_ids : jnp .ndarray = None ,
842
890
decoder_attention_mask : Optional [jnp .ndarray ] = None ,
843
891
output_attentions : Optional [bool ] = None ,
844
892
output_hidden_states : Optional [bool ] = None ,
@@ -853,6 +901,11 @@ def __call__(
853
901
)
854
902
return_dict = return_dict if return_dict is not None else self .config .return_dict
855
903
904
+ if decoder_input_ids is None :
905
+ raise ValueError (
906
+ "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
907
+ )
908
+
856
909
# prepare encoder inputs
857
910
if attention_mask is None :
858
911
attention_mask = jnp .ones_like (input_ids )
@@ -1078,24 +1131,31 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
1078
1131
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
1079
1132
denoising generative setting.
1080
1133
1081
- This model inherits from :class:`~transformers.PreTrainedModel `. Check the superclass documentation for the generic
1082
- methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
1083
- pruning heads etc.)
1134
+ This model inherits from :class:`~transformers.FlaxPreTrainedModel `. Check the superclass documentation for the
1135
+ generic methods the library implements for all its model (such as downloading or saving, resizing the input
1136
+ embeddings, pruning heads etc.)
1084
1137
1085
- This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
1086
- subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
1087
- general usage and behavior.
1138
+ This model is also a Flax Linen `flax.nn.Module
1139
+ <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
1140
+ Module and refer to the Flax documentation for all matter related to general usage and behavior.
1141
+
1142
+ Finally, this model supports inherent JAX features such as:
1143
+
1144
+ - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
1145
+ - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
1146
+ - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
1147
+ - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
1088
1148
1089
1149
Parameters:
1090
1150
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
1091
1151
Initializing with a config file does not load the weights associated with the model, only the
1092
- configuration. Check out the :meth:`~transformers.PreTrainedModel .from_pretrained` method to load the model
1093
- weights.
1152
+ configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel .from_pretrained` method to load the
1153
+ model weights.
1094
1154
"""
1095
1155
1096
1156
T5_INPUTS_DOCSTRING = r"""
1097
1157
Args:
1098
- input_ids (:obj:`torch.LongTensor ` of shape :obj:`(batch_size, sequence_length)`):
1158
+ input_ids (:obj:`jnp.ndarray ` of shape :obj:`(batch_size, sequence_length)`):
1099
1159
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1100
1160
should be able to pad the inputs on both the right and the left.
1101
1161
@@ -1107,14 +1167,14 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
1107
1167
1108
1168
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
1109
1169
<./t5.html#training>`__.
1110
- attention_mask (:obj:`torch.FloatTensor ` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1170
+ attention_mask (:obj:`jnp.ndarray ` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1111
1171
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
1112
1172
1113
1173
- 1 for tokens that are **not masked**,
1114
1174
- 0 for tokens that are **masked**.
1115
1175
1116
1176
`What are attention masks? <../glossary.html#attention-mask>`__
1117
- decoder_input_ids (:obj:`torch.LongTensor ` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1177
+ decoder_input_ids (:obj:`jnp.ndarray ` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1118
1178
Indices of decoder input sequence tokens in the vocabulary.
1119
1179
1120
1180
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
@@ -1129,53 +1189,20 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
1129
1189
1130
1190
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
1131
1191
<./t5.html#training>`__.
1132
- decoder_attention_mask (:obj:`torch.BoolTensor ` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1192
+ decoder_attention_mask (:obj:`jnp.ndarray ` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1133
1193
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
1134
1194
also be used by default.
1135
- head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1136
- Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1137
- 1]``:
1138
-
1139
- - 1 indicates the head is **not masked**,
1140
- - 0 indicates the head is **masked**.
1141
-
1142
- decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1143
- Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1144
- 1]``:
1145
-
1146
- - 1 indicates the head is **not masked**,
1147
- - 0 indicates the head is **masked**.
1148
-
1149
- cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1150
- Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1151
- ``[0, 1]``:
1152
-
1153
- - 1 indicates the head is **not masked**,
1154
- - 0 indicates the head is **masked**.
1155
-
1156
- encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
1195
+ encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
1157
1196
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
1158
1197
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
1159
1198
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
1160
1199
the decoder.
1161
- past_key_values (:obj:`tuple(tuple(torch.FloatTensor ))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1200
+ past_key_values (:obj:`tuple(tuple(jnp.ndarray ))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1162
1201
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1163
1202
1164
1203
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1165
1204
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1166
1205
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1167
- inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1168
- Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
1169
- This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
1170
- vectors than the model's internal embedding lookup matrix.
1171
- decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
1172
- Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
1173
- representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
1174
- have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
1175
- :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1176
-
1177
- If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
1178
- takes the value of :obj:`inputs_embeds`.
1179
1206
1180
1207
use_cache (:obj:`bool`, `optional`):
1181
1208
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
@@ -1242,7 +1269,7 @@ def __call__(
1242
1269
1243
1270
Example::
1244
1271
1245
- >>> from transformers import T5Tokenizer, T5Model
1272
+ >>> from transformers import T5Tokenizer, FlaxT5Model
1246
1273
1247
1274
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
1248
1275
>>> model = FlaxT5Model.from_pretrained('t5-small')
@@ -1310,7 +1337,11 @@ def _get_decoder_module(self):
1310
1337
def setup (self ):
1311
1338
self .model_dim = self .config .d_model
1312
1339
1313
- self .shared = nn .Embed (self .config .vocab_size , self .config .d_model )
1340
+ self .shared = nn .Embed (
1341
+ self .config .vocab_size ,
1342
+ self .config .d_model ,
1343
+ embedding_init = jax .nn .initializers .normal (self .config .initializer_factor , self .dtype ),
1344
+ )
1314
1345
1315
1346
encoder_config = copy .deepcopy (self .config )
1316
1347
encoder_config .causal = False
@@ -1324,13 +1355,12 @@ def setup(self):
1324
1355
decoder_config .num_layers = self .config .num_decoder_layers
1325
1356
self .decoder = FlaxT5Stack (decoder_config , self .shared )
1326
1357
1327
- self .lm_head = nn .Dense (self .config .vocab_size , use_bias = False )
1328
-
1329
- def get_encoder (self ):
1330
- return self .encoder
1331
-
1332
- def get_decoder (self ):
1333
- return self .decoder
1358
+ self .lm_head = nn .Dense (
1359
+ self .config .vocab_size ,
1360
+ use_bias = False ,
1361
+ kernel_init = jax .nn .initializers .normal (self .config .initializer_factor , self .dtype ),
1362
+ dtype = self .dtype ,
1363
+ )
1334
1364
1335
1365
@add_start_docstrings_to_model_forward (T5_INPUTS_DOCSTRING )
1336
1366
@replace_return_docstrings (output_type = FlaxSeq2SeqLMOutput , config_class = _CONFIG_FOR_DOC )
@@ -1361,12 +1391,12 @@ def __call__(
1361
1391
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1362
1392
>>> logits = outputs.logits
1363
1393
1364
- >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids # Batch size 1
1394
+ >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
1365
1395
>>> outputs = model.generate(input_ids)
1366
1396
"""
1367
1397
return_dict = return_dict if return_dict is not None else self .config .use_return_dict
1368
1398
1369
- # Encode if needed (training, first prediction pass)
1399
+ # Encode
1370
1400
encoder_outputs = self .encoder (
1371
1401
input_ids = input_ids ,
1372
1402
attention_mask = attention_mask ,
0 commit comments