@@ -195,7 +195,11 @@ def prepare_incremental_tree_batch(
195195) -> Tuple [torch .Tensor , torch .Tensor , List [List [List [TreeNode ]]]]:
196196 """
197197 准备增量 tree batch,支持不同序列长度
198+ attention mask 直接输出 float score:0.0 表示可attend,-65504.0 表示被遮蔽
198199 """
200+ MASKED = - 65504.0
201+ ATTEND = 0.0
202+
199203 batch_size = len (trees )
200204
201205 if not trees or all (tree .total_nodes <= 1 for tree in trees ):
@@ -228,64 +232,71 @@ def prepare_incremental_tree_batch(
228232 batch_tree_tokens .append (padded_tokens )
229233
230234 tree_len = len (tree_token_ids ) # 不包含 root
231- inputs_len = tree_len + 1 # root + tree tokens
235+ inputs_len = tree_len + 1 # root + tree tokens
232236
233237 if is_prefill :
234- # ============ Prefill 阶段(不变) ============
238+ # ============ Prefill 阶段 ============
235239 past_len = input_ids .shape [1 ]
236240 total_len = past_len + tree_len
237- mask = torch .zeros ( 1 , total_len , total_len , dtype = torch .bool , device = device )
241+ mask = torch .full (( 1 , total_len , total_len ), MASKED , dtype = torch .float , device = device )
238242
239243 prompt_len = curr_seq_len - 1 if curr_seq_len > 0 else 0
240244 root_pos = prompt_len
241245
246+ # Prompt 部分:causal mask
242247 if prompt_len > 0 :
243248 row_idx = torch .arange (prompt_len , device = device ).view (- 1 , 1 )
244249 col_idx = torch .arange (prompt_len , device = device ).view (1 , - 1 )
245- causal_mask = row_idx >= col_idx
246- mask [0 , :prompt_len , :prompt_len ] = causal_mask
250+ causal_mask = row_idx >= col_idx # bool
251+ mask [0 , :prompt_len , :prompt_len ] = torch . where ( causal_mask , ATTEND , MASKED )
247252
253+ # Root attend to prompt + 自己
248254 if prompt_len > 0 :
249- mask [0 , root_pos , :prompt_len ] = True
250- mask [0 , root_pos , root_pos ] = True
255+ mask [0 , root_pos , :prompt_len ] = ATTEND
256+ mask [0 , root_pos , root_pos ] = ATTEND
251257
258+ # Tree tokens attend to prompt + root
252259 if tree_len > 0 :
253260 if prompt_len > 0 :
254- mask [0 , past_len :past_len + tree_len , :prompt_len ] = True
255- mask [0 , past_len :past_len + tree_len , root_pos ] = True
261+ mask [0 , past_len :past_len + tree_len , :prompt_len ] = ATTEND
262+ mask [0 , past_len :past_len + tree_len , root_pos ] = ATTEND
256263
264+ # Tree tokens 之间
257265 if tree_len > 0 :
258266 tree_mask = build_tree_attention_mask_with_root (tree_len , parent_indices , device )
259267 mask [0 , past_len :past_len + tree_len , past_len :past_len + tree_len ] = tree_mask
260268
269+ # Padding
261270 if tree_len < max_tree_size :
262271 total_padded_len = past_len + max_tree_size
263- padded_mask = torch .zeros (1 , total_padded_len , total_padded_len , dtype = torch .bool , device = device )
272+ padded_mask = torch .full (
273+ (1 , total_padded_len , total_padded_len ), MASKED , dtype = torch .float , device = device
274+ )
264275 padded_mask [0 , :total_len , :total_len ] = mask [0 ]
276+ # Padding 行 attend to prompt(避免 NaN)
265277 if curr_seq_len > 0 :
266- padded_mask [0 , total_len :, :curr_seq_len ] = True
278+ padded_mask [0 , total_len :, :curr_seq_len ] = ATTEND
267279 mask = padded_mask
268-
280+
269281 else :
270282 # ============ Generation 阶段 ============
271- # 总长度 = cache + 本轮输入
272283 total_len = cache_len + inputs_len
284+ mask = torch .full ((1 , inputs_len , total_len ), MASKED , dtype = torch .float , device = device )
273285
274- mask = torch .zeros (1 , inputs_len , total_len , dtype = torch .bool , device = device )
275-
276- # 计算 cache 中的有效位置
277- cache_valid_mask = _compute_single_cache_valid_mask (
286+ # 计算 cache 中的有效位置(bool),再映射为 score
287+ cache_valid_bool = _compute_single_cache_valid_mask (
278288 kv_cache_position_ids [i ], cache_len , device
279- )
280-
289+ ) # shape: (cache_len,), dtype: bool
290+ cache_scores = torch .where (cache_valid_bool , ATTEND , MASKED ) # float scores
291+
281292 # 1. Root attend to cache + 自己
282- mask [0 , 0 , :cache_len ] = cache_valid_mask
283- mask [0 , 0 , cache_len ] = True # root attend 自己
293+ mask [0 , 0 , :cache_len ] = cache_scores
294+ mask [0 , 0 , cache_len ] = ATTEND # root attend 自己
284295
285296 # 2. Tree tokens attend to cache + root
286297 if tree_len > 0 :
287- mask [0 , 1 :inputs_len , :cache_len ] = cache_valid_mask .unsqueeze (0 ).expand (tree_len , cache_len )
288- mask [0 , 1 :inputs_len , cache_len ] = True # tree tokens attend to root
298+ mask [0 , 1 :inputs_len , :cache_len ] = cache_scores .unsqueeze (0 ).expand (tree_len , cache_len )
299+ mask [0 , 1 :inputs_len , cache_len ] = ATTEND # tree tokens attend to root
289300
290301 # 3. Tree tokens 之间
291302 if tree_len > 0 :
@@ -297,10 +308,12 @@ def prepare_incremental_tree_batch(
297308 if inputs_len < max_inputs_len :
298309 pad_len = max_inputs_len - inputs_len
299310 total_padded_len = cache_len + max_inputs_len
300- padded_mask = torch .zeros (1 , max_inputs_len , total_padded_len , dtype = torch .bool , device = device )
311+ padded_mask = torch .full (
312+ (1 , max_inputs_len , total_padded_len ), MASKED , dtype = torch .float , device = device
313+ )
301314 padded_mask [0 , :inputs_len , :total_len ] = mask [0 ]
302315 # Padding 行 attend to cache(避免 NaN)
303- padded_mask [0 , inputs_len :, :cache_len ] = cache_valid_mask .unsqueeze (0 ).expand (pad_len , cache_len )
316+ padded_mask [0 , inputs_len :, :cache_len ] = cache_scores .unsqueeze (0 ).expand (pad_len , cache_len )
304317 mask = padded_mask
305318
306319 batch_attention_masks .append (mask )
@@ -358,16 +371,17 @@ def build_tree_attention_mask_with_root(
358371) -> torch .Tensor :
359372 """
360373 构建 tree tokens 之间的 attention mask(不包含 root)
374+ 直接返回 float score mask:0.0 表示可attend,-65504.0 表示被遮蔽
361375 """
362- mask = torch .zeros ( tree_len , tree_len , dtype = torch .bool , device = device )
376+ mask = torch .full (( tree_len , tree_len ), - 65504.0 , dtype = torch .float , device = device )
363377
364378 for i in range (tree_len ):
365- mask [i , i ] = True
379+ mask [i , i ] = 0.0
366380 current = i
367381 while current >= 0 :
368382 parent = parent_indices [current ]
369383 if parent >= 0 :
370- mask [i , parent ] = True
384+ mask [i , parent ] = 0.0
371385 current = parent
372386 else :
373387 break
0 commit comments