Skip to content

Commit 7e98095

Browse files
Marshall-GeCopilot
andauthored
[Kernel]Reuse vLLM workspace manager for Kunlun MoE scratch tensors (#283)
Signed-off-by: Marshall-Ge <1004083966@qq.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 583a830 commit 7e98095

1 file changed

Lines changed: 57 additions & 31 deletions

File tree

vllm_kunlun/ops/_kunlun_ops.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
import xspeedgate_ops # noqa
2525
from vllm.logger import init_logger
26+
from vllm.v1.worker.workspace import current_workspace_manager
2627

2728
logger = init_logger(__name__)
2829

@@ -494,24 +495,60 @@ def fused_moe(
494495
)
495496
return out.sum(1)
496497

498+
# Allocate two shared workspaces for the large temporary buffers
499+
# used by the preprocess, W1, activation, and W2 stages.
500+
y_numel = M * moe_top_k * w1.shape[1]
501+
out_numel = M * moe_top_k * w2.shape[1]
502+
out1_numel = M * moe_top_k * (w1.shape[1] // 2)
503+
moe_expand_numel = M * moe_top_k * N
504+
505+
# Live ranges:
506+
# M * moe_top_k <= 768, M >= 1024:
507+
# workspace_a: out
508+
# workspace_b: y
509+
# M * moe_top_k <= 768, M < 1024:
510+
# workspace_a: out1
511+
# workspace_b: y -> out
512+
# M * moe_top_k > 768, M >= 1024:
513+
# workspace_a: moe_expand -> out
514+
# workspace_b: y
515+
# M * moe_top_k > 768, M < 1024:
516+
# workspace_a: moe_expand -> out1
517+
# workspace_b: y -> out
518+
workspace_a_numel = out_numel
519+
workspace_b_numel = y_numel
520+
521+
if M < 1024:
522+
workspace_a_numel = out1_numel
523+
workspace_b_numel = max(y_numel, out_numel)
524+
525+
if M * moe_top_k > 768:
526+
workspace_a_numel = max(workspace_a_numel, moe_expand_numel)
527+
528+
workspace_a, workspace_b = current_workspace_manager().get_simultaneous(
529+
((workspace_a_numel,), hidden_states.dtype),
530+
((workspace_b_numel,), hidden_states.dtype),
531+
)
532+
497533
if M * moe_top_k > 768:
498-
moe_expand = torch.empty(
499-
(M * moe_top_k, N),
500-
dtype=hidden_states.dtype,
501-
device=hidden_states.device,
502-
) # [M*top_k, N], float
503534
expert_m = torch.zeros(
504-
global_num_experts, dtype=torch.int32, device=hidden_states.device
535+
global_num_experts,
536+
dtype=torch.int32,
537+
device=hidden_states.device,
505538
) # [E]
506539
sorted_tokens_num_lod = torch.zeros(
507540
global_num_experts + 1,
508541
dtype=torch.int32,
509542
device=hidden_states.device,
510543
) # [E+1]
511544
sorted_tokens_idx = torch.zeros(
512-
M * moe_top_k, dtype=torch.int32, device=hidden_states.device
545+
M * moe_top_k,
546+
dtype=torch.int32,
547+
device=hidden_states.device,
513548
)
514549

550+
moe_expand = workspace_a[:moe_expand_numel].view(M * moe_top_k, N)
551+
515552
torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
516553

517554
torch.ops._C.moe_pre_sorted(
@@ -534,15 +571,8 @@ def fused_moe(
534571
)
535572
)
536573

537-
y = torch.empty(
538-
M,
539-
moe_top_k,
540-
w1.shape[1],
541-
dtype=hidden_states.dtype,
542-
device=hidden_states.device,
543-
)
544-
545-
moe_expand = moe_expand.view(M * moe_top_k, hidden_dim)
574+
moe_expand = moe_expand.reshape(M * moe_top_k, hidden_dim)
575+
y = workspace_b[:y_numel].view(M, moe_top_k, w1.shape[1])
546576

547577
if M < 1024:
548578
torch.ops._C.moe_fc(
@@ -553,13 +583,14 @@ def fused_moe(
553583
moe_topk=moe_top_k,
554584
y=y,
555585
)
556-
557-
d = y.shape[-1] // 2
558-
output_shape = y.shape[:-1] + (d,)
559-
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
586+
# Reuse `workspace_a` for `out1` after `moe_expand` is no longer
587+
# needed.
588+
out1 = workspace_a[:out1_numel].view(M, moe_top_k, w1.shape[1] // 2)
560589
torch.ops._C.silu_and_mul(out1, y)
561-
562590
out1 = out1.reshape(-1, out1.shape[-1])
591+
# Reuse `workspace_b` for `out` after `y` has been consumed by
592+
# the activation.
593+
out = workspace_b[:out_numel].view(M, moe_top_k, w2.shape[1])
563594
else:
564595
torch.ops._C.moe_fc(
565596
x=moe_expand,
@@ -573,13 +604,12 @@ def fused_moe(
573604

574605
y = y[..., : y.shape[-1] // 2]
575606
out1 = y.reshape(-1, y.shape[-1])
607+
# Reuse `workspace_a` for `out` after `moe_expand` is no longer
608+
# needed.
609+
out = workspace_a[:out_numel].view(M, moe_top_k, w2.shape[1])
576610

577-
out = torch.empty(
578-
M,
579-
moe_top_k,
580-
w2.shape[1],
581-
dtype=hidden_states.dtype,
582-
device=hidden_states.device,
611+
dequant_scale = torch.ones(
612+
(M, moe_top_k), dtype=torch.float32, device=hidden_states.device
583613
)
584614

585615
torch.ops._C.moe_fc(
@@ -591,9 +621,6 @@ def fused_moe(
591621
y=out,
592622
)
593623

594-
dequant_scale = torch.ones(
595-
[M, moe_top_k], dtype=torch.float32, device=out.device
596-
)
597624
output = torch.empty(
598625
[M, N], dtype=hidden_states.dtype, device=hidden_states.device
599626
)
@@ -629,7 +656,6 @@ def fused_moe_ep(
629656
batch, hidden_size = x.shape
630657
num_local_experts, up_gate_size, _ = w13_weight.shape
631658

632-
633659
topk_weights = torch.empty(
634660
batch, top_k, dtype=router_logits.dtype, device=router_logits.device
635661
)

0 commit comments

Comments
 (0)