Skip to content

Commit a5be9f2

Browse files
author
Xu Xiong
committed
optimize code
1 parent 5b186f0 commit a5be9f2

8 files changed

Lines changed: 90 additions & 1310 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def benchmark_inference(process_idx, args, result_pipe):
7373

7474
drafter = MultiSSMDrafter(
7575
ssm_model_name="JackFram/llama-68m",
76-
num_workers=2,
76+
num_workers=1,
7777
device="cuda"
7878
)
7979
model = AutoDistributedSpeculativeModel.from_pretrained(

src/bloombee/models/llama/spe_dec_tree.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@ def prepare_incremental_tree_batch(
195195
) -> Tuple[torch.Tensor, torch.Tensor, List[List[List[TreeNode]]]]:
196196
"""
197197
准备增量 tree batch,支持不同序列长度
198+
attention mask 直接输出 float score:0.0 表示可attend,-65504.0 表示被遮蔽
198199
"""
200+
MASKED = -65504.0
201+
ATTEND = 0.0
202+
199203
batch_size = len(trees)
200204

201205
if not trees or all(tree.total_nodes <= 1 for tree in trees):
@@ -228,64 +232,71 @@ def prepare_incremental_tree_batch(
228232
batch_tree_tokens.append(padded_tokens)
229233

230234
tree_len = len(tree_token_ids) # 不包含 root
231-
inputs_len = tree_len + 1 # root + tree tokens
235+
inputs_len = tree_len + 1 # root + tree tokens
232236

233237
if is_prefill:
234-
# ============ Prefill 阶段(不变) ============
238+
# ============ Prefill 阶段 ============
235239
past_len = input_ids.shape[1]
236240
total_len = past_len + tree_len
237-
mask = torch.zeros(1, total_len, total_len, dtype=torch.bool, device=device)
241+
mask = torch.full((1, total_len, total_len), MASKED, dtype=torch.float, device=device)
238242

239243
prompt_len = curr_seq_len - 1 if curr_seq_len > 0 else 0
240244
root_pos = prompt_len
241245

246+
# Prompt 部分:causal mask
242247
if prompt_len > 0:
243248
row_idx = torch.arange(prompt_len, device=device).view(-1, 1)
244249
col_idx = torch.arange(prompt_len, device=device).view(1, -1)
245-
causal_mask = row_idx >= col_idx
246-
mask[0, :prompt_len, :prompt_len] = causal_mask
250+
causal_mask = row_idx >= col_idx # bool
251+
mask[0, :prompt_len, :prompt_len] = torch.where(causal_mask, ATTEND, MASKED)
247252

253+
# Root attend to prompt + 自己
248254
if prompt_len > 0:
249-
mask[0, root_pos, :prompt_len] = True
250-
mask[0, root_pos, root_pos] = True
255+
mask[0, root_pos, :prompt_len] = ATTEND
256+
mask[0, root_pos, root_pos] = ATTEND
251257

258+
# Tree tokens attend to prompt + root
252259
if tree_len > 0:
253260
if prompt_len > 0:
254-
mask[0, past_len:past_len + tree_len, :prompt_len] = True
255-
mask[0, past_len:past_len + tree_len, root_pos] = True
261+
mask[0, past_len:past_len + tree_len, :prompt_len] = ATTEND
262+
mask[0, past_len:past_len + tree_len, root_pos] = ATTEND
256263

264+
# Tree tokens 之间
257265
if tree_len > 0:
258266
tree_mask = build_tree_attention_mask_with_root(tree_len, parent_indices, device)
259267
mask[0, past_len:past_len + tree_len, past_len:past_len + tree_len] = tree_mask
260268

269+
# Padding
261270
if tree_len < max_tree_size:
262271
total_padded_len = past_len + max_tree_size
263-
padded_mask = torch.zeros(1, total_padded_len, total_padded_len, dtype=torch.bool, device=device)
272+
padded_mask = torch.full(
273+
(1, total_padded_len, total_padded_len), MASKED, dtype=torch.float, device=device
274+
)
264275
padded_mask[0, :total_len, :total_len] = mask[0]
276+
# Padding 行 attend to prompt(避免 NaN)
265277
if curr_seq_len > 0:
266-
padded_mask[0, total_len:, :curr_seq_len] = True
278+
padded_mask[0, total_len:, :curr_seq_len] = ATTEND
267279
mask = padded_mask
268-
280+
269281
else:
270282
# ============ Generation 阶段 ============
271-
# 总长度 = cache + 本轮输入
272283
total_len = cache_len + inputs_len
284+
mask = torch.full((1, inputs_len, total_len), MASKED, dtype=torch.float, device=device)
273285

274-
mask = torch.zeros(1, inputs_len, total_len, dtype=torch.bool, device=device)
275-
276-
# 计算 cache 中的有效位置
277-
cache_valid_mask = _compute_single_cache_valid_mask(
286+
# 计算 cache 中的有效位置(bool),再映射为 score
287+
cache_valid_bool = _compute_single_cache_valid_mask(
278288
kv_cache_position_ids[i], cache_len, device
279-
)
280-
289+
) # shape: (cache_len,), dtype: bool
290+
cache_scores = torch.where(cache_valid_bool, ATTEND, MASKED) # float scores
291+
281292
# 1. Root attend to cache + 自己
282-
mask[0, 0, :cache_len] = cache_valid_mask
283-
mask[0, 0, cache_len] = True # root attend 自己
293+
mask[0, 0, :cache_len] = cache_scores
294+
mask[0, 0, cache_len] = ATTEND # root attend 自己
284295

285296
# 2. Tree tokens attend to cache + root
286297
if tree_len > 0:
287-
mask[0, 1:inputs_len, :cache_len] = cache_valid_mask.unsqueeze(0).expand(tree_len, cache_len)
288-
mask[0, 1:inputs_len, cache_len] = True # tree tokens attend to root
298+
mask[0, 1:inputs_len, :cache_len] = cache_scores.unsqueeze(0).expand(tree_len, cache_len)
299+
mask[0, 1:inputs_len, cache_len] = ATTEND # tree tokens attend to root
289300

290301
# 3. Tree tokens 之间
291302
if tree_len > 0:
@@ -297,10 +308,12 @@ def prepare_incremental_tree_batch(
297308
if inputs_len < max_inputs_len:
298309
pad_len = max_inputs_len - inputs_len
299310
total_padded_len = cache_len + max_inputs_len
300-
padded_mask = torch.zeros(1, max_inputs_len, total_padded_len, dtype=torch.bool, device=device)
311+
padded_mask = torch.full(
312+
(1, max_inputs_len, total_padded_len), MASKED, dtype=torch.float, device=device
313+
)
301314
padded_mask[0, :inputs_len, :total_len] = mask[0]
302315
# Padding 行 attend to cache(避免 NaN)
303-
padded_mask[0, inputs_len:, :cache_len] = cache_valid_mask.unsqueeze(0).expand(pad_len, cache_len)
316+
padded_mask[0, inputs_len:, :cache_len] = cache_scores.unsqueeze(0).expand(pad_len, cache_len)
304317
mask = padded_mask
305318

306319
batch_attention_masks.append(mask)
@@ -358,16 +371,17 @@ def build_tree_attention_mask_with_root(
358371
) -> torch.Tensor:
359372
"""
360373
构建 tree tokens 之间的 attention mask(不包含 root)
374+
直接返回 float score mask:0.0 表示可attend,-65504.0 表示被遮蔽
361375
"""
362-
mask = torch.zeros(tree_len, tree_len, dtype=torch.bool, device=device)
376+
mask = torch.full((tree_len, tree_len), -65504.0, dtype=torch.float, device=device)
363377

364378
for i in range(tree_len):
365-
mask[i, i] = True
379+
mask[i, i] = 0.0
366380
current = i
367381
while current >= 0:
368382
parent = parent_indices[current]
369383
if parent >= 0:
370-
mask[i, parent] = True
384+
mask[i, parent] = 0.0
371385
current = parent
372386
else:
373387
break

src/bloombee/models/llama/speculative_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def _sample_with_session(
134134
has_printed_first_reach = False # 确保只打印一次
135135
sample_finish_times = [None] * batch_size
136136
sample_finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
137-
while not finished and (seq_lengths - initial_seq_lengths).min().item() < max_new_tokens:
137+
while not finished and (seq_lengths - initial_seq_lengths).max().item() < max_new_tokens:
138138
# 1. Build speculative trees using SSM - 传入 seq_lengths
139139
t1 = time.perf_counter()
140140
spec_trees = drafter.build_trees_parallel(

src/bloombee/server/backend.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
BatchTensorDescriptor((), dtype=self.dtype),
142142
BatchTensorDescriptor((), dtype=torch.int64),
143143
BatchTensorDescriptor(
144-
1, 64, 64, dtype=self.dtype
144+
1, 64, 64, dtype=torch.float
145145
), # tree_attention_mask
146146
BatchTensorDescriptor(
147147
128, dtype=torch.int64
@@ -284,7 +284,7 @@ def inference_step( # Each block will execute once
284284

285285
self._ensure_model_on_device()
286286

287-
# t0 = time.perf_counter()
287+
t0 = time.perf_counter()
288288
with self.cache_manager.use_cache(
289289
*inference_info.cache_handles # Use cache to reduce memory requirements
290290
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): # Use adapter for inference
@@ -319,6 +319,8 @@ def _flag_to_bool(value) -> bool:
319319
f"micro_batch_size={inference_info.micro_batch_size}, "
320320
f"full_batch_size={inference_info.full_batch_size}")
321321

322+
t1 = time.perf_counter()
323+
322324
if kv_cache_position_ids is not None and kv_cache_position_ids.numel() > 0:
323325
k_pkv, v_pkv, cache_len = self.cache_manager.select_cache_without_reorder(
324326
kv_cache_position_ids,
@@ -339,8 +341,8 @@ def _flag_to_bool(value) -> bool:
339341
)
340342
cache_len = k_pkv.shape[2] if k_pkv is not None else 0
341343

342-
# t2 = time.perf_counter()
343-
# logger.info(f"inference_step: cache reorder (if needed) and selection took {t2 - t1:.4f} seconds")
344+
t2 = time.perf_counter()
345+
logger.info(f"inference_step: cache reorder (if needed) and selection took {t2 - t1:.4f} seconds")
344346

345347
layer_past = (k_pkv, v_pkv) if k_pkv is not None else None
346348

@@ -349,11 +351,15 @@ def _flag_to_bool(value) -> bool:
349351

350352
if self._is_spec_decoding:
351353
full_mask = inference_info.tree_attention_mask.to(device)
352-
attention_mask = self.convert_mask_to_scores(full_mask) if full_mask is not None else None
354+
attention_mask = full_mask
353355
if full_mask == None:
354356
full_mask = self._create_causal_attention_mask(batch_size, (seq_len + cache_len), cache_len, hidden_states.device)
355357
attention_mask = self.convert_mask_to_scores(full_mask) if full_mask is not None else None
356358

359+
t3 = time.perf_counter()
360+
logger.info(f"convert_mask_to_scores took {t3 - t2:.4f} seconds")
361+
362+
357363
for offset in range(0, seq_len, max_chunk_length): # Iterate through sequence to process hidden states in chunks only run offset=0
358364
hidden_states_chunk = hidden_states[:, offset : offset + max_chunk_length, :] # Get current hidden states chunk
359365
# print('transformer backend inference step() offset ', offset )
@@ -378,6 +384,9 @@ def _flag_to_bool(value) -> bool:
378384
target_seq_len=seq_len)
379385
else:
380386
rotary_position_ids = None
387+
388+
t4 = time.perf_counter()
389+
logger.info(f"_create_tree_position_ids_with_invalid_cache took {t4 - t3:.4f} seconds")
381390

382391
try:
383392
# Fixed: Properly handle forward method return values with position_ids
@@ -391,8 +400,8 @@ def _flag_to_bool(value) -> bool:
391400
rotary_position_ids=rotary_position_ids,
392401
)
393402

394-
# t5 = time.perf_counter()
395-
# logger.info(f"inference_step: module.forward call took {t5 - t4:.4f} seconds")
403+
t5 = time.perf_counter()
404+
logger.info(f"inference_step: module.forward call took {t5 - t4:.4f} seconds")
396405

397406
if forward_result is None:
398407
logger.info(f" ERROR: module.forward returned None!")
@@ -438,6 +447,10 @@ def _flag_to_bool(value) -> bool:
438447
batch_offset=inference_info.batch_offset,
439448
full_batch_size=inference_info.full_batch_size,
440449
micro_batch_size=inference_info.micro_batch_size,)
450+
451+
t6 = time.perf_counter()
452+
logger.info(f"update_cache_and_async_reorder took {t6 - t5:.4f} seconds")
453+
441454

442455
keep_indices = self._normalize_keep_indices(
443456
inference_info.keep_indices,
@@ -471,6 +484,11 @@ def _flag_to_bool(value) -> bool:
471484
norm_hidden_states = self.module.rms_norm(output_hidden_states)
472485
keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
473486
keep_indices = keep_indices
487+
t7 = time.perf_counter()
488+
logger.info(f"prune_draft_tree took {t7 - t6:.4f} seconds")
489+
490+
491+
474492

475493
if not training_mode and self._is_spec_decoding and self._is_last_block:
476494
original_hidden_states = output_hidden_states

0 commit comments

Comments
 (0)