From c8ff361224d590746025c226b14c9744b960c7bf Mon Sep 17 00:00:00 2001 From: Marshall-Ge <1004083966@qq.com> Date: Wed, 8 Apr 2026 21:54:45 +0800 Subject: [PATCH 1/5] optim memory using workspace manager in quantization Signed-off-by: Marshall-Ge <1004083966@qq.com> --- .../compressed_tensors_moe.py | 59 +++++++++++++------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py index 2c944380..a7019b94 100644 --- a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py @@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, ) +from vllm.v1.worker.workspace import current_workspace_manager from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops from vllm_kunlun.quantization.kernels.quant_ops import dequant_int4_native @@ -169,9 +170,6 @@ def apply_monolithic( scale=routed_scaling_factor, ) - moe_expand = torch.empty( - (M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device - ) # [M, top_k, N], float expert_m = torch.zeros( global_num_experts, dtype=torch.int32, device=hidden_states.device ) # [E] @@ -180,7 +178,34 @@ def apply_monolithic( ) # [E+1] sorted_tokens_idx = torch.zeros( M * top_k, dtype=torch.int32, device=hidden_states.device + ) # [M * top_k] + + y_numel = M * top_k * layer.w13_weight.shape[1] + out_numel = M * top_k * layer.w2_weight.shape[1] + out1_numel = M * top_k * hidden_dim + moe_expand_numel = M * top_k * N + + # Live ranges: + # M < 1024: + # workspace_a: moe_expand -> out1 + # workspace_b: y -> out + # M >= 1024: + # workspace_a: moe_expand -> out + # workspace_b: y + workspace_a_numel = max(moe_expand_numel, out_numel) + workspace_b_numel = y_numel + + if M < 1024: + workspace_a_numel = max(moe_expand_numel, out1_numel) + workspace_b_numel = max(y_numel, out_numel) + + workspace_a, workspace_b = current_workspace_manager().get_simultaneous( + ((workspace_a_numel,), hidden_states.dtype), + ((workspace_b_numel,), hidden_states.dtype), ) + moe_expand = workspace_a[:moe_expand_numel].view( + M * top_k, N + ) # [M * top_k, N], float torch.ops._C.gen_block_statistic(topk_ids, block_statistic) @@ -194,13 +219,7 @@ def apply_monolithic( sorted_tokens_num_lod=sorted_tokens_num_lod, ) - y = torch.empty( - M, - top_k, - layer.w13_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + y = workspace_b[:y_numel].view(M, top_k, layer.w13_weight.shape[1]) moe_expand = moe_expand.view(M * top_k, hidden_dim) @@ -226,8 +245,9 @@ def apply_monolithic( ) d = y.shape[-1] // 2 - output_shape = y.shape[:-1] + (d,) - out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) + # Reuse `workspace_a` for `out1` after `moe_expand` is no longer + # needed. + out1 = workspace_a[:out1_numel].view(M, top_k, d) torch.ops._C.silu_and_mul(out1, y) del y @@ -240,13 +260,14 @@ def apply_monolithic( ) torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) del out1, moe_expand - out = torch.empty( - M, - top_k, - layer.w2_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + if M < 1024: + # Reuse `workspace_b` for `out` after `y` has been consumed by + # the first FC. + out = workspace_b[:out_numel].view(M, top_k, layer.w2_weight.shape[1]) + else: + # Reuse `workspace_a` for `out` after `moe_expand` is no longer + # needed. + out = workspace_a[:out_numel].view(M, top_k, layer.w2_weight.shape[1]) torch.ops._C.moe_fc( x=x_q, From aaf3cbf5ff0624b53677e61e0d6a345edbc1c4e0 Mon Sep 17 00:00:00 2001 From: Marshall-Ge <1004083966@qq.com> Date: Fri, 10 Apr 2026 12:01:11 +0800 Subject: [PATCH 2/5] temp save Signed-off-by: Marshall-Ge <1004083966@qq.com> --- .../compressed_tensors_moe.py | 65 ++++++++----------- 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py index a7019b94..6f9b06eb 100644 --- a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py @@ -170,42 +170,31 @@ def apply_monolithic( scale=routed_scaling_factor, ) - expert_m = torch.zeros( - global_num_experts, dtype=torch.int32, device=hidden_states.device - ) # [E] - sorted_tokens_num_lod = torch.zeros( - global_num_experts + 1, dtype=torch.int32, device=hidden_states.device - ) # [E+1] - sorted_tokens_idx = torch.zeros( - M * top_k, dtype=torch.int32, device=hidden_states.device - ) # [M * top_k] - + moe_expand_numel = M * top_k * N y_numel = M * top_k * layer.w13_weight.shape[1] + out1_numel = M * top_k * (layer.w13_weight.shape[1] // 2) out_numel = M * top_k * layer.w2_weight.shape[1] - out1_numel = M * top_k * hidden_dim - moe_expand_numel = M * top_k * N - - # Live ranges: - # M < 1024: - # workspace_a: moe_expand -> out1 - # workspace_b: y -> out - # M >= 1024: - # workspace_a: moe_expand -> out - # workspace_b: y - workspace_a_numel = max(moe_expand_numel, out_numel) - workspace_b_numel = y_numel - if M < 1024: - workspace_a_numel = max(moe_expand_numel, out1_numel) - workspace_b_numel = max(y_numel, out_numel) + workspace_a_numel = max(moe_expand_numel, out1_numel) + workspace_b_numel = max(y_numel, out_numel) workspace_a, workspace_b = current_workspace_manager().get_simultaneous( ((workspace_a_numel,), hidden_states.dtype), ((workspace_b_numel,), hidden_states.dtype), ) + moe_expand = workspace_a[:moe_expand_numel].view( M * top_k, N - ) # [M * top_k, N], float + ) # [M, top_k, N], float + expert_m = torch.zeros( + global_num_experts, dtype=torch.int32, device=hidden_states.device + ) # [E] + sorted_tokens_num_lod = torch.zeros( + global_num_experts + 1, dtype=torch.int32, device=hidden_states.device + ) # [E+1] + sorted_tokens_idx = torch.zeros( + M * top_k, dtype=torch.int32, device=hidden_states.device + ) torch.ops._C.gen_block_statistic(topk_ids, block_statistic) @@ -219,7 +208,11 @@ def apply_monolithic( sorted_tokens_num_lod=sorted_tokens_num_lod, ) - y = workspace_b[:y_numel].view(M, top_k, layer.w13_weight.shape[1]) + y = workspace_b[:y_numel].view( + M, + top_k, + layer.w13_weight.shape[1], + ) moe_expand = moe_expand.view(M * top_k, hidden_dim) @@ -245,9 +238,8 @@ def apply_monolithic( ) d = y.shape[-1] // 2 - # Reuse `workspace_a` for `out1` after `moe_expand` is no longer - # needed. - out1 = workspace_a[:out1_numel].view(M, top_k, d) + output_shape = y.shape[:-1] + (d,) + out1 = workspace_a[:out1_numel].view(output_shape) torch.ops._C.silu_and_mul(out1, y) del y @@ -260,14 +252,11 @@ def apply_monolithic( ) torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) del out1, moe_expand - if M < 1024: - # Reuse `workspace_b` for `out` after `y` has been consumed by - # the first FC. - out = workspace_b[:out_numel].view(M, top_k, layer.w2_weight.shape[1]) - else: - # Reuse `workspace_a` for `out` after `moe_expand` is no longer - # needed. - out = workspace_a[:out_numel].view(M, top_k, layer.w2_weight.shape[1]) + out = workspace_b[:out_numel].view( + M, + top_k, + layer.w2_weight.shape[1], + ) torch.ops._C.moe_fc( x=x_q, From fc247cadddc7f449817b63c2dec25869cd081c31 Mon Sep 17 00:00:00 2001 From: Marshall-Ge <1004083966@qq.com> Date: Fri, 10 Apr 2026 16:05:33 +0800 Subject: [PATCH 3/5] optim Signed-off-by: Marshall-Ge <1004083966@qq.com> --- .../compressed_tensors_moe.py | 135 +++++++++++------- 1 file changed, 86 insertions(+), 49 deletions(-) diff --git a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py index 6f9b06eb..4e4c8447 100644 --- a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py @@ -170,13 +170,34 @@ def apply_monolithic( scale=routed_scaling_factor, ) + expert_m = torch.zeros( + global_num_experts, dtype=torch.int32, device=hidden_states.device + ) # [E] + sorted_tokens_num_lod = torch.zeros( + global_num_experts + 1, dtype=torch.int32, device=hidden_states.device + ) # [E+1] + sorted_tokens_idx = torch.zeros( + M * top_k, dtype=torch.int32, device=hidden_states.device + ) # [M * top_k] + moe_expand_numel = M * top_k * N y_numel = M * top_k * layer.w13_weight.shape[1] out1_numel = M * top_k * (layer.w13_weight.shape[1] // 2) out_numel = M * top_k * layer.w2_weight.shape[1] - workspace_a_numel = max(moe_expand_numel, out1_numel) - workspace_b_numel = max(y_numel, out_numel) + # Reuse the workspace according to live ranges: + # M < 1024: + # workspace_a: moe_expand -> out1 + # workspace_b: y -> out + # M >= 1024: + # workspace_a: moe_expand -> out + # workspace_b: y + workspace_a_numel = max(moe_expand_numel, out_numel) + workspace_b_numel = y_numel + + if M < 1024: + workspace_a_numel = max(moe_expand_numel, out1_numel) + workspace_b_numel = max(y_numel, out_numel) workspace_a, workspace_b = current_workspace_manager().get_simultaneous( ((workspace_a_numel,), hidden_states.dtype), @@ -185,16 +206,7 @@ def apply_monolithic( moe_expand = workspace_a[:moe_expand_numel].view( M * top_k, N - ) # [M, top_k, N], float - expert_m = torch.zeros( - global_num_experts, dtype=torch.int32, device=hidden_states.device - ) # [E] - sorted_tokens_num_lod = torch.zeros( - global_num_experts + 1, dtype=torch.int32, device=hidden_states.device - ) # [E+1] - sorted_tokens_idx = torch.zeros( - M * top_k, dtype=torch.int32, device=hidden_states.device - ) + ) # [M * top_k, N], float torch.ops._C.gen_block_statistic(topk_ids, block_statistic) @@ -208,12 +220,6 @@ def apply_monolithic( sorted_tokens_num_lod=sorted_tokens_num_lod, ) - y = workspace_b[:y_numel].view( - M, - top_k, - layer.w13_weight.shape[1], - ) - moe_expand = moe_expand.view(M * top_k, hidden_dim) x_shape = moe_expand.shape @@ -223,40 +229,71 @@ def apply_monolithic( ) torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) - torch.ops._C.moe_fc( - x=x_q, - x_perchannel_max=x_scale, - weight=layer.w13_weight, - w_perchannel_max=layer.w13_weight_scale, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=top_k, - y=y, - topk_ids=topk_ids, - # sort_mode=False, - act=None, - ) + y = workspace_b[:y_numel].view(M, top_k, layer.w13_weight.shape[1]) + + if M < 1024: + torch.ops._C.moe_fc( + x=x_q, + x_perchannel_max=x_scale, + weight=layer.w13_weight, + w_perchannel_max=layer.w13_weight_scale, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=top_k, + y=y, + topk_ids=topk_ids, + # sort_mode=False, + act=None, + ) - d = y.shape[-1] // 2 - output_shape = y.shape[:-1] + (d,) - out1 = workspace_a[:out1_numel].view(output_shape) - torch.ops._C.silu_and_mul(out1, y) + d = y.shape[-1] // 2 + # Reuse `workspace_a` for `out1` after `moe_expand` is no longer + # needed. + out1 = workspace_a[:out1_numel].view(M, top_k, d) + torch.ops._C.silu_and_mul(out1, y) + del y + + out1 = out1.reshape(-1, out1.shape[-1]) + x_shape = out1.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=out1.device) + x_scale = torch.empty( + (x_shape[0], 1), dtype=torch.float32, device=out1.device + ) + torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) + del out1, moe_expand - del y + # Reuse `workspace_b` for `out` after `y` has been consumed by + # the first FC. + out = workspace_b[:out_numel].view(M, top_k, layer.w2_weight.shape[1]) + else: + torch.ops._C.moe_fc( + x=x_q, + x_perchannel_max=x_scale, + weight=layer.w13_weight, + w_perchannel_max=layer.w13_weight_scale, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=top_k, + y=y, + topk_ids=topk_ids, + # sort_mode=False, + act="SWISH_GLU", + ) + del x_q, x_scale, moe_expand + + y = y[..., : y.shape[-1] // 2] + out1 = y.reshape(-1, y.shape[-1]) + x_shape = out1.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=out1.device) + x_scale = torch.empty( + (x_shape[0], 1), dtype=torch.float32, device=out1.device + ) + torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) + del out1, y - out1 = out1.reshape(-1, out1.shape[-1]) - x_shape = out1.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) - x_scale = torch.empty( - (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device - ) - torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) - del out1, moe_expand - out = workspace_b[:out_numel].view( - M, - top_k, - layer.w2_weight.shape[1], - ) + # Reuse `workspace_a` for `out` after `moe_expand` is no longer + # needed. + out = workspace_a[:out_numel].view(M, top_k, layer.w2_weight.shape[1]) torch.ops._C.moe_fc( x=x_q, From f6fab8a20f8ff1052bd848e64fe6370d0f5255d5 Mon Sep 17 00:00:00 2001 From: Marshall-Ge <1004083966@qq.com> Date: Fri, 10 Apr 2026 17:43:29 +0800 Subject: [PATCH 4/5] optim Signed-off-by: Marshall-Ge <1004083966@qq.com> --- .../compressed_tensors_moe.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py index 4e4c8447..a2f48fc5 100644 --- a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py @@ -184,6 +184,8 @@ def apply_monolithic( y_numel = M * top_k * layer.w13_weight.shape[1] out1_numel = M * top_k * (layer.w13_weight.shape[1] // 2) out_numel = M * top_k * layer.w2_weight.shape[1] + x_q_numel = max(moe_expand_numel, out1_numel) + x_scale_shape = (M * top_k, 1) # Reuse the workspace according to live ranges: # M < 1024: @@ -199,9 +201,15 @@ def apply_monolithic( workspace_a_numel = max(moe_expand_numel, out1_numel) workspace_b_numel = max(y_numel, out_numel) - workspace_a, workspace_b = current_workspace_manager().get_simultaneous( - ((workspace_a_numel,), hidden_states.dtype), - ((workspace_b_numel,), hidden_states.dtype), + workspace_c_numel = x_q_numel + + workspace_a, workspace_b, workspace_c, workspace_d = ( + current_workspace_manager().get_simultaneous( + ((workspace_a_numel,), hidden_states.dtype), + ((workspace_b_numel,), hidden_states.dtype), + ((workspace_c_numel,), torch.int8), + (x_scale_shape, torch.float32), + ) ) moe_expand = workspace_a[:moe_expand_numel].view( @@ -222,11 +230,8 @@ def apply_monolithic( moe_expand = moe_expand.view(M * top_k, hidden_dim) - x_shape = moe_expand.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) - x_scale = torch.empty( - (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device - ) + x_q = workspace_c[:moe_expand.numel()].view(moe_expand.shape) + x_scale = workspace_d torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) y = workspace_b[:y_numel].view(M, top_k, layer.w13_weight.shape[1]) @@ -254,11 +259,8 @@ def apply_monolithic( del y out1 = out1.reshape(-1, out1.shape[-1]) - x_shape = out1.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=out1.device) - x_scale = torch.empty( - (x_shape[0], 1), dtype=torch.float32, device=out1.device - ) + x_q = workspace_c[:out1.numel()].view(out1.shape) + x_scale = workspace_d torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) del out1, moe_expand @@ -283,11 +285,8 @@ def apply_monolithic( y = y[..., : y.shape[-1] // 2] out1 = y.reshape(-1, y.shape[-1]) - x_shape = out1.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=out1.device) - x_scale = torch.empty( - (x_shape[0], 1), dtype=torch.float32, device=out1.device - ) + x_q = workspace_c[:out1.numel()].view(out1.shape) + x_scale = workspace_d torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) del out1, y From a20048cf1a3023d22f43c8e918d9c0e6bbd05efb Mon Sep 17 00:00:00 2001 From: Marshall-Ge <1004083966@qq.com> Date: Fri, 10 Apr 2026 17:43:50 +0800 Subject: [PATCH 5/5] optim Signed-off-by: Marshall-Ge <1004083966@qq.com> --- .../compressed_tensors_moe.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py index a2f48fc5..40d1cc2b 100644 --- a/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py @@ -203,13 +203,16 @@ def apply_monolithic( workspace_c_numel = x_q_numel - workspace_a, workspace_b, workspace_c, workspace_d = ( - current_workspace_manager().get_simultaneous( - ((workspace_a_numel,), hidden_states.dtype), - ((workspace_b_numel,), hidden_states.dtype), - ((workspace_c_numel,), torch.int8), - (x_scale_shape, torch.float32), - ) + ( + workspace_a, + workspace_b, + workspace_c, + workspace_d, + ) = current_workspace_manager().get_simultaneous( + ((workspace_a_numel,), hidden_states.dtype), + ((workspace_b_numel,), hidden_states.dtype), + ((workspace_c_numel,), torch.int8), + (x_scale_shape, torch.float32), ) moe_expand = workspace_a[:moe_expand_numel].view( @@ -230,7 +233,7 @@ def apply_monolithic( moe_expand = moe_expand.view(M * top_k, hidden_dim) - x_q = workspace_c[:moe_expand.numel()].view(moe_expand.shape) + x_q = workspace_c[: moe_expand.numel()].view(moe_expand.shape) x_scale = workspace_d torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) @@ -259,7 +262,7 @@ def apply_monolithic( del y out1 = out1.reshape(-1, out1.shape[-1]) - x_q = workspace_c[:out1.numel()].view(out1.shape) + x_q = workspace_c[: out1.numel()].view(out1.shape) x_scale = workspace_d torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) del out1, moe_expand @@ -285,7 +288,7 @@ def apply_monolithic( y = y[..., : y.shape[-1] // 2] out1 = y.reshape(-1, y.shape[-1]) - x_q = workspace_c[:out1.numel()].view(out1.shape) + x_q = workspace_c[: out1.numel()].view(out1.shape) x_scale = workspace_d torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) del out1, y