@@ -591,6 +591,7 @@ std::vector<at::Tensor> moe_fused_gate(at::Tensor& input,
591591 {
592592 case 256 :
593593 if (num_expert_group == 8 )
594+ {
594595 // This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4,
595596 // ROWS_PER_CTA = 6 * 4 = 24.
596597 if (input.scalar_type () == at::kBFloat16 )
@@ -605,23 +606,27 @@ std::vector<at::Tensor> moe_fused_gate(at::Tensor& input,
605606 {
606607 LAUNCH_MOE_GATE_CONFIG (float32_t , 256 , 8 );
607608 }
608- else if (num_expert_group == 16 )
609- // Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
610- if (input.scalar_type () == at::kBFloat16 )
611- {
612- LAUNCH_MOE_GATE_CONFIG (bfloat16_t , 256 , 16 );
613- }
614- else if (input.scalar_type () == at::kHalf )
615- {
616- LAUNCH_MOE_GATE_CONFIG (float16_t , 256 , 16 );
617- }
618- else if (input.scalar_type () == at::kFloat )
619- {
620- LAUNCH_MOE_GATE_CONFIG (float32_t , 256 , 16 );
621- }
609+ }
610+ else if (num_expert_group == 16 )
611+ {
612+ // Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
613+ if (input.scalar_type () == at::kBFloat16 )
614+ {
615+ LAUNCH_MOE_GATE_CONFIG (bfloat16_t , 256 , 16 );
616+ }
617+ else if (input.scalar_type () == at::kHalf )
618+ {
619+ LAUNCH_MOE_GATE_CONFIG (float16_t , 256 , 16 );
620+ }
621+ else if (input.scalar_type () == at::kFloat )
622+ {
623+ LAUNCH_MOE_GATE_CONFIG (float32_t , 256 , 16 );
624+ }
625+ }
622626 break ;
623627 case 128 :
624628 if (num_expert_group == 4 )
629+ {
625630 // VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12.
626631 if (input.scalar_type () == at::kBFloat16 )
627632 {
@@ -635,20 +640,23 @@ std::vector<at::Tensor> moe_fused_gate(at::Tensor& input,
635640 {
636641 LAUNCH_MOE_GATE_CONFIG (float32_t , 128 , 4 );
637642 }
638- else if (num_expert_group == 8 )
639- // VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
640- if (input.scalar_type () == at::kBFloat16 )
641- {
642- LAUNCH_MOE_GATE_CONFIG (bfloat16_t , 128 , 8 );
643- }
644- else if (input.scalar_type () == at::kHalf )
645- {
646- LAUNCH_MOE_GATE_CONFIG (float16_t , 128 , 8 );
647- }
648- else if (input.scalar_type () == at::kFloat )
649- {
650- LAUNCH_MOE_GATE_CONFIG (float32_t , 128 , 8 );
651- }
643+ }
644+ else if (num_expert_group == 8 )
645+ {
646+ // VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24.
647+ if (input.scalar_type () == at::kBFloat16 )
648+ {
649+ LAUNCH_MOE_GATE_CONFIG (bfloat16_t , 128 , 8 );
650+ }
651+ else if (input.scalar_type () == at::kHalf )
652+ {
653+ LAUNCH_MOE_GATE_CONFIG (float16_t , 128 , 8 );
654+ }
655+ else if (input.scalar_type () == at::kFloat )
656+ {
657+ LAUNCH_MOE_GATE_CONFIG (float32_t , 128 , 8 );
658+ }
659+ }
652660 break ;
653661 default : break ;
654662 }
0 commit comments