@@ -51,6 +51,8 @@ def run_norm_test(args):
5151 return testRmsnormFp4quant (args )
5252 elif args .routine == "add_rmsnorm_fp4quant" :
5353 return testAddRmsnormFp4quant (args )
54+ elif args .routine == "fused_rmsnorm_silu" :
55+ return testFusedRmsnormSilu (args )
5456 else :
5557 raise ValueError (f"Unsupported routine: { args .routine } " )
5658
@@ -1078,3 +1080,122 @@ def run_backend(backend, input_tensor, residual_tensor, weight):
10781080 cur_res ["case_tag" ] = args .case_tag
10791081 res .append (cur_res )
10801082 return res
1083+
1084+
1085+ def testFusedRmsnormSilu (args ):
1086+ """
1087+ Test fused_rmsnorm_silu API (RMSNorm + SiLU activation).
1088+
1089+ This test:
1090+ 1. Generates random input tensors
1091+ 2. Runs fused_rmsnorm_silu with bf16 output
1092+ 3. Optionally runs reference check
1093+ 4. Measures performance metrics (memory bandwidth)
1094+
1095+ Args:
1096+ args: Parsed command line arguments containing test configuration
1097+
1098+ Returns:
1099+ dict: List of dictionaries containing performance results
1100+ """
1101+ if args .verbose >= 1 :
1102+ print ("[INFO] Running testFusedRmsnormSilu" )
1103+ print (f"[INFO] FlashInfer version: { flashinfer .__version__ } " )
1104+
1105+ device = get_device (args )
1106+ if args .generate_repro_command :
1107+ print (
1108+ f"[INFO] To reproduce this test case, run the following command: { args .repro_command } "
1109+ )
1110+
1111+ batch_size = args .batch_size
1112+ hidden_size = args .hidden_size
1113+ eps = args .eps
1114+ is_cuda_graph_compatible = not args .no_cuda_graph
1115+ run_refcheck = args .refcheck
1116+ res = []
1117+
1118+ input_dtype = dtype_str_to_torch_dtype (args .input_dtype )
1119+ if input_dtype != torch .bfloat16 :
1120+ raise ValueError (
1121+ f"fused_rmsnorm_silu requires bfloat16 input, got { args .input_dtype } "
1122+ )
1123+
1124+ input_shape = (batch_size , hidden_size )
1125+ input_tensor = torch .randn (input_shape , dtype = torch .bfloat16 , device = device )
1126+ weight = torch .rand (hidden_size , dtype = torch .bfloat16 , device = device ) * 1.5 + 0.5
1127+ out = torch .empty (input_shape , dtype = torch .bfloat16 , device = device )
1128+
1129+ if args .verbose >= 2 :
1130+ print (f"[VVERBOSE] { input_tensor .shape = } " )
1131+ print (f"[VVERBOSE] { input_tensor .dtype = } " )
1132+ print (f"[VVERBOSE] { weight .shape = } " )
1133+
1134+ def run_fn (input_tensor , weight , out ):
1135+ return flashinfer .fused_rmsnorm_silu (input_tensor , weight , eps = eps , out = out )
1136+
1137+ has_reference_output = False
1138+ if run_refcheck :
1139+ rms = torch .sqrt (
1140+ torch .mean (input_tensor .float () ** 2 , dim = - 1 , keepdim = True ) + eps
1141+ )
1142+ x_norm = input_tensor .float () / rms * weight .float ()
1143+ reference_output = torch .nn .functional .silu (x_norm ).to (torch .bfloat16 )
1144+ has_reference_output = True
1145+
1146+ if run_refcheck :
1147+ test_out = run_fn (input_tensor , weight , out )
1148+ if has_reference_output :
1149+ (
1150+ num_different_elements ,
1151+ num_elements ,
1152+ num_different_elements_percentage ,
1153+ ) = is_close_stats (reference_output , test_out , rtol = 2e-2 , atol = 2e-2 )
1154+ if num_different_elements > 0 :
1155+ print (
1156+ f"[ERROR] Output tensor mismatch: "
1157+ f"{ num_different_elements } /{ num_elements } ({ num_different_elements_percentage :.2f} %) elements differ"
1158+ )
1159+ if not args .allow_output_mismatch :
1160+ raise AssertionError (
1161+ f"[ERROR] Output mismatch with { num_different_elements } elements"
1162+ )
1163+
1164+ times = bench_gpu_time (
1165+ fn = run_fn ,
1166+ dry_run_iters = args .dry_run_iters ,
1167+ repeat_iters = args .num_iters ,
1168+ enable_cupti = args .use_cupti ,
1169+ use_cuda_graph = is_cuda_graph_compatible ,
1170+ input_args = (input_tensor , weight , out ),
1171+ )
1172+
1173+ if len (times ) > 0 :
1174+ median_time = np .median (times )
1175+ std_time = np .std (times )
1176+
1177+ num_elements = np .prod (input_shape )
1178+ problem_bytes = (
1179+ num_elements * input_dtype .itemsize # input read
1180+ + hidden_size * input_dtype .itemsize # weight read
1181+ + num_elements * input_dtype .itemsize # output write
1182+ )
1183+ problem_flops = num_elements * 7 # rmsnorm (5) + silu (2: exp + div)
1184+ tflops = problem_flops / (10 ** 9 * median_time )
1185+ tb_per_sec = problem_bytes / (10 ** 9 * median_time )
1186+
1187+ print_perf_metrics ("cuda" , median_time , std_time , tflops , tb_per_sec )
1188+
1189+ if args .output_path is not None :
1190+ cur_res = defaultdict (str )
1191+ cur_res ["routine" ] = args .routine
1192+ cur_res ["median_time" ] = median_time
1193+ cur_res ["std_time" ] = std_time
1194+ cur_res ["tflops" ] = tflops
1195+ cur_res ["tb_per_sec" ] = tb_per_sec
1196+ cur_res ["input_dtype" ] = str (input_dtype )
1197+ cur_res ["eps" ] = eps
1198+ cur_res ["backend" ] = "cuda"
1199+ cur_res ["case_tag" ] = args .case_tag
1200+ res .append (cur_res )
1201+ return res
0 commit comments