Skip to content

Commit c12ec40

Browse files
author
Xu Xiong
committed
test max throughput
1 parent 821cbd6 commit c12ec40

6 files changed

Lines changed: 37 additions & 21 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -127,31 +127,47 @@ def benchmark_inference(process_idx, args, result_pipe):
127127
elapsed_time = perf_counter() - start_time
128128

129129
original_output_ids = result
130+
131+
# 去掉输入部分,只保留生成的部分
132+
input_len = input_ids.shape[1]
133+
134+
# padding token ids 是 0, 1, 2
135+
PADDING_IDS = {0, 1, 2}
130136

131-
total_generated = 128 * batch_size
137+
input_lengths = [
138+
sum(1 for tok in sample.tolist() if tok not in PADDING_IDS)
139+
for sample in input_ids
140+
]
141+
142+
# 统计每个 sample 实际生成的 token 数(去掉 padding)
143+
per_sample_generated = [
144+
sum(1 for tok in result[i, input_lengths[i]:].tolist() if tok not in PADDING_IDS)
145+
for i in range(len(input_lengths))
146+
]
147+
148+
total_generated = sum(per_sample_generated)
132149
throughput = total_generated / elapsed_time
133-
150+
134151
logger.info(f"Group {group_idx} | "
135152
f"Total time: {elapsed_time:.4f}s | "
136153
f"Throughput: {throughput:.2f} tokens/s | "
137-
f"Generated tokens per sample: {total_generated}")
154+
f"Per-sample generated: {per_sample_generated} | "
155+
f"Total generated: {total_generated}")
138156

139-
# 保存结果
140-
result_label = "pruned" if getattr(args, 'pruning', False) else "unpruned"
141-
output_file = f"results_{result_label}_group_{group_idx}.json"
157+
logger.info(f"input_ids: {input_ids}")
158+
logger.info(f"original_output_ids: {original_output_ids}")
159+
142160
result_data = {
143161
"group_idx": group_idx,
144162
"pruning": getattr(args, 'pruning', False),
145163
"throughput": throughput,
146164
"elapsed_time": elapsed_time,
147165
"total_generated": total_generated,
148166
"generated_tokens_nums": total_generated,
167+
"per_sample_generated": per_sample_generated,
149168
"batch_size": batch_size,
150169
"max_new_tokens": max_new_tokens,
151170
}
152-
with open(output_file, "w") as f:
153-
json.dump(result_data, f, indent=2)
154-
logger.info(f"Results saved to {output_file}")
155171

156172
result_pipe.send(throughput)
157173

benchmarks/promps_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
batch_size = 32
77
num_groups = 10
88

9-
random.seed(42)
9+
random.seed(1)
1010
groups = []
1111
for i in range(num_groups):
1212
indices = random.sample(range(len(dataset)), batch_size)

eval_indices.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
[[41905, 7296, 1639, 48598, 18024, 16049, 14628, 9144, 48265, 6717, 44348, 48540, 35741, 5697, 38698, 27651, 2082, 1952, 6140, 14328, 15247, 33118, 39453, 1739, 36781, 13031, 46925, 42590, 45962, 35713, 27493, 14446], [29439, 38618, 18231, 425, 49729, 10463, 45753, 27696, 22298, 18210, 10189, 14110, 50036, 22059, 6698, 6078, 24898, 6338, 23526, 22541, 39565, 17335, 2847, 47823, 30108, 35142, 8180, 24807, 5164, 36178, 19213, 41198], [40535, 23700, 37837, 12601, 46174, 4558, 3003, 43336, 14935, 50663, 18965, 5229, 15256, 6619, 24911, 18217, 29714, 41660, 23909, 10659, 24260, 23283, 13730, 43920, 17496, 45994, 44796, 42469, 4679, 39920, 41613, 11215], [35005, 47784, 16043, 10708, 30294, 24867, 17691, 41943, 45099, 36500, 14392, 44866, 21252, 50352, 50855, 3665, 15010, 2103, 20673, 26290, 17546, 4337, 13826, 37170, 47049, 20622, 13934, 42954, 32717, 25928, 42129, 30071], [9363, 17359, 9150, 16162, 48823, 36789, 35322, 17219, 48956, 38311, 28077, 38242, 26175, 23723, 14373, 9065, 33392, 32343, 5957, 49530, 3087, 7185, 10016, 41120, 10484, 51909, 44596, 27666, 39086, 4163, 25216, 25009], [39052, 30674, 34676, 16476, 36256, 752, 44583, 47233, 7507, 44676, 35190, 49209, 17486, 50370, 42006, 22293, 7310, 19234, 28492, 10365, 29735, 212, 47323, 47164, 17261, 32806, 49935, 11708, 33271, 6973, 40979, 19558], [41874, 33270, 39909, 13035, 10016, 24504, 49971, 10587, 35348, 51028, 34757, 37, 39252, 21243, 32021, 1276, 7331, 23788, 20153, 15692, 3796, 15785, 37182, 5161, 5613, 47966, 31849, 4535, 49846, 34911, 50189, 8241], [8414, 43237, 31148, 36031, 10821, 17370, 34581, 39753, 27730, 13880, 35343, 49497, 47836, 45211, 13182, 46723, 20428, 26148, 44019, 42590, 24472, 28711, 33919, 29588, 7930, 16246, 14725, 4196, 22156, 1378, 38555, 36301], [15080, 38564, 14432, 471, 4652, 46389, 41359, 3858, 15003, 4417, 2058, 21654, 4643, 33695, 15597, 18250, 43842, 31812, 14040, 35339, 8671, 47405, 37423, 37762, 30976, 15925, 51420, 30996, 26677, 12478, 6181, 6352], [43187, 28249, 23219, 27759, 26941, 30606, 47780, 3550, 44129, 42824, 42348, 6449, 3972, 26386, 47724, 22236, 7161, 16295, 12556, 12465, 35146, 29400, 9186, 27648, 12025, 18254, 30318, 16371, 4940, 29041, 36066, 6416]]
1+
[[8805, 37303, 50054, 4135, 16716, 7727, 32468, 49870, 29457, 30949, 42702, 24878, 51689, 13759, 6151, 31972, 1857, 25546, 28361, 39809, 49956, 50276, 138, 45602, 29188, 17454, 47286, 14992, 38741, 6699, 20803, 2004], [1462, 1667, 42568, 35482, 603, 24982, 44989, 14195, 27663, 47569, 1903, 34578, 14528, 50049, 28697, 32493, 36232, 15275, 22655, 15130, 44357, 14338, 49869, 30120, 18991, 1408, 27274, 36467, 42093, 6553, 12183, 41245], [47424, 19424, 7922, 48702, 21803, 47283, 46608, 32820, 27663, 33273, 43929, 12441, 19881, 18622, 38507, 32726, 33114, 25778, 38600, 2262, 31472, 15908, 48741, 26495, 27152, 43564, 11338, 24059, 35966, 46074, 50845, 44203], [48379, 24556, 5666, 28767, 43500, 33320, 7073, 51016, 10728, 34140, 25772, 24282, 32092, 48022, 1938, 30757, 2849, 20219, 46096, 40292, 38874, 37891, 25794, 42412, 11164, 11048, 32914, 14872, 806, 50497, 13075, 35364], [35935, 15215, 26506, 33670, 22532, 37866, 23152, 30089, 17647, 43202, 35913, 39907, 47801, 374, 25145, 51357, 48529, 33587, 8470, 33992, 50947, 36789, 13466, 27924, 3678, 31529, 23903, 37355, 36333, 13096, 33077, 27092], [31780, 23382, 27159, 22680, 103, 35289, 35396, 40861, 51540, 40137, 21701, 30025, 39312, 1833, 15047, 41639, 11613, 36094, 38303, 11847, 6003, 36112, 16730, 2127, 44113, 4617, 5454, 1093, 29687, 954, 49423, 49518], [18428, 16355, 17605, 7175, 40947, 12098, 22572, 19024, 4555, 10975, 10461, 16725, 34562, 11019, 43034, 17885, 42480, 46634, 19299, 29799, 46047, 21102, 32538, 31049, 7483, 1548, 20447, 25333, 22501, 27585, 12323, 16935], [7127, 16610, 47851, 33430, 13702, 39691, 28288, 1364, 14770, 1170, 26038, 9598, 2315, 47109, 10500, 29207, 46177, 33181, 44444, 27961, 35697, 14457, 41338, 45550, 33855, 29546, 14627, 34334, 42500, 2011, 25880, 44230], [37738, 21053, 43242, 41349, 27937, 3852, 48329, 19569, 8236, 13902, 3109, 20079, 4635, 5009, 20339, 19521, 48748, 10368, 27274, 37023, 16538, 8545, 555, 36747, 2484, 38704, 14260, 37373, 30202, 11240, 51119, 46138], [40826, 33349, 2452, 24770, 13133, 22736, 6489, 13484, 37577, 44181, 28373, 38758, 12721, 32266, 6843, 43644, 25563, 19403, 33037, 32754, 1127, 21321, 40116, 26366, 18438, 1185, 10286, 13163, 21478, 36919, 51295, 8856]]

src/bloombee/models/llama/speculative_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def generate(
5555
# Keep the argument for API compatibility, but ignore runtime overrides.
5656
if "session_max_length" in model_kwargs:
5757
model_kwargs.pop("session_max_length", None)
58-
session_max_length = 350
58+
session_max_length = 170
5959
logger.info("Speculative session_max_length=%s (hardcoded)", session_max_length)
6060

6161
# Use inference session for proper distributed caching
@@ -134,7 +134,7 @@ def _sample_with_session(
134134
has_printed_first_reach = False # 确保只打印一次
135135
sample_finish_times = [None] * batch_size
136136
sample_finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
137-
while not finished and (seq_lengths - initial_seq_lengths).min().item() < max_new_tokens:
137+
while not finished and (seq_lengths - initial_seq_lengths).max().item() < max_new_tokens:
138138
# 1. Build speculative trees using SSM - 传入 seq_lengths
139139
t1 = time.perf_counter()
140140
spec_trees = drafter.build_trees_parallel(

src/bloombee/server/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,8 @@ def _flag_to_bool(value) -> bool:
481481

482482
is_prefill = kv_cache_position_ids is None or kv_cache_position_ids.numel() == 0
483483
if not training_mode and self._is_spec_decoding and self._need_pruning and self._is_last_block and not is_prefill:
484-
norm_hidden_states = self.module.rms_norm(output_hidden_states)
485-
keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
484+
# norm_hidden_states = self.module.rms_norm(output_hidden_states)
485+
# keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
486486
keep_indices = keep_indices
487487
# t7 = time.perf_counter()
488488
# logger.info(f"prune_draft_tree took {t7 - t6:.4f} seconds")

src/bloombee/server/handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -718,12 +718,12 @@ def _flag_to_bool(value: Any) -> bool:
718718
# print_time_now('')
719719
step_+=1 ###
720720
can_push_case_time=perf_counter() ###
721-
normalized_outputs = self._normalize_serialized_tensors(output_tensors)
722-
next_tensors = normalized_outputs + list(request.tensors[6:])
723-
push_tensor_bytes = sum(len(t.buffer) for t in next_tensors)
724-
NETWORK_SPEED_BYTES_PER_SEC = 62.5 * 1024 * 1024
725-
transfer_delay = push_tensor_bytes / NETWORK_SPEED_BYTES_PER_SEC + 0.025
726-
await asyncio.sleep(transfer_delay)
721+
# normalized_outputs = self._normalize_serialized_tensors(output_tensors)
722+
# next_tensors = normalized_outputs + list(request.tensors[6:])
723+
# push_tensor_bytes = sum(len(t.buffer) for t in next_tensors)
724+
# NETWORK_SPEED_BYTES_PER_SEC = 62.5 * 1024 * 1024
725+
# transfer_delay = push_tensor_bytes / NETWORK_SPEED_BYTES_PER_SEC + 0.025
726+
# await asyncio.sleep(transfer_delay)
727727
start_ExpertResponse_time=perf_counter() ###
728728
push_schedule_ms = (start_ExpertResponse_time - can_push_case_time) * 1000.0
729729
push_time.append(push_schedule_ms) ###

0 commit comments

Comments
 (0)