Skip to content

Commit 58bc37f

Browse files
committed
[Feature] Add Eagle next-token prepare op
1 parent a9bf223 commit 58bc37f

4 files changed

Lines changed: 280 additions & 52 deletions

File tree

setup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
ext_modules = [
1414
CppExtension(
1515
name="vllm_kunlun._kunlun",
16-
sources=["vllm_kunlun/csrc/utils.cpp"],
16+
sources=[
17+
"vllm_kunlun/csrc/utils.cpp",
18+
"vllm_kunlun/csrc/eagle_prepare_next_token_ids.cpp",
19+
],
1720
include_dirs=[
1821
"vllm_kunlun/csrc",
1922
"/usr/local/cuda/include",
@@ -30,8 +33,9 @@ def run(self):
3033
for ext in self.extensions:
3134
ext_path = self.get_ext_fullpath(ext.name)
3235
file_name = os.path.basename(ext_path)
33-
target_path = os.path.join("vllm_kunlun", file_name)
36+
target_path = os.path.join(ROOT_DIR, "vllm_kunlun", file_name)
3437

38+
os.makedirs(os.path.dirname(target_path), exist_ok=True)
3539
if os.path.exists(target_path):
3640
os.remove(target_path)
3741
shutil.copyfile(ext_path, target_path)

tests/ut/test_eagle_cpp_ops.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""
2+
Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
3+
4+
This file is a part of the vllm-kunlun project.
5+
6+
Licensed under the Apache License, Version 2.0 (the "License");
7+
you may not use this file except in compliance with the License.
8+
You may obtain a copy of the License at
9+
10+
http://www.apache.org/licenses/LICENSE-2.0
11+
12+
Unless required by applicable law or agreed to in writing, software
13+
distributed under the License is distributed on an "AS IS" BASIS,
14+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
See the License for the specific language governing permissions and
16+
limitations under the License.
17+
"""
18+
19+
from __future__ import annotations
20+
21+
from types import SimpleNamespace
22+
from unittest.mock import patch
23+
24+
import numpy as np
25+
import torch
26+
27+
import vllm_kunlun._kunlun # noqa: F401
28+
29+
30+
def _reference_prepare_next_token_ids(
31+
sampled_token_ids: torch.Tensor,
32+
discard_request_indices: torch.Tensor,
33+
num_discarded_requests: int,
34+
backup_next_token_ids: torch.Tensor,
35+
vocab_size: int,
36+
) -> tuple[torch.Tensor, torch.Tensor]:
37+
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
38+
if num_discarded_requests > 0:
39+
idx = discard_request_indices[:num_discarded_requests]
40+
if idx.device != valid_sampled_token_ids_gpu.device:
41+
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
42+
if idx.dtype != torch.long:
43+
idx = idx.to(torch.long)
44+
if idx.numel() > 0:
45+
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
46+
47+
max_gen_len = sampled_token_ids.shape[-1]
48+
if max_gen_len == 1:
49+
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
50+
else:
51+
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
52+
valid_sampled_token_ids_gpu < vocab_size
53+
)
54+
55+
valid_sampled_tokens_count = valid_mask.sum(dim=1)
56+
last_valid_indices = valid_sampled_tokens_count - 1
57+
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
58+
selected_tokens = torch.gather(
59+
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
60+
).squeeze(1)
61+
next_token_ids = torch.where(
62+
last_valid_indices != -1,
63+
selected_tokens,
64+
backup_next_token_ids[: valid_sampled_token_ids_gpu.shape[0]],
65+
)
66+
return next_token_ids, valid_sampled_tokens_count
67+
68+
69+
def test_eagle_prepare_next_token_ids_without_discards_matches_reference():
70+
sampled = torch.tensor([[1, 3, 4], [5, 8, 9]], dtype=torch.int64)
71+
discard = torch.tensor([1, 0], dtype=torch.int32)
72+
backup = torch.tensor([10, 11], dtype=torch.int64)
73+
74+
result = torch.ops._C.eagle_prepare_next_token_ids_padded(
75+
sampled, discard, 0, backup, 100
76+
)
77+
78+
expected = _reference_prepare_next_token_ids(sampled, discard, 0, backup, 100)
79+
torch.testing.assert_close(result[0], expected[0])
80+
torch.testing.assert_close(result[1], expected[1])
81+
82+
83+
def test_eagle_prepare_next_token_ids_handles_partial_discards():
84+
sampled = torch.tensor([[3, 4], [7, 8], [9, 10]], dtype=torch.int64)
85+
discard = torch.tensor([1, 2, 0], dtype=torch.int32)
86+
backup = torch.tensor([20, 21, 22], dtype=torch.int64)
87+
88+
result = torch.ops._C.eagle_prepare_next_token_ids_padded(
89+
sampled, discard, 2, backup, 100
90+
)
91+
92+
expected = _reference_prepare_next_token_ids(sampled, discard, 2, backup, 100)
93+
torch.testing.assert_close(result[0], expected[0])
94+
torch.testing.assert_close(result[1], expected[1])
95+
96+
97+
def test_eagle_prepare_next_token_ids_handles_all_discards():
98+
sampled = torch.tensor([[3, 4], [7, 8]], dtype=torch.int64)
99+
discard = torch.tensor([0, 1], dtype=torch.int32)
100+
backup = torch.tensor([30, 31], dtype=torch.int64)
101+
102+
result = torch.ops._C.eagle_prepare_next_token_ids_padded(
103+
sampled, discard, 2, backup, 100
104+
)
105+
106+
expected = _reference_prepare_next_token_ids(sampled, discard, 2, backup, 100)
107+
torch.testing.assert_close(result[0], expected[0])
108+
torch.testing.assert_close(result[1], expected[1])
109+
110+
111+
def test_eagle_prepare_next_token_ids_treats_single_token_rows_as_valid():
112+
sampled = torch.tensor([[-1], [999]], dtype=torch.int64)
113+
discard = torch.tensor([], dtype=torch.int32)
114+
backup = torch.tensor([40, 41], dtype=torch.int64)
115+
116+
result = torch.ops._C.eagle_prepare_next_token_ids_padded(
117+
sampled, discard, 0, backup, 100
118+
)
119+
120+
expected = _reference_prepare_next_token_ids(sampled, discard, 0, backup, 100)
121+
torch.testing.assert_close(result[0], expected[0])
122+
torch.testing.assert_close(result[1], expected[1])
123+
124+
125+
def test_eagle_prepare_next_token_ids_filters_invalid_tokens_and_falls_back():
126+
sampled = torch.tensor([[-1, 2, 3], [101, 105, 2], [-1, -1, -1]], dtype=torch.int64)
127+
discard = torch.tensor([2], dtype=torch.int32)
128+
backup = torch.tensor([50, 51, 52], dtype=torch.int64)
129+
130+
result = torch.ops._C.eagle_prepare_next_token_ids_padded(
131+
sampled, discard, 1, backup, 100
132+
)
133+
134+
expected = _reference_prepare_next_token_ids(sampled, discard, 1, backup, 100)
135+
torch.testing.assert_close(result[0], expected[0])
136+
torch.testing.assert_close(result[1], expected[1])
137+
138+
139+
class _BackupNextTokenIds:
140+
def __init__(self, size: int):
141+
self.np = np.zeros(size, dtype=np.int64)
142+
self.gpu = torch.zeros(size, dtype=torch.int64)
143+
144+
def copy_to_gpu(self, size: int) -> None:
145+
self.gpu[:size] = torch.from_numpy(self.np[:size]).to(self.gpu.dtype)
146+
147+
148+
class _Request:
149+
def __init__(self, token_id: int):
150+
self.token_id = token_id
151+
152+
def get_token_id(self, _: int) -> int:
153+
return self.token_id
154+
155+
156+
def test_prepare_next_token_ids_padded_uses_cpp_op():
157+
from vllm_kunlun.v1.sample.spec_decode import eagle as eagle_module
158+
159+
sampled = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int64)
160+
discard = torch.tensor([1], dtype=torch.int32)
161+
expected_next = torch.tensor([3, 77], dtype=torch.int64)
162+
expected_counts = torch.tensor([3, 0], dtype=torch.int64)
163+
164+
proposer = SimpleNamespace(backup_next_token_ids=_BackupNextTokenIds(2))
165+
common_attn_metadata = SimpleNamespace(seq_lens_cpu=torch.tensor([4, 5]))
166+
gpu_input_batch = SimpleNamespace(num_reqs=2, req_ids=["a", "b"], vocab_size=100)
167+
requests = {"a": _Request(70), "b": _Request(77)}
168+
169+
with patch.object(
170+
torch.ops._C,
171+
"eagle_prepare_next_token_ids_padded",
172+
return_value=(expected_next, expected_counts),
173+
create=True,
174+
) as mocked:
175+
result = eagle_module.prepare_next_token_ids_padded(
176+
proposer,
177+
common_attn_metadata,
178+
sampled,
179+
requests,
180+
gpu_input_batch,
181+
discard,
182+
1,
183+
)
184+
185+
assert proposer.backup_next_token_ids.np[:2].tolist() == [70, 77]
186+
mocked.assert_called_once()
187+
torch.testing.assert_close(result[0], expected_next)
188+
torch.testing.assert_close(result[1], expected_counts)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
3+
*
4+
* This file is a part of the vllm-kunlun project.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
#include <torch/extension.h>
20+
21+
namespace {
22+
23+
std::tuple<torch::Tensor, torch::Tensor> eagle_prepare_next_token_ids_padded(
24+
const torch::Tensor& sampled_token_ids,
25+
const torch::Tensor& discard_request_indices,
26+
int64_t num_discarded_requests,
27+
const torch::Tensor& backup_next_token_ids,
28+
int64_t vocab_size) {
29+
TORCH_CHECK(sampled_token_ids.dim() == 2, "sampled_token_ids must be 2D");
30+
TORCH_CHECK(
31+
backup_next_token_ids.dim() == 1,
32+
"backup_next_token_ids must be 1D");
33+
TORCH_CHECK(
34+
backup_next_token_ids.size(0) >= sampled_token_ids.size(0),
35+
"backup_next_token_ids must have at least batch_size elements");
36+
37+
auto valid_sampled_token_ids_gpu = sampled_token_ids.clone();
38+
39+
if (num_discarded_requests > 0) {
40+
auto discard_indices =
41+
discard_request_indices.slice(0, 0, num_discarded_requests);
42+
discard_indices = discard_indices.to(
43+
valid_sampled_token_ids_gpu.device(), torch::kLong);
44+
if (discard_indices.numel() > 0) {
45+
valid_sampled_token_ids_gpu.index_fill_(0, discard_indices, -1);
46+
}
47+
}
48+
49+
torch::Tensor valid_mask;
50+
if (sampled_token_ids.size(1) == 1) {
51+
valid_mask = torch::ones_like(
52+
valid_sampled_token_ids_gpu,
53+
valid_sampled_token_ids_gpu.options().dtype(torch::kBool));
54+
} else {
55+
valid_mask =
56+
valid_sampled_token_ids_gpu.ne(-1) &
57+
valid_sampled_token_ids_gpu.lt(vocab_size);
58+
}
59+
60+
auto valid_sampled_tokens_count = valid_mask.sum(1);
61+
auto last_valid_indices = valid_sampled_tokens_count - 1;
62+
auto last_valid_indices_safe = torch::clamp_min(last_valid_indices, 0);
63+
auto selected_tokens =
64+
valid_sampled_token_ids_gpu.gather(1, last_valid_indices_safe.unsqueeze(1))
65+
.squeeze(1);
66+
auto next_token_ids = torch::where(
67+
last_valid_indices.ne(-1),
68+
selected_tokens,
69+
backup_next_token_ids.slice(0, 0, sampled_token_ids.size(0)));
70+
71+
return std::make_tuple(next_token_ids, valid_sampled_tokens_count);
72+
}
73+
74+
} // namespace
75+
76+
TORCH_LIBRARY_FRAGMENT(_C, m) {
77+
m.def(
78+
"eagle_prepare_next_token_ids_padded",
79+
&eagle_prepare_next_token_ids_padded);
80+
}

vllm_kunlun/v1/sample/spec_decode/eagle.py

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -284,58 +284,14 @@ def prepare_next_token_ids_padded(
284284
)
285285
self.backup_next_token_ids.copy_to_gpu(num_reqs)
286286

287-
# Mask out the sampled tokens indices that should not be sampled.
288-
discard_sampled_tokens_req_indices = discard_request_indices[
289-
:num_discarded_requests
290-
]
291-
292-
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
293-
# valid_sampled_token_ids_gpu.index_fill_(
294-
# 0, discard_sampled_tokens_req_indices, -1)
295-
# ---- FIX START ----
296-
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
297-
if num_discarded_requests > 0:
298-
# make sure index is on same device and is int64
299-
idx = discard_sampled_tokens_req_indices
300-
if idx.device != valid_sampled_token_ids_gpu.device:
301-
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
302-
if idx.dtype != torch.long:
303-
idx = idx.to(torch.long)
304-
if idx.numel() > 0:
305-
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
306-
# ---- FIX END ----
307-
# Generate a mask for all valid tokens within those requests
308-
max_gen_len = sampled_token_ids.shape[-1]
309-
if max_gen_len == 1:
310-
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
311-
else:
312-
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
313-
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
314-
)
315-
316-
# Count the number of valid tokens in each request
317-
valid_sampled_tokens_count = valid_mask.sum(dim=1)
318-
319-
# Get the rightmost valid index per row
320-
last_valid_indices = valid_sampled_tokens_count - 1
321-
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
322-
323-
# Get last valid token from each row
324-
# (assume undefined state where there is no valid token)
325-
selected_tokens = torch.gather(
326-
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
327-
).squeeze(1)
328-
329-
# Use last token if valid, pre-computed backup if not
330-
batch_size = valid_sampled_token_ids_gpu.shape[0]
331-
next_token_ids = torch.where(
332-
last_valid_indices != -1,
333-
selected_tokens,
334-
self.backup_next_token_ids.gpu[:batch_size],
287+
return torch.ops._C.eagle_prepare_next_token_ids_padded(
288+
sampled_token_ids,
289+
discard_request_indices,
290+
num_discarded_requests,
291+
self.backup_next_token_ids.gpu[:num_reqs],
292+
gpu_input_batch.vocab_size,
335293
)
336294

337-
return next_token_ids, valid_sampled_tokens_count
338-
339295

340296
EagleProposer.propose = propose
341297
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded

0 commit comments

Comments
 (0)