Skip to content

Commit cf99cb2

Browse files
committed
change api
1 parent 7ecb1dc commit cf99cb2

2 files changed

Lines changed: 19 additions & 25 deletions

File tree

paddlenlp/transformers/fp8_utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def swiglu(x, y=None):
2727

2828
try:
2929
import deep_gemm
30-
import FusedQuantOps as FQO
3130
import kitchen
3231
import kitchen.quantization_subchannel_block_hybrid
3332
from kitchen.quantization import QParams, ScalingType
@@ -343,7 +342,7 @@ def backward(ctx, do3):
343342
o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype)
344343
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_sacle), o1)
345344

346-
x_dequant_fp16 = FQO.fused_act_dequant(x_fp8, x_scale.T.contiguous())
345+
x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous())
347346
x_dequant_fp16 = padding(x_dequant_fp16, 0)
348347

349348
_, _, x_t_fp8, x_t_scale = kitchen_quant(
@@ -468,7 +467,7 @@ def fwd_gate_up(self, x_bf16, expert_w1, num_expert, tokens_per_expert):
468467
self.tokens_per_expert = tokens_per_expert
469468
self.m_indices = gen_m_indices(tokens_per_expert)
470469
# concat w1, shape is [num_groups, n, k]
471-
w1_t_quant, w1_t_scale = FQO.fused_stack_transpose_quant(expert_w1)
470+
w1_t_quant, w1_t_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w1, transpose=True)
472471
w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]])
473472
w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]])
474473

@@ -504,12 +503,14 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert):
504503
[m_sum, k] = [m_sum, n] * [num_groups, n, k]
505504
"""
506505
# concat and transpose w2
507-
w2_quant, w2_sacle = FQO.fused_stack_transpose_quant(expert_w2)
506+
w2_quant, w2_sacle = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_w2, transpose=True)
508507
w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]])
509508
w2_sacle = w2_sacle.reshape([num_expert, -1, w2_sacle.shape[-1]])
510509

511510
# quant o2
512-
o2_fp8, o2_scale = FQO.fused_spaq(o1, unzipped_probs, using_pow2_scaling=True)
511+
o2_fp8, o2_scale = paddle.incubate.nn.functional.fused_weighted_swiglu_act_quant(
512+
o1, unzipped_probs, using_pow2_scaling=True
513+
)
513514
o2_scale = paddle.transpose(paddle.transpose(o2_scale, [1, 0]).contiguous(), [1, 0])
514515
self.unzipped_probs = unzipped_probs
515516

@@ -527,7 +528,9 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1):
527528
[m_sum, n] = [m_sum, k] * [num_groups, k, n]
528529
"""
529530
# recompute concated_w2_2d
530-
bw_w2_quant, bw_w2_scale = FQO.fused_stack_quant(expert_w2)
531+
bw_w2_quant, bw_w2_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
532+
expert_w2, transpose=False
533+
)
531534
bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]])
532535
bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]])
533536

@@ -541,7 +544,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1):
541544
(unzipped_grad_fp8, unzipped_grad_scale), (bw_w2_quant, bw_w2_scale), do2_s, m_indices=self.m_indices
542545
)
543546

544-
do1, probs_grad, o2_s = FQO.fused_swiglu_probs_bwd(o1, do2_s, self.unzipped_probs)
547+
do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd(o1, do2_s, self.unzipped_probs)
545548

546549
return do1, o2_s, probs_grad
547550

@@ -555,7 +558,9 @@ def bwd_gate_up_input(self, do1, expert_w1):
555558
[m_sum, k] = [m_sum, n] * [num_groups, n, k]
556559
"""
557560
# recompute concated_w1_t
558-
bw_w1_quant, bw_w1_scale = FQO.fused_stack_quant(expert_w1)
561+
bw_w1_quant, bw_w1_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(
562+
expert_w1, transpose=False
563+
)
559564
bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]])
560565
bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]])
561566

@@ -573,11 +578,7 @@ def bwd_gate_up_input(self, do1, expert_w1):
573578
return dx
574579

575580
def fused_transpose_split_quant(self, x, tokens_per_expert, pow_2_scales):
576-
out, scale = [], []
577-
for tokens in tokens_per_expert:
578-
out.append(paddle.empty([x.shape[1], tokens], dtype="float8_e4m3fn"))
579-
scale.append(paddle.empty([tokens // 128, x.shape[1]], dtype="float32"))
580-
FQO.fused_transpose_split_quant(x, out, scale, pow_2_scales)
581+
out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant(x, tokens_per_expert, pow_2_scales)
581582
return out, scale
582583

583584
def bwd_down_weight(self, do3, o2, expert_w2):
@@ -681,7 +682,7 @@ def backward(self, out_grad):
681682
expert_w2 = [x.w2 for x in self.custom_map.experts if x is not None]
682683

683684
if self.mem_efficient:
684-
input = FQO.fused_act_dequant(self.input_fp8, self.input_scale)
685+
input = paddle.incubate.nn.functional.fused_act_dequant(self.input_fp8, self.input_scale)
685686
else:
686687
input = self.input
687688

paddlenlp/transformers/moe_utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@
1818

1919
from .fp8_utils import dequantize_fp8_to_fp32
2020

21-
try:
22-
import TokenDispatcherUtils as TDU
23-
except:
24-
pass
25-
2621

2722
def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk):
2823
x = paddle.flatten(x)
@@ -120,12 +115,11 @@ def forward(
120115
num_experts,
121116
tokens_per_expert,
122117
):
123-
(unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, _,) = TDU.tokens_unzip_stable(
118+
(unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, _,) = paddle.nn.functional.moe_permute(
124119
hs_2d_dispatched,
125120
None,
126121
dispatched_indices,
127122
dispatched_probs,
128-
topk=topk,
129123
num_experts=num_experts,
130124
tokens_per_expert=tokens_per_expert,
131125
padding_multiplex=128,
@@ -140,7 +134,7 @@ def forward(
140134

141135
@paddle.no_grad()
142136
def backward(self, dx, hidden_states_out_grad, probs_grad, dispatched_indices, num_experts):
143-
weighted_zipped_tokens, probs_grad_zipped = TDU.tokens_zip(
137+
weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute(
144138
dx,
145139
self.zipped_expertwise_rowmap,
146140
dispatched_indices,
@@ -161,7 +155,7 @@ def __init__(self, token_dispatcher, name="zip"):
161155
def forward(
162156
self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts
163157
):
164-
expert_out_zipped, zipped_probs_topk = TDU.tokens_zip(
158+
expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute(
165159
expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts
166160
)
167161
return expert_out_zipped
@@ -176,12 +170,11 @@ def backward(
176170
num_experts,
177171
tokens_per_expert,
178172
):
179-
(unzipped_grad, zipped_expertwise_rowmap_grad, unzipped_probs_grad, _,) = TDU.tokens_unzip_stable(
173+
(unzipped_grad, zipped_expertwise_rowmap_grad, unzipped_probs_grad, _,) = paddle.nn.functional.moe_permute(
180174
grad_output,
181175
None,
182176
dispatched_indices,
183177
dispatched_probs,
184-
top_k,
185178
num_experts,
186179
tokens_per_expert,
187180
padding_multiplex=128,

0 commit comments

Comments
 (0)