Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18798,6 +18798,31 @@ def Torch_PrimAbsScalarOp : Torch_Op<"prim.abs.Scalar", [
}];
}

def Torch_Aten_ForeachLerpListOp : Torch_Op<"aten._foreach_lerp.List", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_foreach_lerp.List : (Tensor[], Tensor[], Tensor[]) -> (Tensor[])`";
let arguments = (ins
AnyTorchListOfTensorType:$self,
AnyTorchListOfTensorType:$tensors1,
AnyTorchListOfTensorType:$weights
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ForeachLerpListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void Aten_ForeachLerpListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_PrimsConvertElementTypeOp : Torch_Op<"prims.convert_element_type", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11836,6 +11836,43 @@ class DecomposePrimsSumOp : public OpRewritePattern<PrimsSumOp> {
};
} // namespace

namespace {
// Decompose aten._foreach_lerp.List into element-wise aten.lerp.Tensor
class DecomposeAten_ForeachLerpListOp
: public OpRewritePattern<Aten_ForeachLerpListOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_ForeachLerpListOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
SmallVector<Value> selfElems, tensors1Elems, weightsElems;
if (!getListConstructElements(op.getSelf(), selfElems))
return rewriter.notifyMatchFailure(
op, "self must come from a PrimListConstructOp");
if (!getListConstructElements(op.getTensors1(), tensors1Elems))
return rewriter.notifyMatchFailure(
op, "tensors1 must come from a PrimListConstructOp");
if (!getListConstructElements(op.getWeights(), weightsElems))
return rewriter.notifyMatchFailure(
op, "weights must come from a PrimListConstructOp");

if (selfElems.size() != tensors1Elems.size() ||
selfElems.size() != weightsElems.size())
return rewriter.notifyMatchFailure(op, "list sizes must match");

SmallVector<Value> results;
for (size_t i = 0; i < selfElems.size(); i++) {
Value lerped = AtenLerpTensorOp::create(
rewriter, loc, selfElems[i].getType(), selfElems[i], tensors1Elems[i],
weightsElems[i]);
results.push_back(lerped);
}

rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, op.getType(), results);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.sgn` op into comparisons and aten.where.
class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
Expand Down Expand Up @@ -13558,6 +13595,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSgnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsSumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ForeachLerpListOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::sym_constrain_range : (Scalar, int?, int?) -> ()")
emit("aten::sym_constrain_range_for_size : (Scalar, int?, int?) -> ()")
emit("aten::_assert_scalar : (Scalar, str) -> ()")
emit("aten::_foreach_lerp.List : (Tensor[], Tensor[], Tensor[]) -> (Tensor[])")

# ==========================================================================
# `prim::` namespace.
Expand Down
26 changes: 26 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,29 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
return %0 : !torch.vtensor<[1,8,4,4],f32>
}


// -----

// CHECK-LABEL: func.func @foreach_lerp_list_decompose
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3],f32>, %[[ARG2:.*]]: !torch.vtensor<[3],f32>, %[[ARG3:.*]]: !torch.vtensor<[3],f32>, %[[ARG4:.*]]: !torch.vtensor<[3],f32>, %[[ARG5:.*]]: !torch.vtensor<[3],f32>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[SUB0:.*]] = torch.aten.sub.Tensor %[[ARG2]], %[[ARG0]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[MUL0:.*]] = torch.aten.mul.Tensor %[[SUB0]], %[[ARG4]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[ADD0:.*]] = torch.aten.add.Tensor %[[ARG0]], %[[MUL0]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[ARG3]], %[[ARG1]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[MUL1:.*]] = torch.aten.mul.Tensor %[[SUB1]], %[[ARG5]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
// CHECK: %[[ADD1:.*]] = torch.aten.add.Tensor %[[ARG1]], %[[MUL1]], %[[INT1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32>
// CHECK: %[[RESULT:.*]] = torch.prim.ListConstruct %[[ADD0]], %[[ADD1]] : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list<vtensor<[3],f32>>
// CHECK: return %[[RESULT]] : !torch.list<vtensor<[3],f32>>
func.func @foreach_lerp_list_decompose(
%a0: !torch.vtensor<[3],f32>, %a1: !torch.vtensor<[3],f32>,
%b0: !torch.vtensor<[3],f32>, %b1: !torch.vtensor<[3],f32>,
%w0: !torch.vtensor<[3],f32>, %w1: !torch.vtensor<[3],f32>
) -> !torch.list<vtensor<[3],f32>> {
%self = torch.prim.ListConstruct %a0, %a1 : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list<vtensor<[3],f32>>
%tensors1 = torch.prim.ListConstruct %b0, %b1 : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list<vtensor<[3],f32>>
%weights = torch.prim.ListConstruct %w0, %w1 : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.list<vtensor<[3],f32>>
%0 = torch.aten._foreach_lerp.List %self, %tensors1, %weights : !torch.list<vtensor<[3],f32>>, !torch.list<vtensor<[3],f32>>, !torch.list<vtensor<[3],f32>> -> !torch.list<vtensor<[3],f32>>
return %0 : !torch.list<vtensor<[3],f32>>
}