@@ -47,11 +47,12 @@ def prepare_fa2_from_position_ids(
4747 query = query .contiguous ().view (- 1 , query .size (- 2 ), query .size (- 1 ))
4848 key = key .contiguous ().view (- 1 , key .size (- 2 ), key .size (- 1 ))
4949 value = value .contiguous ().view (- 1 , value .size (- 2 ), value .size (- 1 ))
50+ tensor_kwargs = {"dtype" : torch .int32 , "device" : position_ids .device }
5051 position_ids = position_ids .view (- 1 )
5152 cu_seqlens = torch .cat (
5253 (
53- (position_ids == 0 ).nonzero ().view (- 1 ).to (torch . int32 ),
54- torch .tensor (position_ids .size (), device = position_ids . device , dtype = torch . int32 ),
54+ (position_ids == 0 ).nonzero ().view (- 1 ).to (** tensor_kwargs ),
55+ torch .tensor (position_ids .size (), ** tensor_kwargs ),
5556 )
5657 )
5758 max_length = cu_seqlens .diff ().max () # use cu_seqlens to infer max_length for qwen2vl mrope
@@ -90,12 +91,9 @@ def _custom_flash_attention_forward(
9091 query_states , key_states , value_states , target_dtype = torch .bfloat16
9192 )
9293
93- if position_ids is not None :
94- assert position_ids .ndim == 2 # (batch_size, seq_length)
95-
9694 sp_size = get_ulysses_sequence_parallel_world_size ()
9795 if sp_size > 1 :
98- # qkv: (batch_size, seq_length, num_head, head_size)
96+ # qkv: (batch_size, seq_length / sp_size , num_head, head_size)
9997 query_states = gather_seq_scatter_heads (query_states , seq_dim = 1 , head_dim = 2 )
10098 key_states = gather_seq_scatter_heads (key_states , seq_dim = 1 , head_dim = 2 )
10199 value_states = gather_seq_scatter_heads (value_states , seq_dim = 1 , head_dim = 2 )
@@ -105,19 +103,17 @@ def _custom_flash_attention_forward(
105103
106104 if position_ids is not None and query_length != 1 and not (torch .diff (position_ids , dim = - 1 ) >= 0 ).all ():
107105 batch_size = query_states .size (0 )
108- query_states , key_states , value_states , cu_seq_lens , max_seq_lens = prepare_fa2_from_position_ids (
106+ q , k , v , ( cu_seqlens_q , cu_seqlens_k ), ( max_seqlen_q , max_seqlen_k ) = prepare_fa2_from_position_ids (
109107 query_states , key_states , value_states , position_ids
110108 )
111- cu_seqlens_q , cu_seqlens_k = cu_seq_lens
112- max_seqlen_in_batch_q , max_seqlen_in_batch_k = max_seq_lens
113109 attn_output = flash_attn_varlen_func (
114- query_states ,
115- key_states ,
116- value_states ,
110+ q ,
111+ k ,
112+ v ,
117113 cu_seqlens_q = cu_seqlens_q ,
118114 cu_seqlens_k = cu_seqlens_k ,
119- max_seqlen_q = max_seqlen_in_batch_q ,
120- max_seqlen_k = max_seqlen_in_batch_k ,
115+ max_seqlen_q = max_seqlen_q ,
116+ max_seqlen_k = max_seqlen_k ,
121117 dropout_p = kwargs .pop ("dropout" , 0.0 ),
122118 softmax_scale = kwargs .pop ("softmax_scale" , None ),
123119 causal = is_causal ,
@@ -132,14 +128,15 @@ def _custom_flash_attention_forward(
132128 attention_mask ,
133129 query_length ,
134130 is_causal = is_causal ,
131+ position_ids = position_ids ,
135132 sliding_window = sliding_window ,
136133 use_top_left_mask = use_top_left_mask ,
137134 deterministic = deterministic ,
138135 ** kwargs ,
139- ) # do not pass position_ids to old flash_attention_forward
136+ )
140137
141138 if sp_size > 1 :
142- # (batch_size, seq_length, num_head, head_size)
139+ # output: (batch_size, seq_length / sp_size , num_head, head_size)
143140 attn_output = gather_heads_scatter_seq (attn_output , head_dim = 2 , seq_dim = 1 )
144141
145142 return attn_output
0 commit comments