Skip to content

[BUG] [CuTe DSL] Cold L2 leads to IMA in grouped_blockscaled_gemm #2737

@reubenconducts

Description

@reubenconducts

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug
When running grouped_blockscaled_gemm.py with use_cold_l2 == True, I run into an IMA with warmup_iterations + iterations > 10 in many cases, including but not limited to:

  • MXF4 and NVF4 with all test cases
  • MXFP8 with >=4 groups
    None of the IMAs arises without cold l2 or with fewer than 10 iterations.

Steps/Code to reproduce bug

from grouped_blockscaled_gemm.py import run

run(
    num_groups=4,
    problem_sizes_mnkl=[(256, 256, 256, 1), (512, 512, 512, 1), (1024, 1024, 1024, 1), (2048, 2048, 2048, 1)],
    ab_dtype=cutlass.Float8E4M3FN,
    sf_dtype=cutlass.Float8E8M0FNU,
    sf_vec_size=32,
    c_dtype=cutlass.Float32,
    a_major="m",
    b_major="n",
    c_major="m",
    mma_tiler_mn=(128, 128),
    cluster_shape_mn=(1, 1),
    warmup_iterations=5,
    iterations=10, # or any number of total iterations > 10
    skip_ref_check=True,
    use_cold_l2=True,
)

Environment details (please complete the following information):

  • B200

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions