-
Notifications
You must be signed in to change notification settings - Fork 647
Description
(thanks to Claude Sonnet 4.5 for summarizing the issue)
Title
torch.aten.mm fails to legalize with mixed precision operands (f32 + bf16), while torch.aten.matmul (batched) works correctly
Description
There's an inconsistency in type promotion handling between batched and non-batched matrix multiplication operations in torch-mlir. The lowering from Torch to Linalg correctly handles mixed precision operands for torch.aten.matmul (batch matmul) by automatically inserting type extension operations, but fails to legalize torch.aten.mm (regular matmul) with the same mixed precision scenario.
Expected Behavior
Both torch.aten.matmul and torch.aten.mm should support automatic type promotion when operands have different precisions (e.g., f32 and bf16), inserting appropriate arith.extf operations to cast operands to the larger type before performing the operation.
Actual Behavior
✅ Works: Batched matmul with mixed precision
The following operation successfully legalizes:
%matmul_C_perm = torch.aten.matmul %matrix_a_perm, %matrix_b_perm : !torch.vtensor<[10,64,16],f32>, !torch.vtensor<[10,16,32],bf16> -> !torch.vtensor<[10,64,32],f32>Lowering to Linalg correctly inserts type promotion:
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%transposed_0 : tensor<10x16x32xbf16>) outs(%5 : tensor<10x16x32xf32>) {
^bb0(%in: bf16, %out: f32):
%13 = arith.extf %in : bf16 to f32
linalg.yield %13 : f32
} -> tensor<10x16x32xf32>
...
%9 = linalg.batch_matmul ins(%transposed, %6 : tensor<10x64x16xf32>, tensor<10x16x32xf32>) outs(%8 : tensor<10x64x32xf32>) -> tensor<10x64x32xf32>Note that linalg.batch_matmul receives operands of the same type (both f32).
❌ Fails: Regular matmul with mixed precision
The following operation fails to legalize:
%matmul_C_perm = torch.aten.matmul %matrix_a_perm, %matrix_b_perm : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,32],bf16> -> !torch.vtensor<[16,32],f32>Error:
/tmp/.cache/fusilli/benchmark_matmul_b1_m16_n32_k64_transAfalse_transBfalse_biasfalse_atypef32_btypebf16_outtypef32/iree-compile-input.mlir:17:22: error: failed to legalize operation 'torch.aten.mm' that was explicitly marked illegal
%matmul_C_perm = torch.aten.matmul %matrix_a_perm, %matrix_b_perm : !torch.vtensor<[16,64],f32>, !torch.vtensor<[64,32],bf16> -> !torch.vtensor<[16,32],f32>
^
/tmp/.cache/fusilli/benchmark_matmul_b1_m16_n32_k64_transAfalse_transBfalse_biasfalse_atypef32_btypebf16_outtypef32/iree-compile-input.mlir:17:22: note: see current operation: %22 = "torch.aten.mm"(%12, %21) : (!torch.vtensor<[16,64],f32>, !torch.vtensor<[64,32],bf16>) -> !torch.vtensor<[16,32],f32>