3030 BarkEosPrioritizerLogitsProcessor ,
3131 SuppressTokensLogitsProcessor ,
3232)
33- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
33+ from ...masking_utils import create_bidirectional_mask
3434from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask , is_flash_attn_available
3535from ...modeling_layers import GradientCheckpointingLayer
3636from ...modeling_outputs import CausalLMOutputWithPast , MaskedLMOutput
@@ -464,7 +464,6 @@ def forward(
464464 raise ValueError ("You have to specify either input_ids or input_embeds" )
465465
466466 input_shape = input_embeds .size ()[:- 1 ]
467- batch_size = input_embeds .shape [0 ]
468467 seq_length = input_shape [- 1 ]
469468
470469 device = input_ids .device if input_ids is not None else input_embeds .device
@@ -487,17 +486,11 @@ def forward(
487486
488487 position_embeds = self .position_embeds_layer (position_ids ) # position embeddings of shape (1, t, n_embd)
489488
490- # Attention mask.
491- if attention_mask is not None :
492- if batch_size <= 0 :
493- raise ValueError ("batch_size has to be defined and > 0" )
494- if self .config ._attn_implementation == "flash_attention_2" :
495- attention_mask = attention_mask if 0 in attention_mask else None
496- else :
497- attention_mask = attention_mask .view (batch_size , - 1 )
498- # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
499- # from_seq_length is 1 to easily broadcast
500- attention_mask = _prepare_4d_attention_mask (attention_mask , input_embeds .dtype , tgt_len = 1 )
489+ attention_mask = create_bidirectional_mask (
490+ config = self .config ,
491+ input_embeds = input_embeds ,
492+ attention_mask = attention_mask ,
493+ )
501494
502495 hidden_states = self .drop (input_embeds + position_embeds )
503496 output_shape = input_shape + (hidden_states .size (- 1 ),)
@@ -1074,7 +1067,6 @@ def forward(
10741067 input_embeds = input_embeds [:, :, :, : codebook_idx + 1 ].sum (dim = - 1 )
10751068
10761069 input_shape = input_embeds .size ()[:- 1 ]
1077- batch_size = input_embeds .shape [0 ]
10781070 seq_length = input_shape [1 ]
10791071
10801072 device = input_ids .device if input_ids is not None else input_embeds .device
@@ -1085,16 +1077,11 @@ def forward(
10851077
10861078 position_embeds = self .position_embeds_layer (position_ids ) # position embeddings of shape (1, t, n_embd)
10871079
1088- # Attention mask.
1089- if attention_mask is not None :
1090- if batch_size <= 0 :
1091- raise ValueError ("batch_size has to be defined and > 0" )
1092- if self .config ._attn_implementation == "flash_attention_2" :
1093- attention_mask = attention_mask if 0 in attention_mask else None
1094- else :
1095- # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
1096- # from_seq_length is 1 to easily broadcast
1097- attention_mask = _prepare_4d_attention_mask (attention_mask , input_embeds .dtype , tgt_len = 1 )
1080+ attention_mask = create_bidirectional_mask (
1081+ config = self .config ,
1082+ input_embeds = input_embeds ,
1083+ attention_mask = attention_mask ,
1084+ )
10981085
10991086 hidden_states = self .drop (input_embeds + position_embeds )
11001087 output_shape = input_shape + (hidden_states .size (- 1 ),)
0 commit comments