- 
                Notifications
    
You must be signed in to change notification settings  - Fork 1.5k
 
Open
Labels
Description
Which component has the problem?
CuTe DSL
Bug Report
Describe the bug
When running grouped_blockscaled_gemm.py with use_cold_l2 == True, I run into an IMA with warmup_iterations + iterations > 10 in many cases, including but not limited to:
- MXF4 and NVF4 with all test cases
 - MXFP8 with >=4 groups
None of the IMAs arises without cold l2 or with fewer than 10 iterations. 
Steps/Code to reproduce bug
from grouped_blockscaled_gemm.py import run
run(
    num_groups=4,
    problem_sizes_mnkl=[(256, 256, 256, 1), (512, 512, 512, 1), (1024, 1024, 1024, 1), (2048, 2048, 2048, 1)],
    ab_dtype=cutlass.Float8E4M3FN,
    sf_dtype=cutlass.Float8E8M0FNU,
    sf_vec_size=32,
    c_dtype=cutlass.Float32,
    a_major="m",
    b_major="n",
    c_major="m",
    mma_tiler_mn=(128, 128),
    cluster_shape_mn=(1, 1),
    warmup_iterations=5,
    iterations=10, # or any number of total iterations > 10
    skip_ref_check=True,
    use_cold_l2=True,
)Environment details (please complete the following information):
- B200