Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 100 additions & 51 deletions vllm_kunlun/quantization/compressed_tensors/compressed_tensors_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
)
from vllm.v1.worker.workspace import current_workspace_manager

from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
from vllm_kunlun.quantization.kernels.quant_ops import dequant_int4_native
Expand Down Expand Up @@ -169,9 +170,6 @@ def apply_monolithic(
scale=routed_scaling_factor,
)

moe_expand = torch.empty(
(M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device
) # [M, top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
Expand All @@ -180,8 +178,47 @@ def apply_monolithic(
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * top_k, dtype=torch.int32, device=hidden_states.device
) # [M * top_k]

moe_expand_numel = M * top_k * N
y_numel = M * top_k * layer.w13_weight.shape[1]
out1_numel = M * top_k * (layer.w13_weight.shape[1] // 2)
out_numel = M * top_k * layer.w2_weight.shape[1]
x_q_numel = max(moe_expand_numel, out1_numel)
x_scale_shape = (M * top_k, 1)

# Reuse the workspace according to live ranges:
# M < 1024:
# workspace_a: moe_expand -> out1
# workspace_b: y -> out
# M >= 1024:
# workspace_a: moe_expand -> out
# workspace_b: y
workspace_a_numel = max(moe_expand_numel, out_numel)
workspace_b_numel = y_numel

if M < 1024:
workspace_a_numel = max(moe_expand_numel, out1_numel)
workspace_b_numel = max(y_numel, out_numel)

workspace_c_numel = x_q_numel

(
workspace_a,
workspace_b,
workspace_c,
workspace_d,
) = current_workspace_manager().get_simultaneous(
((workspace_a_numel,), hidden_states.dtype),
((workspace_b_numel,), hidden_states.dtype),
((workspace_c_numel,), torch.int8),
(x_scale_shape, torch.float32),
)

moe_expand = workspace_a[:moe_expand_numel].view(
M * top_k, N
) # [M * top_k, N], float

torch.ops._C.gen_block_statistic(topk_ids, block_statistic)

torch.ops._C.moe_pre_sorted(
Expand All @@ -194,59 +231,71 @@ def apply_monolithic(
sorted_tokens_num_lod=sorted_tokens_num_lod,
)

y = torch.empty(
M,
top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)

moe_expand = moe_expand.view(M * top_k, hidden_dim)

x_shape = moe_expand.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty(
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
)
x_q = workspace_c[: moe_expand.numel()].view(moe_expand.shape)
x_scale = workspace_d
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)

torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w13_weight,
w_perchannel_max=layer.w13_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act=None,
)

d = y.shape[-1] // 2
output_shape = y.shape[:-1] + (d,)
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
torch.ops._C.silu_and_mul(out1, y)

del y
y = workspace_b[:y_numel].view(M, top_k, layer.w13_weight.shape[1])

if M < 1024:
torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w13_weight,
w_perchannel_max=layer.w13_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act=None,
)

out1 = out1.reshape(-1, out1.shape[-1])
x_shape = out1.shape
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty(
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
)
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
del out1, moe_expand
out = torch.empty(
M,
top_k,
layer.w2_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
d = y.shape[-1] // 2
# Reuse `workspace_a` for `out1` after `moe_expand` is no longer
# needed.
out1 = workspace_a[:out1_numel].view(M, top_k, d)
torch.ops._C.silu_and_mul(out1, y)
del y

out1 = out1.reshape(-1, out1.shape[-1])
x_q = workspace_c[: out1.numel()].view(out1.shape)
x_scale = workspace_d
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
del out1, moe_expand

# Reuse `workspace_b` for `out` after `y` has been consumed by
# the first FC.
out = workspace_b[:out_numel].view(M, top_k, layer.w2_weight.shape[1])
else:
torch.ops._C.moe_fc(
x=x_q,
x_perchannel_max=x_scale,
weight=layer.w13_weight,
w_perchannel_max=layer.w13_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act="SWISH_GLU",
)
del x_q, x_scale, moe_expand

y = y[..., : y.shape[-1] // 2]
out1 = y.reshape(-1, y.shape[-1])
x_q = workspace_c[: out1.numel()].view(out1.shape)
x_scale = workspace_d
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
del out1, y

# Reuse `workspace_a` for `out` after `moe_expand` is no longer
# needed.
out = workspace_a[:out_numel].view(M, top_k, layer.w2_weight.shape[1])

torch.ops._C.moe_fc(
x=x_q,
Expand Down
Loading