Skip to content

Commit 5efbb83

Browse files
authored
[habana_main]enable padding_aware_scheduler for speculative decoding (#1264)
performance number is at https://jira.habana-labs.com/browse/SW-211730 Tested with MTP deepseek R1 with batch_size == 4, see 1.6x improvement w/wo MTP in throughput. ``` Method: no_spec Acceptance rate: [0, 0] Method: mtp Acceptance rate: [2180, 1918] ``` with batch_size == 32, see 1.5x improvement ``` Method: no_spec Acceptance rate: [0, 0] Method: mtp Acceptance rate: [17408, 15371] ``` with batch_size == 128, see 1.35x improvement ``` Method: no_spec Acceptance rate:  [0, 0] Method: mtp Acceptance rate:  [69988, 61143] ``` with batch_size == 256, see no obvious improvement (which is reasonable because we hit maximum batch_size) ``` method: no_spec Acceptance rate: [0, 0] Method: mtp Acceptance rate: [139286, 122973] ``` Signed-off-by: Chendi Xue <[email protected]>
1 parent 1b40abb commit 5efbb83

File tree

5 files changed

+8
-13
lines changed

5 files changed

+8
-13
lines changed

tests/spec_decode/e2e/test_eagle_correctness.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@
5252
5353
# Main model
5454
"model_name": MAIN_MODEL,
55-
56-
# schedule
57-
"use_padding_aware_scheduling": False,
5855
}])
5956
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
6057
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

tests/spec_decode/e2e/test_medusa_correctness.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,6 @@
5555
5656
# Main model
5757
"model_name": MAIN_MODEL,
58-
59-
# schedule
60-
"use_padding_aware_scheduling": False,
6158
}])
6259
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
6360
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@
5757
5858
# Main model
5959
"model_name": MAIN_MODEL,
60-
61-
# schedule
62-
"use_padding_aware_scheduling": False,
6360
}])
6461
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
6562
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

tests/spec_decode/e2e/test_mtp_correctness.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@
5252
5353
# GPU memory utilization
5454
"gpu_memory_utilization": 0.85,
55-
56-
# scheduler
57-
"use_padding_aware_scheduling": False,
5855
}])
5956
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
6057
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])

vllm/sequence.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1336,7 +1336,14 @@ def prune(self,
13361336
seq_ids = get_all_seq_ids(seq_group_metadata_list)
13371337
if seq_ids != self._seq_ids:
13381338
# Batch contents changed - prune removed sequences.
1339-
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1339+
if len(seq_ids) < len(self._seq_ids):
1340+
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
1341+
else:
1342+
# This path is added for use_padding_aware_scheduling
1343+
index = [
1344+
self._seq_ids.index(seq_id)
1345+
if seq_id in self._seq_ids else 0 for seq_id in seq_ids
1346+
]
13401347
self.hidden_states = self.hidden_states[index]
13411348
if self.second_last_token_hidden_states is not None:
13421349
self.second_last_token_hidden_states = self\

0 commit comments

Comments
 (0)