Skip to content

Commit 6910316

Browse files
committed
this wont work, will fix step by step
1 parent 676c828 commit 6910316

29 files changed

+137
-155
lines changed

src/transformers/models/bark/modeling_bark.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
BarkEosPrioritizerLogitsProcessor,
3131
SuppressTokensLogitsProcessor,
3232
)
33-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
33+
from ...masking_utils import create_bidirectional_mask
3434
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
3535
from ...modeling_layers import GradientCheckpointingLayer
3636
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
@@ -464,7 +464,6 @@ def forward(
464464
raise ValueError("You have to specify either input_ids or input_embeds")
465465

466466
input_shape = input_embeds.size()[:-1]
467-
batch_size = input_embeds.shape[0]
468467
seq_length = input_shape[-1]
469468

470469
device = input_ids.device if input_ids is not None else input_embeds.device
@@ -487,17 +486,11 @@ def forward(
487486

488487
position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
489488

490-
# Attention mask.
491-
if attention_mask is not None:
492-
if batch_size <= 0:
493-
raise ValueError("batch_size has to be defined and > 0")
494-
if self.config._attn_implementation == "flash_attention_2":
495-
attention_mask = attention_mask if 0 in attention_mask else None
496-
else:
497-
attention_mask = attention_mask.view(batch_size, -1)
498-
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
499-
# from_seq_length is 1 to easily broadcast
500-
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
489+
attention_mask = create_bidirectional_mask(
490+
config=self.config,
491+
input_embeds=input_embeds,
492+
attention_mask=attention_mask,
493+
)
501494

502495
hidden_states = self.drop(input_embeds + position_embeds)
503496
output_shape = input_shape + (hidden_states.size(-1),)
@@ -1074,7 +1067,6 @@ def forward(
10741067
input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
10751068

10761069
input_shape = input_embeds.size()[:-1]
1077-
batch_size = input_embeds.shape[0]
10781070
seq_length = input_shape[1]
10791071

10801072
device = input_ids.device if input_ids is not None else input_embeds.device
@@ -1085,16 +1077,11 @@ def forward(
10851077

10861078
position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
10871079

1088-
# Attention mask.
1089-
if attention_mask is not None:
1090-
if batch_size <= 0:
1091-
raise ValueError("batch_size has to be defined and > 0")
1092-
if self.config._attn_implementation == "flash_attention_2":
1093-
attention_mask = attention_mask if 0 in attention_mask else None
1094-
else:
1095-
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
1096-
# from_seq_length is 1 to easily broadcast
1097-
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
1080+
attention_mask = create_bidirectional_mask(
1081+
config=self.config,
1082+
input_embeds=input_embeds,
1083+
attention_mask=attention_mask,
1084+
)
10981085

10991086
hidden_states = self.drop(input_embeds + position_embeds)
11001087
output_shape = input_shape + (hidden_states.size(-1),)

src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2828
from ...generation import GenerationMixin
2929
from ...masking_utils import create_bidirectional_mask, create_causal_mask
30-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
3130
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3231
from ...modeling_layers import GradientCheckpointingLayer
3332
from ...modeling_outputs import (
@@ -1686,7 +1685,11 @@ def forward(
16861685
# expand attention_mask
16871686
if self.attention_type == "original_full":
16881687
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1689-
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1688+
attention_mask = create_bidirectional_mask(
1689+
config=self.config,
1690+
input_embeds=inputs_embeds,
1691+
attention_mask=attention_mask,
1692+
)
16901693
blocked_encoder_mask = band_mask = from_mask = to_mask = None
16911694
elif self.attention_type == "block_sparse":
16921695
blocked_encoder_mask, band_mask, from_mask, to_mask = self.create_masks_for_block_sparse_attn(

src/transformers/models/clvp/modeling_clvp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...activations import ACT2FN, get_activation
3030
from ...cache_utils import Cache, DynamicCache
3131
from ...generation import GenerationConfig, GenerationMixin
32-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
32+
from ...masking_utils import create_bidirectional_mask, _prepare_4d_causal_attention_mask
3333
from ...modeling_outputs import (
3434
BaseModelOutput,
3535
BaseModelOutputWithPastAndCrossAttentions,
@@ -911,9 +911,11 @@ def forward(
911911
raise ValueError("You have to specify either input_ids or inputs_embeds")
912912

913913
# expand attention_mask and create position_ids if needed
914-
if attention_mask is not None:
915-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
916-
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
914+
attention_mask = create_bidirectional_mask(
915+
config=self.config,
916+
input_embeds=inputs_embeds,
917+
attention_mask=attention_mask,
918+
)
917919

918920
if position_ids is None:
919921
device = input_ids.device if input_ids is not None else inputs_embeds.device

src/transformers/models/conditional_detr/modeling_conditional_detr.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from ... import initialization as init
2525
from ...activations import ACT2FN
26-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26+
from ...masking_utils import create_bidirectional_mask
2727
from ...modeling_layers import GradientCheckpointingLayer
2828
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
2929
from ...modeling_utils import PreTrainedModel
@@ -1068,10 +1068,11 @@ def forward(
10681068
hidden_states = inputs_embeds
10691069
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
10701070

1071-
# expand attention_mask
1072-
if attention_mask is not None:
1073-
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1074-
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1071+
attention_mask = create_bidirectional_mask(
1072+
config=self.config,
1073+
input_embeds=inputs_embeds,
1074+
attention_mask=attention_mask,
1075+
)
10751076

10761077
encoder_states = () if output_hidden_states else None
10771078
all_attentions = () if output_attentions else None
@@ -1202,14 +1203,12 @@ def forward(
12021203

12031204
if inputs_embeds is not None:
12041205
hidden_states = inputs_embeds
1205-
input_shape = inputs_embeds.size()[:-1]
12061206

1207-
# expand encoder attention mask
1208-
if encoder_hidden_states is not None and encoder_attention_mask is not None:
1209-
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
1210-
encoder_attention_mask = _prepare_4d_attention_mask(
1211-
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1212-
)
1207+
attention_mask = create_bidirectional_mask(
1208+
config=self.config,
1209+
input_embeds=inputs_embeds,
1210+
attention_mask=attention_mask,
1211+
)
12131212

12141213
# optional intermediate hidden states
12151214
intermediate = () if self.config.auxiliary_loss else None

src/transformers/models/dab_detr/modeling_dab_detr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from ... import initialization as init
2525
from ...activations import ACT2FN
26-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26+
from ...masking_utils import create_bidirectional_mask
2727
from ...modeling_layers import GradientCheckpointingLayer
2828
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
2929
from ...modeling_utils import PreTrainedModel
@@ -921,10 +921,11 @@ def forward(
921921

922922
hidden_states = inputs_embeds
923923

924-
# expand attention_mask
925-
if attention_mask is not None:
926-
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
927-
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
924+
attention_mask = create_bidirectional_mask(
925+
config=self.config,
926+
input_embeds=inputs_embeds,
927+
attention_mask=attention_mask,
928+
)
928929

929930
encoder_states = () if output_hidden_states else None
930931
all_attentions = () if output_attentions else None

src/transformers/models/deformable_detr/modeling_deformable_detr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ... import initialization as init
2727
from ...activations import ACT2FN
2828
from ...integrations import use_kernel_forward_from_hub
29-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
29+
from ...masking_utils import create_bidirectional_mask
3030
from ...modeling_layers import GradientCheckpointingLayer
3131
from ...modeling_outputs import BaseModelOutput
3232
from ...modeling_utils import PreTrainedModel

src/transformers/models/detr/modeling_detr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from ... import initialization as init
2525
from ...activations import ACT2FN
26-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
26+
from ...masking_utils import create_bidirectional_mask
2727
from ...modeling_layers import GradientCheckpointingLayer
2828
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput
2929
from ...modeling_utils import PreTrainedModel
@@ -824,10 +824,11 @@ def forward(
824824
hidden_states = inputs_embeds
825825
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
826826

827-
# expand attention_mask
828-
if attention_mask is not None:
829-
# [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len]
830-
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
827+
attention_mask = create_bidirectional_mask(
828+
config=self.config,
829+
input_embeds=inputs_embeds,
830+
attention_mask=attention_mask,
831+
)
831832

832833
encoder_states = () if output_hidden_states else None
833834
all_attentions = () if output_attentions else None

src/transformers/models/git/modeling_git.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from ...activations import ACT2FN
2828
from ...cache_utils import Cache, DynamicCache
2929
from ...generation import GenerationMixin
30-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
30+
from ...masking_utils import create_bidirectional_mask
3131
from ...modeling_layers import GradientCheckpointingLayer
3232
from ...modeling_outputs import (
3333
BaseModelOutput,
@@ -1087,9 +1087,11 @@ def forward(
10871087
if attention_mask is not None:
10881088
# if the user provides an attention mask, we add it to the default one
10891089
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1090-
expanded_attn_mask = _prepare_4d_attention_mask(
1091-
attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
1092-
).to(embedding_output.device)
1090+
expanded_attn_mask = create_bidirectional_mask(
1091+
config=self.config,
1092+
input_embeds=inputs_embeds,
1093+
attention_mask=attention_mask,
1094+
)
10931095
if past_key_values_length > 0:
10941096
expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
10951097
else:

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...activations import ACT2FN
2626
from ...cache_utils import Cache, DynamicCache
2727
from ...generation import GenerationMixin
28-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
28+
from ...masking_utils import create_bidirectional_mask
2929
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3030
from ...modeling_layers import GradientCheckpointingLayer
3131
from ...modeling_outputs import BaseModelOutput, ModelOutput
@@ -485,13 +485,11 @@ def forward(
485485
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
486486

487487
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
488-
# The call to `_upad_input` in `_flash_attention_forward` is expensive
489-
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
490-
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
491-
if not torch.any(~patch_attention_mask):
492-
patch_attention_mask = None
493-
elif self.config._attn_implementation != "flash_attention_2":
494-
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
488+
patch_attention_mask = create_bidirectional_mask(
489+
config=self.config,
490+
input_embeds=hidden_states,
491+
attention_mask=patch_attention_mask,
492+
)
495493

496494
encoder_outputs: BaseModelOutput = self.encoder(
497495
inputs_embeds=hidden_states,

src/transformers/models/informer/modeling_informer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ...activations import ACT2FN
3131
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3232
from ...masking_utils import create_bidirectional_mask, create_causal_mask
33-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
33+
from ...masking_utils import create_bidirectional_mask
3434
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3535
from ...modeling_layers import GradientCheckpointingLayer
3636
from ...modeling_outputs import (
@@ -915,10 +915,11 @@ def forward(
915915
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
916916
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
917917

918-
# expand attention_mask
919-
if attention_mask is not None:
920-
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
921-
attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
918+
attention_mask = create_bidirectional_mask(
919+
config=self.config,
920+
input_embeds=inputs_embeds,
921+
attention_mask=attention_mask,
922+
)
922923

923924
encoder_states = () if output_hidden_states else None
924925
all_attentions = () if output_attentions else None

0 commit comments

Comments
 (0)