@@ -821,6 +821,7 @@ def forward(
821821 use_cache : bool = False ,
822822 inbatch_pack_offset : Optional [Tuple [paddle .Tensor ]] = None ,
823823 token_type_ids : Optional [Tuple [paddle .Tensor ]] = None ,
824+ attn_mask_startend_row_indices : Optional [paddle .Tensor ] = None ,
824825 ) -> Tuple [paddle .Tensor , Optional [paddle .Tensor ], Optional [Tuple [paddle .Tensor ]]]:
825826 if token_type_ids is not None :
826827 token_type_ids = token_type_ids [:, :- 1 ]
@@ -901,6 +902,7 @@ def forward(
901902 past_key_value ,
902903 use_cache ,
903904 inbatch_pack_offset ,
905+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
904906 use_reentrant = False ,
905907 )
906908 else :
@@ -915,6 +917,7 @@ def forward(
915917 past_key_value = past_key_value ,
916918 use_cache = use_cache ,
917919 inbatch_pack_offset = inbatch_pack_offset ,
920+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
918921 )
919922 if self .config .sequence_parallel :
920923 attn_output = attn_output .reshape ([- 1 , attn_output .shape [- 1 ]])
@@ -1152,6 +1155,7 @@ def forward(
11521155 use_cache : Optional [bool ] = False ,
11531156 inbatch_pack_offset : Optional [paddle .Tensor ] = None ,
11541157 output_gate_logits = True ,
1158+ attn_mask_startend_row_indices : Optional [paddle .Tensor ] = None ,
11551159 ) -> Tuple [paddle .Tensor , Optional [Tuple [paddle .Tensor , paddle .Tensor ]]]:
11561160 residual = hidden_states
11571161 if token_type_ids is not None :
@@ -1178,6 +1182,7 @@ def forward(
11781182 use_cache = use_cache ,
11791183 inbatch_pack_offset = inbatch_pack_offset ,
11801184 token_type_ids = token_type_ids ,
1185+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
11811186 )
11821187
11831188 if self .use_linear_residual_norm_recompute is True :
@@ -1660,6 +1665,7 @@ def forward(
16601665 output_hidden_states = None ,
16611666 return_dict = False ,
16621667 inbatch_pack_offset = None ,
1668+ attn_mask_startend_row_indices = None ,
16631669 ** kwargs ,
16641670 ):
16651671 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -1719,6 +1725,12 @@ def forward(
17191725 )
17201726 hidden_states = inputs_embeds
17211727
1728+ attn_mask_startend_row_indices_ori = attn_mask_startend_row_indices
1729+ if attn_mask_startend_row_indices is not None :
1730+ attn_mask_startend_row_indices = attn_mask_startend_row_indices [
1731+ :, :, : - self .config .multi_token_pred_depth
1732+ ]
1733+
17221734 all_hidden_states = () if output_hidden_states else None
17231735 all_self_attns = () if output_attentions else None
17241736 next_decoder_cache = () if use_cache else None
@@ -1743,6 +1755,7 @@ def forward(
17431755 past_key_value ,
17441756 use_cache ,
17451757 inbatch_pack_offset ,
1758+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
17461759 )
17471760 else :
17481761 layer_outputs = decoder_layer (
@@ -1754,6 +1767,7 @@ def forward(
17541767 past_key_value ,
17551768 use_cache ,
17561769 inbatch_pack_offset ,
1770+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
17571771 )
17581772
17591773 if isinstance (layer_outputs , (tuple , list )):
@@ -1786,6 +1800,11 @@ def forward(
17861800 ],
17871801 axis = 1 ,
17881802 )
1803+ attn_mask_startend_row_indices_cur_depth = None
1804+ if attn_mask_startend_row_indices is not None :
1805+ attn_mask_startend_row_indices_cur_depth = attn_mask_startend_row_indices_ori [
1806+ :, :, (depth + 1 ) : inputs_embeds_ori .shape [1 ] + (depth + 1 )
1807+ ] - (depth + 1 )
17891808
17901809 inputs_embeds_cur_depth_norm = self .mtp_emb_norm [depth ](inputs_embeds_cur_depth )
17911810 hidden_states_norm = self .mtp_hidden_norm [depth ](hidden_states )
@@ -1809,6 +1828,7 @@ def forward(
18091828 past_key_value ,
18101829 use_cache ,
18111830 inbatch_pack_offset ,
1831+ attn_mask_startend_row_indices = attn_mask_startend_row_indices_cur_depth ,
18121832 )
18131833
18141834 if isinstance (layer_outputs , (tuple , list )):
@@ -2132,6 +2152,8 @@ def forward(
21322152 data_id = None ,
21332153 src_id = None ,
21342154 inbatch_pack_offset = None ,
2155+ attn_mask_startend_row_indices = None ,
2156+ ** kwargs ,
21352157 ):
21362158 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
21372159 output_hidden_states = (
@@ -2151,6 +2173,7 @@ def forward(
21512173 output_hidden_states = output_hidden_states ,
21522174 return_dict = True ,
21532175 inbatch_pack_offset = inbatch_pack_offset ,
2176+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
21542177 )
21552178
21562179 hidden_states = outputs .last_hidden_state
0 commit comments