1919 mega_fp4_pack ,
2020 mega_fp4_quantize_kernel ,
2121 mega_fp4_unpack ,
22+ nvfp4_quantize_stacked ,
2223 triton_quantize_mx4_unpack ,
2324 triton_quantize_nvfp4 ,
2425)
@@ -1237,11 +1238,6 @@ class FP8RowwiseGrouped(GemmOpBase):
12371238 FP8 grouped matmul with rowwise scaling.
12381239 """
12391240
1240- @property
1241- def name (self ) -> str :
1242- prefix = "Cutlass" if torch .version .cuda else "CK"
1243- return f"{ prefix } { self .__class__ .__name__ } "
1244-
12451241 def preprocess (self , x , w ):
12461242 m_values = [i .shape [0 ] for i in x ]
12471243 m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
@@ -2429,50 +2425,28 @@ class CutlassNVFP4GroupwiseGrouped(GemmOpBase):
24292425 def preprocess (self , x , w ):
24302426 m_values = [i .shape [0 ] for i in x ]
24312427 m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2432- x = torch .concat (x , dim = 0 ).contiguous ()
2428+ x_cat = torch .concat (x , dim = 0 ).contiguous ()
24332429
2434- def get_global_scale (x , w , m_sizes ):
2435- G = len (w )
2436- w_global_scale = []
2437- global_scale = []
2430+ G = m_sizes .numel ()
24382431
2439- cumulative_sum = torch .zeros (
2440- m_sizes .shape [0 ] + 1 , dtype = torch .int64 , device = m_sizes .device
2432+ # w_global_scale is static (weights don't change)
2433+ w_global_scale = []
2434+ for i in range (G ):
2435+ w_gs = (448.0 * 6.0 ) / torch .amax (torch .abs (w [i ].flatten ()), dim = - 1 ).to (
2436+ torch .float32
24412437 )
2442- cumulative_sum [1 :] = torch .cumsum (m_sizes , dim = 0 )
2443-
2444- x_global_scale , tensor_idx = calculate_group_max (x , m_sizes = m_sizes )
2445-
2446- for i in range (G ):
2447- w_global_scale_ = (448.0 * 6.0 ) / torch .amax (
2448- torch .abs (w [i ].flatten ()), dim = - 1
2449- ).to (torch .float32 )
2450-
2451- global_scale_ = 1 / (x_global_scale [i ] * w_global_scale_ )
2452-
2453- w_global_scale .append (w_global_scale_ )
2454- global_scale .append (global_scale_ )
2455-
2456- return x_global_scale , w_global_scale , global_scale , tensor_idx
2457-
2458- # Compute global scale for each group
2459- G = m_sizes .numel ()
2460- x_global_scale , w_global_scale , global_scale , tensor_idx = get_global_scale (
2461- x , w , m_sizes
2462- )
2463- global_scale = torch .stack (global_scale , dim = 0 ).contiguous ()
2438+ w_global_scale .append (w_gs )
2439+ w_global_scale = torch .stack (w_global_scale , dim = 0 ).contiguous ()
24642440
24652441 wq , w_scale = zip (
24662442 * [triton_quantize_nvfp4 (w [i ], w_global_scale [i ]) for i in range (G )]
24672443 )
24682444 wq = torch .stack (wq , dim = 0 ).contiguous ()
24692445 w_scale = torch .stack (w_scale , dim = 0 ).contiguous ()
24702446
2471- return x , wq , w_scale , x_global_scale , global_scale , m_sizes , tensor_idx
2447+ return x_cat , wq , w_scale , w_global_scale , m_sizes
24722448
2473- def quantize (
2474- self , x , wq , w_scale , x_global_scale , global_scale , m_sizes , tensor_idx
2475- ):
2449+ def quantize (self , x , wq , w_scale , w_global_scale , m_sizes ):
24762450 # alternative methods, may be useful in some scenarios
24772451 """
24782452 starting_row_after_padding, belong_indices, row_within_tensor = (
@@ -2489,6 +2463,10 @@ def quantize(
24892463 )
24902464 """
24912465
2466+ x_global_scale , tensor_idx = calculate_group_max (x , m_sizes = m_sizes )
2467+
2468+ global_scale = 1.0 / (x_global_scale * w_global_scale )
2469+
24922470 # we can optionally set optional_tensor_idx to None to run the alternative method
24932471 xq , x_scale , starting_row_after_padding = mega_fp4_quantize_kernel (
24942472 m_sizes , x , x_global_scale , optional_tensor_idx = tensor_idx
@@ -2527,9 +2505,7 @@ def compute(
25272505 )
25282506 return gemm_result
25292507
2530- def quantize_and_compute (
2531- self , x , wq , w_scale , x_global_scale , global_scale , m_sizes , tensor_idx
2532- ):
2508+ def quantize_and_compute (self , x , wq , w_scale , w_global_scale , m_sizes ):
25332509 (
25342510 xq ,
25352511 wq ,
@@ -2538,9 +2514,7 @@ def quantize_and_compute(
25382514 m_sizes ,
25392515 global_scale ,
25402516 starting_row_after_padding ,
2541- ) = self .quantize (
2542- x , wq , w_scale , x_global_scale , global_scale , m_sizes , tensor_idx
2543- )
2517+ ) = self .quantize (x , wq , w_scale , w_global_scale , m_sizes )
25442518 return self .compute (
25452519 xq ,
25462520 wq ,
@@ -2564,6 +2538,105 @@ def compute_dtype(self) -> ComputeDtype:
25642538 return ComputeDtype .FP4
25652539
25662540
2541+ @register_gemm_op
2542+ class CutlassNVFP4TorchGrouped (GemmOpBase ):
2543+ """
2544+ NVFP4 grouped matmul using per-expert global scales for activation
2545+ quantization (stacked_nvfp4_quantize), with per-expert alpha scales
2546+ applied post-GEMM via the torch offsets API.
2547+ """
2548+
2549+ def preprocess (self , x , w ):
2550+ m_values = [i .shape [0 ] for i in x ]
2551+ m_sizes = torch .tensor (m_values ).to (dtype = torch .int64 , device = x [0 ].device )
2552+ x_cat = torch .concat (x , dim = 0 ).contiguous ()
2553+
2554+ G = m_sizes .numel ()
2555+ N_per_expert = w [0 ].shape [0 ]
2556+ K = w [0 ].shape [1 ]
2557+
2558+ # Batch-quantize all expert weights in one shot using stacked kernel
2559+ w_cat = torch .cat (w , dim = 0 ).contiguous () # [G*N, K]
2560+ w_m_sizes = torch .full (
2561+ (G ,), N_per_expert , dtype = torch .int64 , device = w_cat .device
2562+ )
2563+ w_global_scale , _ = calculate_group_max (w_cat , w_m_sizes )
2564+ wq , w_scale_2d = nvfp4_quantize_stacked (w_m_sizes , w_cat , w_global_scale )
2565+
2566+ # Reshape to [G, N, ...] for the GEMM
2567+ wq = wq .view (G , N_per_expert , K // 2 )
2568+ padded_N = (N_per_expert + 127 ) // 128 * 128
2569+ w_scale = w_scale_2d [: G * padded_N ].view (G , padded_N , - 1 )
2570+
2571+ # Precompute offsets for the torch API (cumulative end indices, int32)
2572+ offsets = torch .cumsum (m_sizes , dim = 0 ).to (torch .int32 )
2573+
2574+ return x_cat , wq , w_scale , w_global_scale , m_sizes , offsets
2575+
2576+ def quantize (self , x , wq , w_scale , w_global_scale , m_sizes , offsets ):
2577+ x_global_scale , _ = calculate_group_max (x , m_sizes = m_sizes )
2578+ # global_scale = 1 / (x_gs * w_gs) per expert
2579+ global_scale = 1.0 / (x_global_scale * w_global_scale )
2580+
2581+ xq , x_scale = nvfp4_quantize_stacked (m_sizes , x , x_global_scale )
2582+ return (
2583+ xq ,
2584+ wq ,
2585+ x_scale ,
2586+ w_scale ,
2587+ global_scale ,
2588+ offsets ,
2589+ )
2590+
2591+ def compute (
2592+ self ,
2593+ xq ,
2594+ wq ,
2595+ x_scale ,
2596+ w_scale ,
2597+ global_scale ,
2598+ offsets ,
2599+ ):
2600+ return torch .ops .mslk .f4f4bf16_grouped_mm (
2601+ xq ,
2602+ wq .transpose (- 2 , - 1 ),
2603+ x_scale ,
2604+ w_scale ,
2605+ offsets ,
2606+ global_scale = global_scale ,
2607+ )
2608+
2609+ def quantize_and_compute (self , x , wq , w_scale , w_global_scale , m_sizes , offsets ):
2610+ (
2611+ xq ,
2612+ wq ,
2613+ x_scale ,
2614+ w_scale ,
2615+ global_scale ,
2616+ offsets ,
2617+ ) = self .quantize (x , wq , w_scale , w_global_scale , m_sizes , offsets )
2618+ return self .compute (
2619+ xq ,
2620+ wq ,
2621+ x_scale ,
2622+ w_scale ,
2623+ global_scale ,
2624+ offsets ,
2625+ )
2626+
2627+ @property
2628+ def supported_accelerators (self ) -> set [Accelerator ]:
2629+ return {Accelerator .NVIDIA_SM100 , Accelerator .NVIDIA_SM103 }
2630+
2631+ @property
2632+ def supported_gemm_types (self ) -> set [GemmType ]:
2633+ return {GemmType .GROUPED }
2634+
2635+ @property
2636+ def compute_dtype (self ) -> ComputeDtype :
2637+ return ComputeDtype .FP4
2638+
2639+
25672640# Broken with cuda graph
25682641# @register_gemm_op
25692642class CutlassNVFP4GroupwiseStackedGroupedPackUnpack (GemmOpBase ):
@@ -2761,7 +2834,7 @@ def compute(self, x, w, offs):
27612834 )
27622835
27632836 def quantize_and_compute (self , x , w , offs ):
2764- x , w , offs = self .quantize (x , w )
2837+ x , w , offs = self .quantize (x , w , offs )
27652838 return self .compute (x , w , offs )
27662839
27672840 @property
0 commit comments