@@ -27,7 +27,6 @@ def swiglu(x, y=None):
2727
2828try :
2929 import deep_gemm
30- import FusedQuantOps as FQO
3130 import kitchen
3231 import kitchen .quantization_subchannel_block_hybrid
3332 from kitchen .quantization import QParams , ScalingType
@@ -343,7 +342,7 @@ def backward(ctx, do3):
343342 o1 = paddle .empty ([x_fp8 .shape [0 ], w1_fp8 .shape [0 ]], dtype = do3 .dtype )
344343 deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale .T ), (w1_fp8 , w1_sacle ), o1 )
345344
346- x_dequant_fp16 = FQO .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
345+ x_dequant_fp16 = paddle . incubate . nn . functional .fused_act_dequant (x_fp8 , x_scale .T .contiguous ())
347346 x_dequant_fp16 = padding (x_dequant_fp16 , 0 )
348347
349348 _ , _ , x_t_fp8 , x_t_scale = kitchen_quant (
@@ -468,7 +467,7 @@ def fwd_gate_up(self, x_bf16, expert_w1, num_expert, tokens_per_expert):
468467 self .tokens_per_expert = tokens_per_expert
469468 self .m_indices = gen_m_indices (tokens_per_expert )
470469 # concat w1, shape is [num_groups, n, k]
471- w1_t_quant , w1_t_scale = FQO . fused_stack_transpose_quant (expert_w1 )
470+ w1_t_quant , w1_t_scale = paddle . incubate . nn . functional . fused_stack_transpose_quant (expert_w1 , transpose = True )
472471 w1_t_quant = w1_t_quant .reshape ([num_expert , - 1 , w1_t_quant .shape [- 1 ]])
473472 w1_t_scale = w1_t_scale .reshape ([num_expert , - 1 , w1_t_scale .shape [- 1 ]])
474473
@@ -504,12 +503,14 @@ def fwd_down(self, o1, unzipped_probs, expert_w2, num_expert):
504503 [m_sum, k] = [m_sum, n] * [num_groups, n, k]
505504 """
506505 # concat and transpose w2
507- w2_quant , w2_sacle = FQO . fused_stack_transpose_quant (expert_w2 )
506+ w2_quant , w2_sacle = paddle . incubate . nn . functional . fused_stack_transpose_quant (expert_w2 , transpose = True )
508507 w2_quant = w2_quant .reshape ([num_expert , - 1 , w2_quant .shape [- 1 ]])
509508 w2_sacle = w2_sacle .reshape ([num_expert , - 1 , w2_sacle .shape [- 1 ]])
510509
511510 # quant o2
512- o2_fp8 , o2_scale = FQO .fused_spaq (o1 , unzipped_probs , using_pow2_scaling = True )
511+ o2_fp8 , o2_scale = paddle .incubate .nn .functional .fused_weighted_swiglu_act_quant (
512+ o1 , unzipped_probs , using_pow2_scaling = True
513+ )
513514 o2_scale = paddle .transpose (paddle .transpose (o2_scale , [1 , 0 ]).contiguous (), [1 , 0 ])
514515 self .unzipped_probs = unzipped_probs
515516
@@ -527,7 +528,9 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1):
527528 [m_sum, n] = [m_sum, k] * [num_groups, k, n]
528529 """
529530 # recompute concated_w2_2d
530- bw_w2_quant , bw_w2_scale = FQO .fused_stack_quant (expert_w2 )
531+ bw_w2_quant , bw_w2_scale = paddle .incubate .nn .functional .fused_stack_transpose_quant (
532+ expert_w2 , transpose = False
533+ )
531534 bw_w2_quant = bw_w2_quant .reshape ([len (expert_w2 ), - 1 , bw_w2_quant .shape [- 1 ]])
532535 bw_w2_scale = bw_w2_scale .reshape ([len (expert_w2 ), - 1 , bw_w2_scale .shape [- 1 ]])
533536
@@ -541,7 +544,7 @@ def bwd_dowm_input(self, expert_w2, unzipped_grad, o1):
541544 (unzipped_grad_fp8 , unzipped_grad_scale ), (bw_w2_quant , bw_w2_scale ), do2_s , m_indices = self .m_indices
542545 )
543546
544- do1 , probs_grad , o2_s = FQO . fused_swiglu_probs_bwd (o1 , do2_s , self .unzipped_probs )
547+ do1 , probs_grad , o2_s = paddle . incubate . nn . functional . fused_swiglu_weighted_bwd (o1 , do2_s , self .unzipped_probs )
545548
546549 return do1 , o2_s , probs_grad
547550
@@ -555,7 +558,9 @@ def bwd_gate_up_input(self, do1, expert_w1):
555558 [m_sum, k] = [m_sum, n] * [num_groups, n, k]
556559 """
557560 # recompute concated_w1_t
558- bw_w1_quant , bw_w1_scale = FQO .fused_stack_quant (expert_w1 )
561+ bw_w1_quant , bw_w1_scale = paddle .incubate .nn .functional .fused_stack_transpose_quant (
562+ expert_w1 , transpose = False
563+ )
559564 bw_w1_quant = bw_w1_quant .reshape ([len (expert_w1 ), - 1 , bw_w1_quant .shape [- 1 ]])
560565 bw_w1_scale = bw_w1_scale .reshape ([len (expert_w1 ), - 1 , bw_w1_scale .shape [- 1 ]])
561566
@@ -573,11 +578,7 @@ def bwd_gate_up_input(self, do1, expert_w1):
573578 return dx
574579
575580 def fused_transpose_split_quant (self , x , tokens_per_expert , pow_2_scales ):
576- out , scale = [], []
577- for tokens in tokens_per_expert :
578- out .append (paddle .empty ([x .shape [1 ], tokens ], dtype = "float8_e4m3fn" ))
579- scale .append (paddle .empty ([tokens // 128 , x .shape [1 ]], dtype = "float32" ))
580- FQO .fused_transpose_split_quant (x , out , scale , pow_2_scales )
581+ out , scale = paddle .incubate .nn .functional .fused_transpose_split_quant (x , tokens_per_expert , pow_2_scales )
581582 return out , scale
582583
583584 def bwd_down_weight (self , do3 , o2 , expert_w2 ):
@@ -681,7 +682,7 @@ def backward(self, out_grad):
681682 expert_w2 = [x .w2 for x in self .custom_map .experts if x is not None ]
682683
683684 if self .mem_efficient :
684- input = FQO .fused_act_dequant (self .input_fp8 , self .input_scale )
685+ input = paddle . incubate . nn . functional .fused_act_dequant (self .input_fp8 , self .input_scale )
685686 else :
686687 input = self .input
687688
0 commit comments