Skip to content

Commit 5cbbe85

Browse files
author
Xu Xiong
committed
evaluate_sd_mb
1 parent 9e97c90 commit 5cbbe85

9 files changed

Lines changed: 39 additions & 39 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,25 +74,9 @@ def benchmark_inference(process_idx, args, result_pipe):
7474
indices = random.sample(range(len(dataset)), batch_size)
7575
sampled = dataset.select(indices)
7676
test_prompts = []
77-
# for item in sampled:
78-
# test_prompts.append(item["instruction"])
79-
80-
base_prompt = (
81-
"Quantum mechanics explains the behavior of particles at very small scales. "
82-
"Neural networks learn patterns by adjusting weights through backpropagation. "
83-
"Distributed systems require robust consensus mechanisms to maintain state. "
84-
"Optimization algorithms like gradient descent are fundamental to machine learning. "
85-
"Transformer architectures rely on attention mechanisms to capture dependencies. "
86-
"Reinforcement learning optimizes actions by maximizing cumulative rewards. "
87-
"Bayesian inference updates beliefs based on observed evidence and prior knowledge. "
88-
"Convex optimization problems guarantee global minima under certain conditions. "
89-
"Signal processing extracts meaningful information from noisy measurements. "
90-
)
91-
prompts = [
92-
f"{base_prompt} Example {i + 1} discusses large-scale AI systems and scientific discovery."
93-
for i in range(batch_size)
94-
]
95-
test_prompts = prompts
77+
for item in sampled:
78+
test_prompts.append(item["instruction"])
79+
9680

9781
tokenizer.pad_token = tokenizer.eos_token
9882
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"]

eval_indices.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[[41905, 7296, 1639, 48598, 18024, 16049, 14628, 9144, 48265, 6717, 44348, 48540, 35741, 5697, 38698, 27651, 2082, 1952, 6140, 14328, 15247, 33118, 39453, 1739, 36781, 13031, 46925, 42590, 45962, 35713, 27493, 14446], [29439, 38618, 18231, 425, 49729, 10463, 45753, 27696, 22298, 18210, 10189, 14110, 50036, 22059, 6698, 6078, 24898, 6338, 23526, 22541, 39565, 17335, 2847, 47823, 30108, 35142, 8180, 24807, 5164, 36178, 19213, 41198], [40535, 23700, 37837, 12601, 46174, 4558, 3003, 43336, 14935, 50663, 18965, 5229, 15256, 6619, 24911, 18217, 29714, 41660, 23909, 10659, 24260, 23283, 13730, 43920, 17496, 45994, 44796, 42469, 4679, 39920, 41613, 11215], [35005, 47784, 16043, 10708, 30294, 24867, 17691, 41943, 45099, 36500, 14392, 44866, 21252, 50352, 50855, 3665, 15010, 2103, 20673, 26290, 17546, 4337, 13826, 37170, 47049, 20622, 13934, 42954, 32717, 25928, 42129, 30071], [9363, 17359, 9150, 16162, 48823, 36789, 35322, 17219, 48956, 38311, 28077, 38242, 26175, 23723, 14373, 9065, 33392, 32343, 5957, 49530, 3087, 7185, 10016, 41120, 10484, 51909, 44596, 27666, 39086, 4163, 25216, 25009], [39052, 30674, 34676, 16476, 36256, 752, 44583, 47233, 7507, 44676, 35190, 49209, 17486, 50370, 42006, 22293, 7310, 19234, 28492, 10365, 29735, 212, 47323, 47164, 17261, 32806, 49935, 11708, 33271, 6973, 40979, 19558], [41874, 33270, 39909, 13035, 10016, 24504, 49971, 10587, 35348, 51028, 34757, 37, 39252, 21243, 32021, 1276, 7331, 23788, 20153, 15692, 3796, 15785, 37182, 5161, 5613, 47966, 31849, 4535, 49846, 34911, 50189, 8241], [8414, 43237, 31148, 36031, 10821, 17370, 34581, 39753, 27730, 13880, 35343, 49497, 47836, 45211, 13182, 46723, 20428, 26148, 44019, 42590, 24472, 28711, 33919, 29588, 7930, 16246, 14725, 4196, 22156, 1378, 38555, 36301], [15080, 38564, 14432, 471, 4652, 46389, 41359, 3858, 15003, 4417, 2058, 21654, 4643, 33695, 15597, 18250, 43842, 31812, 14040, 35339, 8671, 47405, 37423, 37762, 30976, 15925, 51420, 30996, 26677, 12478, 6181, 6352], [43187, 28249, 23219, 27759, 26941, 30606, 47780, 3550, 44129, 42824, 42348, 6449, 3972, 26386, 47724, 22236, 7161, 16295, 12556, 12465, 35146, 29400, 9186, 27648, 12025, 18254, 30318, 16371, 4940, 29041, 36066, 6416]]

src/bloombee/models/llama/speculative_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def generate(
3535
logits_processor: Optional[LogitsProcessorList] = None,
3636
stopping_criteria: Optional[StoppingCriteriaList] = None,
3737
streamer: Optional["BaseStreamer"] = None,
38-
beam_width: int = 1,
39-
max_tree_depth: int = 4,
38+
beam_width: int = 2,
39+
max_tree_depth: int = 3,
4040
use_kv_cache: bool = True,
4141
kv_cache_window: int = 2048,
4242
max_new_tokens: int = 128,

src/bloombee/server/backend.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def _flag_to_bool(value) -> bool:
369369
position_ids = self._position_ids_cache[cache_key] + (cache_len + offset)
370370
if self._is_spec_decoding:
371371
rotary_position_ids = self._create_tree_position_ids_with_invalid_cache(
372-
width=1,
373-
depth=4,
372+
width=2,
373+
depth=3,
374374
prefill_length=inference_info.prefill_length - 1,
375375
kv_cache_position_ids=kv_cache_position_ids,
376376
batch_offset=inference_info.batch_offset,
@@ -472,15 +472,7 @@ def _flag_to_bool(value) -> bool:
472472
keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
473473
keep_indices = keep_indices
474474

475-
if not training_mode and self._is_spec_decoding and self._is_last_block:
476-
original_hidden_states = output_hidden_states
477-
batch_size, seq_len, hidden_size = original_hidden_states.shape
478-
device = original_hidden_states.device
479-
valid_mask = keep_indices >= 0
480-
batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices)
481-
valid_hidden_states = original_hidden_states[batch_idx[valid_mask], keep_indices[valid_mask], :]
482-
output_hidden_states = valid_hidden_states.unsqueeze(0)
483-
475+
484476
self._last_keep_indices = keep_indices + cache_len
485477
return (output_hidden_states, keep_indices) # Return output hidden states
486478

src/bloombee/server/block_functions.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,8 +1571,21 @@ async def process_microbatch_merged(mb_idx: int, mb_start: int, mb_end: int, tot
15711571
micro_hidden_list = [r[0] for r in results]
15721572
micro_keep_list = [r[1] for r in results]
15731573

1574+
1575+
padded_keep_list = []
1576+
for keep_indices in micro_keep_list:
1577+
current_len = keep_indices.shape[1] # dim 1, 当前是 1
1578+
pad_size = length_increment - current_len
1579+
if pad_size > 0:
1580+
# pad shape: [batch_size, pad_size]
1581+
pad_shape = (keep_indices.shape[0], pad_size)
1582+
padding = torch.full(pad_shape, -1, dtype=keep_indices.dtype, device=keep_indices.device)
1583+
keep_indices = torch.cat([keep_indices, padding], dim=1) # → [16, length_increment]
1584+
padded_keep_list.append(keep_indices)
1585+
1586+
15741587
hidden_states = merge_microbatch_outputs(micro_hidden_list, dim=0)
1575-
keep_indices = merge_microbatch_outputs(micro_keep_list, dim=0)
1588+
keep_indices = merge_microbatch_outputs(padded_keep_list, dim=0)
15761589

15771590
# Calculate overlap statistics
15781591
total_pipeline_time = (pipeline_end_time - pipeline_start_time) * 1000 # ms
@@ -1860,6 +1873,16 @@ async def process_microbatch(mb_idx: int, mb_start: int, mb_end: int):
18601873
dtype=torch.int64,
18611874
device=hidden_states.device
18621875
).unsqueeze(0).expand(hidden_states.shape[0], -1)
1876+
1877+
if is_spec_dec:
1878+
original_hidden_states = hidden_states
1879+
batch_size, seq_len, hidden_size = original_hidden_states.shape
1880+
device = original_hidden_states.device
1881+
valid_mask = keep_indices >= 0
1882+
batch_idx = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(keep_indices)
1883+
valid_hidden_states = original_hidden_states[batch_idx[valid_mask], keep_indices[valid_mask], :]
1884+
hidden_states = valid_hidden_states.unsqueeze(0)
1885+
18631886

18641887
serialize_start = perf_counter()
18651888
need_pruning_next = torch.tensor(0)

src/bloombee/server/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def __init__(
292292

293293
self.policy = Policy(
294294
gpu_batch_size, 1, # gpu_batch_size controls GPU cache allocation
295-
50, 50, # w_gpu_percent, w_cpu_percent
295+
100, 0, # w_gpu_percent, w_cpu_percent
296296
100, 0, # cache_gpu_percent=100% (GPU cache only holds micro_batch_size slots)
297297
100, 0, # act_gpu_percent, act_cpu_percent (activations on GPU)
298298
overlap=False, sep_layer=True, pin_weight=True,
@@ -324,13 +324,13 @@ def __init__(
324324
self.weight_home = array_1d(self.num_blocks, ValueHolder)
325325
self.path = os.path.join(tempfile.gettempdir(), 'data', 'llama_weights')
326326

327-
hidden_size = 4096
327+
hidden_size = 6656
328328
vocab_size = 32000
329329

330330
# Create configuration
331331
config = PruningConfig(
332332
method=PruningMethod.ADAPTIVE_NEURAL,
333-
neural_threshold=0.9,
333+
neural_threshold=0.5,
334334
simple_threshold=0.1
335335
)
336336

src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self.lm_head = MidLMHead(hidden_size=hidden_size, vocab_size=vocab_size).to("cuda")
5656
lm_head_weights_path = hf_hub_download(
5757
repo_id="xxiong59/lm-head-for-speculative-pruning",
58-
filename="lm_head_weights_15.pt",
58+
filename="lm_head_llama30B-15.pt",
5959
cache_dir="./cache"
6060
)
6161
lm_head_checkpoint = torch.load(lm_head_weights_path, map_location="cuda")

src/bloombee/server/speculative_pruner/pruner_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self.iteration = 0
4343
self.middle_states = None
4444

45-
train_lm_head_mode = True
45+
train_lm_head_mode = False
4646
self.lm_head_trainer = LM_head_trainer(hidden_size, vocab_size, device, config) if train_lm_head_mode else None
4747

4848
def switch_method(self, method: Union[str, PruningMethod], keep_stats: bool = False):

src/bloombee/utils/microbatch_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
# Default values
2626
# Micro-batch size for pipeline overlap. Each micro-batch writes to its own slice of the KV cache.
27-
DEFAULT_MICRO_BATCH_SIZE = 0 # Default micro-batch size for pipeline overlap
27+
DEFAULT_MICRO_BATCH_SIZE = 16 # Default micro-batch size for pipeline overlap
2828

2929

3030
def _is_microbatch_flag_enabled() -> bool:

0 commit comments

Comments
 (0)