Description
We are hitting a runtime failure when using the newly added CUDA kernel mx_block_rearrange_2d_M_groups_cuda introduced in PR #3546.
Our workload requires a specific per-group size (44), which appears to be unsupported by the kernel. The failure manifests as a CUDA driver error at runtime.
Input configuration:
scales_tensor: torch.Size([98048, 44])
dtype = torch.float8_e8m0fnu
input_group_end_offsets: torch.Size([8])
dtype = torch.int32
chunk_width = 64
chunks_per_tb = 8
Error:
CUDA workload error: CUDA Driver Error at /workspace/ao/torchao/csrc/cuda/mx_kernels/mx_block_rearrange_2d_M_groups.cu:431
invalid argument
cc @danielvegamyhre @drisspg