Skip to content

torch.aten.mm fails to legalize with mixed precision operands #4422

@sjain-stanford

Description

@sjain-stanford

(thanks to Claude Sonnet 4.5 for summarizing the issue)

Title

torch.aten.mm fails to legalize with mixed precision operands (f32 + bf16), while torch.aten.matmul (batched) works correctly

Description

There's an inconsistency in type promotion handling between batched and non-batched matrix multiplication operations in torch-mlir. The lowering from Torch to Linalg correctly handles mixed precision operands for torch.aten.matmul (batch matmul) by automatically inserting type extension operations, but fails to legalize torch.aten.mm (regular matmul) with the same mixed precision scenario.

Expected Behavior

Both torch.aten.matmul and torch.aten.mm should support automatic type promotion when operands have different precisions (e.g., f32 and bf16), inserting appropriate arith.extf operations to cast operands to the larger type before performing the operation.

Actual Behavior

✅ Works: Batched matmul with mixed precision

The following operation successfully legalizes:

%matmul_C_perm = torch.aten.matmul %matrix_a_perm, %matrix_b_perm : !torch.vtensor<[10,64,16],f32>, !torch.vtensor<[10,16,32],bf16> -> !torch.vtensor<[10,64,32],f32>

Lowering to Linalg correctly inserts type promotion:

%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%transposed_0 : tensor<10x16x32xbf16>) outs(%5 : tensor<10x16x32xf32>) {
^bb0(%in: bf16, %out: f32):
  %13 = arith.extf %in : bf16 to f32
  linalg.yield %13 : f32
} -> tensor<10x16x32xf32>

...

%9 = linalg.batch_matmul ins(%transposed, %6 : tensor<10x64x16xf32>, tensor<10x16x32xf32>) outs(%8 : tensor<10x64x32xf32>) -> tensor<10x64x32xf32>

Note that linalg.batch_matmul receives operands of the same type (both f32).

❌ Fails: Regular matmul with mixed precision

The following operation fails to legalize:

%matmul_C_perm = torch.aten.matmul %matrix_a_perm, %matrix_b_perm : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,32],bf16> -> !torch.vtensor<[16,32],f32>

Error:

/tmp/.cache/fusilli/benchmark_matmul_b1_m16_n32_k64_transAfalse_transBfalse_biasfalse_atypef32_btypebf16_outtypef32/iree-compile-input.mlir:17:22: error: failed to legalize operation 'torch.aten.mm' that was explicitly marked illegal
    %matmul_C_perm = torch.aten.matmul %matrix_a_perm, %matrix_b_perm : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,32],bf16> -> !torch.vtensor<[16,32],f32>
                     ^
/tmp/.cache/fusilli/benchmark_matmul_b1_m16_n32_k64_transAfalse_transBfalse_biasfalse_atypef32_btypebf16_outtypef32/iree-compile-input.mlir:17:22: note: see current operation: %22 = "torch.aten.mm"(%12, %21) : (!torch.vtensor<[16,64],f32>, !torch.vtensor<[64,32],bf16>) -> !torch.vtensor<[16,32],f32>

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions