diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 82d11ec0737a..61bb391036a8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -18798,6 +18798,31 @@ def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [ }]; } +def Torch_Aten_ForeachLerpListOp : Torch_Op<"aten._foreach_lerp.List", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_foreach_lerp.List : (Tensor[], Tensor[], Tensor[]) -> (Tensor[])`"; + let arguments = (ins + AnyTorchListOfTensorType:$self, + AnyTorchListOfTensorType:$tensors1, + AnyTorchListOfTensorType:$weights + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ForeachLerpListOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_ForeachLerpListOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9a386ec35f30..536f603792cd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11836,6 +11836,43 @@ class DecomposePrimsSumOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose aten._foreach_lerp.List into element-wise aten.lerp.Tensor +class DecomposeAten_ForeachLerpListOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_ForeachLerpListOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector selfElems, tensors1Elems, weightsElems; + if (!getListConstructElements(op.getSelf(), selfElems)) + return rewriter.notifyMatchFailure( + op, "self must come from a PrimListConstructOp"); + if (!getListConstructElements(op.getTensors1(), tensors1Elems)) + return rewriter.notifyMatchFailure( + op, "tensors1 must come from a PrimListConstructOp"); + if (!getListConstructElements(op.getWeights(), weightsElems)) + return rewriter.notifyMatchFailure( + op, "weights must come from a PrimListConstructOp"); + + if (selfElems.size() != tensors1Elems.size() || + selfElems.size() != weightsElems.size()) + return rewriter.notifyMatchFailure(op, "list sizes must match"); + + SmallVector results; + for (size_t i = 0; i < selfElems.size(); i++) { + Value lerped = AtenLerpTensorOp::create( + rewriter, loc, selfElems[i].getType(), selfElems[i], tensors1Elems[i], + weightsElems[i]); + results.push_back(lerped); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), results); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.sgn` op into comparisons and aten.where. class DecomposeAtenSgnOp : public OpRewritePattern { @@ -13558,6 +13595,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7d5e65c21cef..03e406be3e27 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1282,6 +1282,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()") emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()") emit("aten::_assert_scalar : (Scalar, str) -> ()") + emit("aten::_foreach_lerp.List : (Tensor[], Tensor[], Tensor[]) -> (Tensor[])") # ========================================================================== # `prim::` namespace. diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index d8dc39375bb6..dbb87f931eed 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -1013,3 +1013,29 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens %0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32> return %0 : !torch.vtensor<[1,8,4,4],f32> } + + +// ----- + +// CHECK-LABEL: func.func @foreach_lerp_list_decompose +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3],f32>, %[[ARG2:.*]]: !torch.vtensor<[3],f32>, %[[ARG3:.*]]: !torch.vtensor<[3],f32>, %[[ARG4:.*]]: !torch.vtensor<[3],f32>, %[[ARG5:.*]]: !torch.vtensor<[3],f32> +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[SUB0:.*]] = torch.aten.sub.Tensor %[[ARG2]], %[[ARG0]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> +// CHECK: %[[MUL0:.*]] = torch.aten.mul.Tensor %[[SUB0]], %[[ARG4]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> +// CHECK: %[[ADD0:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[MUL0]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> +// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[ARG3]], %[[ARG1]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> +// CHECK: %[[MUL1:.*]] = torch.aten.mul.Tensor %[[SUB1]], %[[ARG5]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> +// CHECK: %[[ADD1:.*]] = torch.aten.add.Tensor %[[ARG1]], %[[MUL1]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> +// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ADD0]], %[[ADD1]] : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list> +// CHECK: return %[[RESULT]] : !torch.list> +func.func @foreach_lerp_list_decompose( + %a0: !torch.vtensor<[3],f32>, %a1: !torch.vtensor<[3],f32>, + %b0: !torch.vtensor<[3],f32>, %b1: !torch.vtensor<[3],f32>, + %w0: !torch.vtensor<[3],f32>, %w1: !torch.vtensor<[3],f32> +) -> !torch.list> { + %self = torch.prim.ListConstruct %a0, %a1 : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list> + %tensors1 = torch.prim.ListConstruct %b0, %b1 : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list> + %weights = torch.prim.ListConstruct %w0, %w1 : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list> + %0 = torch.aten._foreach_lerp.List %self, %tensors1, %weights : !torch.list>, !torch.list>, !torch.list> -> !torch.list> + return %0 : !torch.list> +}