Skip to content

torch.prims.convert_element_type to linalg bf16 to f16 fail #3962

Open
@AmosLewis

Description

'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible

This error from llama3_8b_fp8 model
small reproducer input ir convert.torch.mlir :

func.func @convert(%652: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
  %int5 = torch.constant.int 5
  %0 = torch.prims.convert_element_type %652, %int5 : !torch.vtensor<[1,?,32,128],bf16>, !torch.int -> !torch.vtensor<[1,?,32,128],f16>
  return %0 : !torch.vtensor<[1,?,32,128],f16>
}

torch-mlir-opt --torch-decompose-complex-ops --cse --canonicalize convert.torch.mlir > todtype.torch.mlir

module {
  func.func @convert(%arg0: !torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16> {
    %int5 = torch.constant.int 5
    %false = torch.constant.bool false
    %none = torch.constant.none
    %0 = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,?,32,128],f16>
    return %0 : !torch.vtensor<[1,?,32,128],f16>
  }
}

torch-mlir-opt ---convert-torch-to-linalg todtype.torch.mlir

'arith.extf' op operand type 'bf16' and result type 'f16' are cast incompatible
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() <{function_type = (!torch.vtensor<[1,?,32,128],bf16>) -> !torch.vtensor<[1,?,32,128],f16>, sym_name = "convert"}> ({
^bb0(%arg0: !torch.vtensor<[1,?,32,128],bf16>):
  %0 = "builtin.unrealized_conversion_cast"(%arg0) : (!torch.vtensor<[1,?,32,128],bf16>) -> tensor<1x?x32x128xbf16>
  %1 = "torch.constant.int"() <{value = 5 : i64}> : () -> !torch.int
  %2 = "builtin.unrealized_conversion_cast"(%1) : (!torch.int) -> i64
  %3 = "torch.constant.bool"() <{value = false}> : () -> !torch.bool
  %4 = "builtin.unrealized_conversion_cast"(%3) : (!torch.bool) -> i1
  %5 = "torch.constant.none"() : () -> !torch.none
  %6 = "arith.constant"() <{value = 1 : index}> : () -> index
  %7 = "arith.constant"() <{value = 1 : index}> : () -> index
  %8 = "tensor.dim"(%0, %7) : (tensor<1x?x32x128xbf16>, index) -> index
  %9 = "arith.constant"() <{value = 2 : index}> : () -> index
  %10 = "arith.constant"() <{value = 32 : index}> : () -> index
  %11 = "arith.constant"() <{value = 3 : index}> : () -> index
  %12 = "arith.constant"() <{value = 128 : index}> : () -> index
  %13 = "tensor.empty"(%8) : (index) -> tensor<1x?x32x128xf16>
  %14 = "linalg.generic"(%0, %13) <{indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg1: bf16, %arg2: f16):
    %17 = "arith.extf"(%arg1) : (bf16) -> f16
    "linalg.yield"(%17) : (f16) -> ()
  }) : (tensor<1x?x32x128xbf16>, tensor<1x?x32x128xf16>) -> tensor<1x?x32x128xf16>
  %15 = "tensor.cast"(%14) : (tensor<1x?x32x128xf16>) -> tensor<1x?x32x128xf16>
  %16 = "torch.aten.to.dtype"(%arg0, %1, %3, %3, %5) : (!torch.vtensor<[1,?,32,128],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none) -> !torch.vtensor<[1,?,32,128],f16>
  "func.return"(%16) : (!torch.vtensor<[1,?,32,128],f16>) -> ()
}) : () -> ()

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions