diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 82d11ec0737a..09ac0549266d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16819,6 +16819,31 @@ def Torch_AtenMulFloatIntOp : Torch_Op<"aten.mul.float_int", [ let hasFolder = 1; } +def Torch_AtenAddFloatOp : Torch_Op<"aten.add.float", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::add.float : (float, float) -> (float)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_FloatType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAddFloatOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenAddFloatOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 3b095e37bdea..b92c400c582a 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -534,7 +534,9 @@ class ConvertTorchToArith typeConverter, context); patterns.add>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bbbd8fec5427..677ade0c54a1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4500,6 +4500,15 @@ OpFoldResult AtenSubFloatOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](double a, double b) { return a - b; }); } +//===----------------------------------------------------------------------===// +// AtenAddFloatOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAddFloatOp::fold(FoldAdaptor adaptor) { + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a + b; }); +} + //===----------------------------------------------------------------------===// // AtenAddOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7d5e65c21cef..d0b8ab502744 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1184,6 +1184,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) emit("aten::mul.float_int : (float, int) -> (float)", has_folder=True) + emit("aten::add.float : (float, float) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index aa97bdefb123..592224857f26 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -210,6 +210,19 @@ func.func @torch.aten.sub.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in return %0 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.add.float( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] +// CHECK: %[[ADD:.*]] = arith.addf %[[LHS_F64:.*]], %[[RHS_F64:.*]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[ADD:.*]] +// CHECK: return %[[OUT:.*]] : !torch.float +func.func @torch.aten.add.float(%arg0: !torch.float, %arg1: !torch.float) -> !torch.float { + %0 = torch.aten.add.float %arg0, %arg1 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.sub.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index de05a619f1bd..9f4ea8d4b2c7 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2412,6 +2412,16 @@ func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], return %2 : !torch.vtensor<[],si64> } +// CHECK-LABEL: func.func @torch.aten.add.float$fold() -> !torch.float { +// CHECK: %[[FLOAT_3:.*]] = torch.constant.float 3.000000e+00 +// CHECK: return %[[FLOAT_3]] : !torch.float +func.func @torch.aten.add.float$fold() -> !torch.float { + %float1 = torch.constant.float 1.0 + %float2 = torch.constant.float 2.0 + %0 = torch.aten.add.float %float1, %float2 : !torch.float, !torch.float -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.sub.float$fold() -> !torch.float { // CHECK: %[[FLOAT_1:.*]] = torch.constant.float -1.000000e+00 // CHECK: return %[[FLOAT_1]] : !torch.float