Skip to content

Commit 468cda2

Browse files
[Flax T5] Fix weight initialization and fix docs (#12327)
* finish t5 flax fixes * improve naming
1 parent 12a4457 commit 468cda2

File tree

1 file changed

+115
-85
lines changed

1 file changed

+115
-85
lines changed

src/transformers/models/t5/modeling_flax_t5.py

+115-85
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,21 @@ class FlaxT5DenseReluDense(nn.Module):
8484
dtype: jnp.dtype = jnp.float32
8585

8686
def setup(self):
87-
self.wi = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
88-
self.wo = nn.Dense(self.config.d_model, use_bias=False, dtype=self.dtype)
87+
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
88+
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
89+
90+
self.wi = nn.Dense(
91+
self.config.d_ff,
92+
use_bias=False,
93+
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
94+
dtype=self.dtype,
95+
)
96+
self.wo = nn.Dense(
97+
self.config.d_model,
98+
use_bias=False,
99+
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
100+
dtype=self.dtype,
101+
)
89102
self.dropout = nn.Dropout(self.config.dropout_rate)
90103

91104
def __call__(self, hidden_states, deterministic=True):
@@ -101,9 +114,27 @@ class FlaxT5DenseGatedGeluDense(nn.Module):
101114
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
102115

103116
def setup(self):
104-
self.wi_0 = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
105-
self.wi_1 = nn.Dense(self.config.d_ff, use_bias=False, dtype=self.dtype)
106-
self.wo = nn.Dense(self.config.d_model, use_bias=False, dtype=self.dtype)
117+
wi_init_std = self.config.initializer_factor * (self.config.d_model ** -0.5)
118+
wo_init_std = self.config.initializer_factor * (self.config.d_ff ** -0.5)
119+
120+
self.wi_0 = nn.Dense(
121+
self.config.d_ff,
122+
use_bias=False,
123+
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
124+
dtype=self.dtype,
125+
)
126+
self.wi_1 = nn.Dense(
127+
self.config.d_ff,
128+
use_bias=False,
129+
kernel_init=jax.nn.initializers.normal(wi_init_std, self.dtype),
130+
dtype=self.dtype,
131+
)
132+
self.wo = nn.Dense(
133+
self.config.d_model,
134+
use_bias=False,
135+
kernel_init=jax.nn.initializers.normal(wo_init_std, self.dtype),
136+
dtype=self.dtype,
137+
)
107138
self.dropout = nn.Dropout(self.config.dropout_rate)
108139
self.gelu_act = ACT2FN["gelu_new"]
109140

@@ -154,14 +185,40 @@ def setup(self):
154185
self.dropout = self.config.dropout_rate
155186
self.inner_dim = self.n_heads * self.key_value_proj_dim
156187

157-
self.q = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
158-
self.k = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
159-
self.v = nn.Dense(self.inner_dim, use_bias=False, dtype=self.dtype)
160-
self.o = nn.Dense(self.d_model, use_bias=False, dtype=self.dtype)
188+
inner_dim_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
189+
d_model_init_std = self.config.initializer_factor * (self.inner_dim ** -0.5)
190+
191+
self.q = nn.Dense(
192+
self.inner_dim,
193+
use_bias=False,
194+
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
195+
dtype=self.dtype,
196+
)
197+
self.k = nn.Dense(
198+
self.inner_dim,
199+
use_bias=False,
200+
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
201+
dtype=self.dtype,
202+
)
203+
self.v = nn.Dense(
204+
self.inner_dim,
205+
use_bias=False,
206+
kernel_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
207+
dtype=self.dtype,
208+
)
209+
self.o = nn.Dense(
210+
self.d_model,
211+
use_bias=False,
212+
kernel_init=jax.nn.initializers.normal(inner_dim_init_std, self.dtype),
213+
dtype=self.dtype,
214+
)
161215

162216
if self.has_relative_attention_bias:
163217
self.relative_attention_bias = nn.Embed(
164-
self.relative_attention_num_buckets, self.n_heads, dtype=self.dtype
218+
self.relative_attention_num_buckets,
219+
self.n_heads,
220+
embedding_init=jax.nn.initializers.normal(d_model_init_std, self.dtype),
221+
dtype=self.dtype,
165222
)
166223

167224
@staticmethod
@@ -246,7 +303,8 @@ def _concatenate_to_cache(self, key, value, query, attention_mask):
246303
cached_value.value = value
247304
num_updated_cache_vectors = query.shape[1]
248305
cache_index.value = cache_index.value + num_updated_cache_vectors
249-
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
306+
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions
307+
# that have already been generated and cached, not the remaining zero elements.
250308
pad_mask = jnp.broadcast_to(
251309
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
252310
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
@@ -488,7 +546,6 @@ def __call__(
488546
encoder_hidden_states=None,
489547
encoder_attention_mask=None,
490548
encoder_decoder_position_bias=None,
491-
cross_attn_layer_head_mask=None,
492549
output_attentions=False,
493550
return_dict=True,
494551
deterministic=True,
@@ -527,7 +584,9 @@ def __call__(
527584

528585
outputs = outputs + attention_outputs
529586

530-
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
587+
# returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
588+
# (cross-attention position bias), (cross-attention weights)
589+
return outputs
531590

532591

533592
class FlaxT5LayerCollection(nn.Module):
@@ -548,7 +607,6 @@ def __call__(
548607
encoder_hidden_states=None,
549608
encoder_attention_mask=None,
550609
encoder_decoder_position_bias=None,
551-
cross_attn_layer_head_mask=None,
552610
output_attentions=False,
553611
return_dict=True,
554612
deterministic=True,
@@ -713,7 +771,7 @@ def __call__(
713771

714772
T5_ENCODE_INPUTS_DOCSTRING = r"""
715773
Args:
716-
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
774+
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
717775
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
718776
should be able to pad the inputs on both the right and the left.
719777
@@ -723,23 +781,13 @@ def __call__(
723781
724782
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
725783
<./t5.html#training>`__.
726-
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
784+
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
727785
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
728786
729787
- 1 for tokens that are **not masked**,
730788
- 0 for tokens that are **masked**.
731789
732790
`What are attention masks? <../glossary.html#attention-mask>`__
733-
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
734-
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
735-
736-
- 1 indicates the head is **not masked**,
737-
- 0 indicates the head is **masked**.
738-
739-
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
740-
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
741-
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
742-
vectors than the model's internal embedding lookup matrix.
743791
output_attentions (:obj:`bool`, `optional`):
744792
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
745793
tensors for more detail.
@@ -838,7 +886,7 @@ def __call__(
838886
self,
839887
input_ids: jnp.ndarray,
840888
attention_mask: Optional[jnp.ndarray] = None,
841-
decoder_input_ids: Optional[jnp.ndarray] = None,
889+
decoder_input_ids: jnp.ndarray = None,
842890
decoder_attention_mask: Optional[jnp.ndarray] = None,
843891
output_attentions: Optional[bool] = None,
844892
output_hidden_states: Optional[bool] = None,
@@ -853,6 +901,11 @@ def __call__(
853901
)
854902
return_dict = return_dict if return_dict is not None else self.config.return_dict
855903

904+
if decoder_input_ids is None:
905+
raise ValueError(
906+
"Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here."
907+
)
908+
856909
# prepare encoder inputs
857910
if attention_mask is None:
858911
attention_mask = jnp.ones_like(input_ids)
@@ -1078,24 +1131,31 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
10781131
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text
10791132
denoising generative setting.
10801133
1081-
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
1082-
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
1083-
pruning heads etc.)
1134+
This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the
1135+
generic methods the library implements for all its model (such as downloading or saving, resizing the input
1136+
embeddings, pruning heads etc.)
10841137
1085-
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
1086-
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
1087-
general usage and behavior.
1138+
This model is also a Flax Linen `flax.nn.Module
1139+
<https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax
1140+
Module and refer to the Flax documentation for all matter related to general usage and behavior.
1141+
1142+
Finally, this model supports inherent JAX features such as:
1143+
1144+
- `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__
1145+
- `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__
1146+
- `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__
1147+
- `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__
10881148
10891149
Parameters:
10901150
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
10911151
Initializing with a config file does not load the weights associated with the model, only the
1092-
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
1093-
weights.
1152+
configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the
1153+
model weights.
10941154
"""
10951155

10961156
T5_INPUTS_DOCSTRING = r"""
10971157
Args:
1098-
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
1158+
input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`):
10991159
Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
11001160
should be able to pad the inputs on both the right and the left.
11011161
@@ -1107,14 +1167,14 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
11071167
11081168
To know more on how to prepare :obj:`input_ids` for pretraining take a look a `T5 Training
11091169
<./t5.html#training>`__.
1110-
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1170+
attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):
11111171
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
11121172
11131173
- 1 for tokens that are **not masked**,
11141174
- 0 for tokens that are **masked**.
11151175
11161176
`What are attention masks? <../glossary.html#attention-mask>`__
1117-
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1177+
decoder_input_ids (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
11181178
Indices of decoder input sequence tokens in the vocabulary.
11191179
11201180
Indices can be obtained using :class:`~transformers.T5Tokenizer`. See
@@ -1129,53 +1189,20 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs
11291189
11301190
To know more on how to prepare :obj:`decoder_input_ids` for pretraining take a look at `T5 Training
11311191
<./t5.html#training>`__.
1132-
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
1192+
decoder_attention_mask (:obj:`jnp.ndarray` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
11331193
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
11341194
also be used by default.
1135-
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1136-
Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in ``[0,
1137-
1]``:
1138-
1139-
- 1 indicates the head is **not masked**,
1140-
- 0 indicates the head is **masked**.
1141-
1142-
decoder_head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1143-
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in ``[0,
1144-
1]``:
1145-
1146-
- 1 indicates the head is **not masked**,
1147-
- 0 indicates the head is **masked**.
1148-
1149-
cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
1150-
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1151-
``[0, 1]``:
1152-
1153-
- 1 indicates the head is **not masked**,
1154-
- 0 indicates the head is **masked**.
1155-
1156-
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
1195+
encoder_outputs (:obj:`tuple(tuple(jnp.ndarray)`, `optional`):
11571196
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`:
11581197
`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a
11591198
sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of
11601199
the decoder.
1161-
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1200+
past_key_values (:obj:`tuple(tuple(jnp.ndarray))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
11621201
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
11631202
11641203
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
11651204
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
11661205
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1167-
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1168-
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
1169-
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
1170-
vectors than the model's internal embedding lookup matrix.
1171-
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
1172-
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
1173-
representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds`
1174-
have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert
1175-
:obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1176-
1177-
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds`
1178-
takes the value of :obj:`inputs_embeds`.
11791206
11801207
use_cache (:obj:`bool`, `optional`):
11811208
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
@@ -1242,7 +1269,7 @@ def __call__(
12421269
12431270
Example::
12441271
1245-
>>> from transformers import T5Tokenizer, T5Model
1272+
>>> from transformers import T5Tokenizer, FlaxT5Model
12461273
12471274
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
12481275
>>> model = FlaxT5Model.from_pretrained('t5-small')
@@ -1310,7 +1337,11 @@ def _get_decoder_module(self):
13101337
def setup(self):
13111338
self.model_dim = self.config.d_model
13121339

1313-
self.shared = nn.Embed(self.config.vocab_size, self.config.d_model)
1340+
self.shared = nn.Embed(
1341+
self.config.vocab_size,
1342+
self.config.d_model,
1343+
embedding_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
1344+
)
13141345

13151346
encoder_config = copy.deepcopy(self.config)
13161347
encoder_config.causal = False
@@ -1324,13 +1355,12 @@ def setup(self):
13241355
decoder_config.num_layers = self.config.num_decoder_layers
13251356
self.decoder = FlaxT5Stack(decoder_config, self.shared)
13261357

1327-
self.lm_head = nn.Dense(self.config.vocab_size, use_bias=False)
1328-
1329-
def get_encoder(self):
1330-
return self.encoder
1331-
1332-
def get_decoder(self):
1333-
return self.decoder
1358+
self.lm_head = nn.Dense(
1359+
self.config.vocab_size,
1360+
use_bias=False,
1361+
kernel_init=jax.nn.initializers.normal(self.config.initializer_factor, self.dtype),
1362+
dtype=self.dtype,
1363+
)
13341364

13351365
@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
13361366
@replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@@ -1361,12 +1391,12 @@ def __call__(
13611391
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
13621392
>>> logits = outputs.logits
13631393
1364-
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids # Batch size 1
1394+
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="np").input_ids
13651395
>>> outputs = model.generate(input_ids)
13661396
"""
13671397
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
13681398

1369-
# Encode if needed (training, first prediction pass)
1399+
# Encode
13701400
encoder_outputs = self.encoder(
13711401
input_ids=input_ids,
13721402
attention_mask=attention_mask,

0 commit comments

Comments
 (0)