Skip to content

Commit 4cc631d

Browse files
authored
[Feature] improve random_sample (#359)
Signed-off-by: yuejun <yuejun@baidu.com>
1 parent 09869ee commit 4cc631d

3 files changed

Lines changed: 3 additions & 12 deletions

File tree

vllm_kunlun/ops/fla/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ def get_available_device() -> str:
121121

122122

123123
@functools.cache
124-
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
124+
def _check_platform() -> Literal["nvidia", "amd", "kunlun"]:
125125
device = get_available_device()
126126
mapping = {
127127
"cuda": "nvidia",
128128
"hip": "amd",
129-
"xpu": "intel",
129+
"xpu": "kunlun",
130130
}
131131
# return the mapped value, or the original if not found
132132
return mapping.get(device, device)

vllm_kunlun/v1/attention/backends/kunlun_attn.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717
import copy
1818
from dataclasses import dataclass
19-
from itertools import accumulate
2019
from typing import (
2120
TYPE_CHECKING,
2221
Any,
@@ -603,13 +602,6 @@ def build(
603602
seq_lens = common_attn_metadata.seq_lens
604603
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
605604

606-
seq_start_loc = list(accumulate(seq_lens, initial=0))
607-
608-
seq_start_loc_tensor = torch.empty(
609-
len(seq_start_loc), dtype=torch.int32, device=self.device
610-
)
611-
seq_start_loc_tensor.copy_(torch.as_tensor(seq_start_loc, dtype=torch.int32))
612-
613605
kv_lod_cpu = torch.zeros(num_reqs + 1, dtype=torch.int32, device="cpu")
614606
kv_lod_cpu[1:] = seq_lens_cpu.to(torch.int32).cumsum(dim=0)
615607
kv_lod_xpu = kv_lod_cpu.to(self.device)

vllm_kunlun/v1/sample/ops/topk_topp_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,7 @@ def random_sample(
169169
q = q.clamp(min=1e-12)
170170
else:
171171
for i, generator in generators.items():
172-
q[i].exponential_(generator=generator)
173-
172+
torch.ops.xspeedgate_ops.inplace_exponential(q[i], generator=generator)
174173
return probs.div_(q).argmax(dim=-1).view(-1)
175174

176175

0 commit comments

Comments
 (0)