2323import torch
2424import xspeedgate_ops # noqa
2525from vllm .logger import init_logger
26+ from vllm .v1 .worker .workspace import current_workspace_manager
2627
2728logger = init_logger (__name__ )
2829
@@ -494,24 +495,60 @@ def fused_moe(
494495 )
495496 return out .sum (1 )
496497
498+ # Allocate two shared workspaces for the large temporary buffers
499+ # used by the preprocess, W1, activation, and W2 stages.
500+ y_numel = M * moe_top_k * w1 .shape [1 ]
501+ out_numel = M * moe_top_k * w2 .shape [1 ]
502+ out1_numel = M * moe_top_k * (w1 .shape [1 ] // 2 )
503+ moe_expand_numel = M * moe_top_k * N
504+
505+ # Live ranges:
506+ # M * moe_top_k <= 768, M >= 1024:
507+ # workspace_a: out
508+ # workspace_b: y
509+ # M * moe_top_k <= 768, M < 1024:
510+ # workspace_a: out1
511+ # workspace_b: y -> out
512+ # M * moe_top_k > 768, M >= 1024:
513+ # workspace_a: moe_expand -> out
514+ # workspace_b: y
515+ # M * moe_top_k > 768, M < 1024:
516+ # workspace_a: moe_expand -> out1
517+ # workspace_b: y -> out
518+ workspace_a_numel = out_numel
519+ workspace_b_numel = y_numel
520+
521+ if M < 1024 :
522+ workspace_a_numel = out1_numel
523+ workspace_b_numel = max (y_numel , out_numel )
524+
525+ if M * moe_top_k > 768 :
526+ workspace_a_numel = max (workspace_a_numel , moe_expand_numel )
527+
528+ workspace_a , workspace_b = current_workspace_manager ().get_simultaneous (
529+ ((workspace_a_numel ,), hidden_states .dtype ),
530+ ((workspace_b_numel ,), hidden_states .dtype ),
531+ )
532+
497533 if M * moe_top_k > 768 :
498- moe_expand = torch .empty (
499- (M * moe_top_k , N ),
500- dtype = hidden_states .dtype ,
501- device = hidden_states .device ,
502- ) # [M*top_k, N], float
503534 expert_m = torch .zeros (
504- global_num_experts , dtype = torch .int32 , device = hidden_states .device
535+ global_num_experts ,
536+ dtype = torch .int32 ,
537+ device = hidden_states .device ,
505538 ) # [E]
506539 sorted_tokens_num_lod = torch .zeros (
507540 global_num_experts + 1 ,
508541 dtype = torch .int32 ,
509542 device = hidden_states .device ,
510543 ) # [E+1]
511544 sorted_tokens_idx = torch .zeros (
512- M * moe_top_k , dtype = torch .int32 , device = hidden_states .device
545+ M * moe_top_k ,
546+ dtype = torch .int32 ,
547+ device = hidden_states .device ,
513548 )
514549
550+ moe_expand = workspace_a [:moe_expand_numel ].view (M * moe_top_k , N )
551+
515552 torch .ops ._C .gen_block_statistic (topk_ids , block_statistic )
516553
517554 torch .ops ._C .moe_pre_sorted (
@@ -534,15 +571,8 @@ def fused_moe(
534571 )
535572 )
536573
537- y = torch .empty (
538- M ,
539- moe_top_k ,
540- w1 .shape [1 ],
541- dtype = hidden_states .dtype ,
542- device = hidden_states .device ,
543- )
544-
545- moe_expand = moe_expand .view (M * moe_top_k , hidden_dim )
574+ moe_expand = moe_expand .reshape (M * moe_top_k , hidden_dim )
575+ y = workspace_b [:y_numel ].view (M , moe_top_k , w1 .shape [1 ])
546576
547577 if M < 1024 :
548578 torch .ops ._C .moe_fc (
@@ -553,13 +583,14 @@ def fused_moe(
553583 moe_topk = moe_top_k ,
554584 y = y ,
555585 )
556-
557- d = y .shape [- 1 ] // 2
558- output_shape = y .shape [:- 1 ] + (d ,)
559- out1 = torch .empty (output_shape , dtype = y .dtype , device = y .device )
586+ # Reuse `workspace_a` for `out1` after `moe_expand` is no longer
587+ # needed.
588+ out1 = workspace_a [:out1_numel ].view (M , moe_top_k , w1 .shape [1 ] // 2 )
560589 torch .ops ._C .silu_and_mul (out1 , y )
561-
562590 out1 = out1 .reshape (- 1 , out1 .shape [- 1 ])
591+ # Reuse `workspace_b` for `out` after `y` has been consumed by
592+ # the activation.
593+ out = workspace_b [:out_numel ].view (M , moe_top_k , w2 .shape [1 ])
563594 else :
564595 torch .ops ._C .moe_fc (
565596 x = moe_expand ,
@@ -573,13 +604,12 @@ def fused_moe(
573604
574605 y = y [..., : y .shape [- 1 ] // 2 ]
575606 out1 = y .reshape (- 1 , y .shape [- 1 ])
607+ # Reuse `workspace_a` for `out` after `moe_expand` is no longer
608+ # needed.
609+ out = workspace_a [:out_numel ].view (M , moe_top_k , w2 .shape [1 ])
576610
577- out = torch .empty (
578- M ,
579- moe_top_k ,
580- w2 .shape [1 ],
581- dtype = hidden_states .dtype ,
582- device = hidden_states .device ,
611+ dequant_scale = torch .ones (
612+ (M , moe_top_k ), dtype = torch .float32 , device = hidden_states .device
583613 )
584614
585615 torch .ops ._C .moe_fc (
@@ -591,9 +621,6 @@ def fused_moe(
591621 y = out ,
592622 )
593623
594- dequant_scale = torch .ones (
595- [M , moe_top_k ], dtype = torch .float32 , device = out .device
596- )
597624 output = torch .empty (
598625 [M , N ], dtype = hidden_states .dtype , device = hidden_states .device
599626 )
@@ -629,7 +656,6 @@ def fused_moe_ep(
629656 batch , hidden_size = x .shape
630657 num_local_experts , up_gate_size , _ = w13_weight .shape
631658
632-
633659 topk_weights = torch .empty (
634660 batch , top_k , dtype = router_logits .dtype , device = router_logits .device
635661 )
0 commit comments