torch.prims.convert_element_type to linalg bf16 to f16 fail #3962
Open
Description
'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible
This error from llama3_8b_fp8 model
small reproducer input ir convert.torch.mlir
:
func.func @convert(%652: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
%int5 = torch.constant.int 5
%0 = torch.prims.convert_element_type %652, %int5 : !torch.vtensor<[1,?,32,128],bf16>, !torch.int -> !torch.vtensor<[1,?,32,128],f16>
return %0 : !torch.vtensor<[1,?,32,128],f16>
}
torch-mlir-opt --torch-decompose-complex-ops --cse --canonicalize convert.torch.mlir > todtype.torch.mlir
module {
func.func @convert(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
%int5 = torch.constant.int 5
%false = torch.constant.bool false
%none = torch.constant.none
%0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
return %0 : !torch.vtensor<[1,?,32,128],f16>
}
}
torch-mlir-opt ---convert-torch-to-linalg todtype.torch.mlir
'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (!torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16>, sym_name = "convert"}> ({
^bb0(%arg0: !torch.vtensor<[1,?,32,128],bf16>):
%0 = "builtin.unrealized_conversion_cast"(%arg0) : (!torch.vtensor<[1,?,32,128],bf16>) -> tensor<1x?x32x128xbf16>
%1 = "torch.constant.int"() <{value = 5 : i64}> : () -> !torch.int
%2 = "builtin.unrealized_conversion_cast"(%1) : (!torch.int) -> i64
%3 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
%4 = "builtin.unrealized_conversion_cast"(%3) : (!torch.bool) -> i1
%5 = "torch.constant.none"() : () -> !torch.none
%6 = "arith.constant"() <{value = 1 : index}> : () -> index
%7 = "arith.constant"() <{value = 1 : index}> : () -> index
%8 = "tensor.dim"(%0, %7) : (tensor<1x?x32x128xbf16>, index) -> index
%9 = "arith.constant"() <{value = 2 : index}> : () -> index
%10 = "arith.constant"() <{value = 32 : index}> : () -> index
%11 = "arith.constant"() <{value = 3 : index}> : () -> index
%12 = "arith.constant"() <{value = 128 : index}> : () -> index
%13 = "tensor.empty"(%8) : (index) -> tensor<1x?x32x128xf16>
%14 = "linalg.generic"(%0, %13) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
^bb0(%arg1: bf16, %arg2: f16):
%17 = "arith.extf"(%arg1) : (bf16) -> f16
"linalg.yield"(%17) : (f16) -> ()
}) : (tensor<1x?x32x128xbf16>, tensor<1x?x32x128xf16>) -> tensor<1x?x32x128xf16>
%15 = "tensor.cast"(%14) : (tensor<1x?x32x128xf16>) -> tensor<1x?x32x128xf16>
%16 = "torch.aten.to.dtype"(%arg0, %1, %3, %3, %5) : (!torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none) -> !torch.vtensor<[1,?,32,128],f16>
"func.return"(%16) : (!torch.vtensor<[1,?,32,128],f16>) -> ()
}) : () -> ()
Metadata
Assignees
Labels
No labels