Skip to content

Commit 0139f27

Browse files
authored
Revert "Remove unused primal arguments for AutoDiffRegionOp (#2640)" (#2677)
This reverts commit 15a3d98.
1 parent 15a3d98 commit 0139f27

File tree

2 files changed

+1
-84
lines changed

2 files changed

+1
-84
lines changed

enzyme/Enzyme/MLIR/Dialect/Ops.cpp

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -958,73 +958,14 @@ class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
958958
}
959959
};
960960

961-
class RemoveUnusedArgs final : public OpRewritePattern<AutoDiffRegionOp> {
962-
963-
public:
964-
using OpRewritePattern<AutoDiffRegionOp>::OpRewritePattern;
965-
966-
LogicalResult matchAndRewrite(AutoDiffRegionOp uop,
967-
PatternRewriter &rewriter) const override {
968-
SmallVector<Value> newInArgs;
969-
SmallVector<size_t> argIdxToErase;
970-
SmallVector<ActivityAttr> newInActivityArgs;
971-
llvm::SmallVector<Value> blockArg(uop.getBody().getArguments());
972-
auto in_idx = 0;
973-
for (auto [idx, act] :
974-
llvm::enumerate(uop.getActivity().getAsRange<ActivityAttr>())) {
975-
auto act_val = act.getValue();
976-
Value res = uop.getInputs()[in_idx++];
977-
978-
if (blockArg[idx].use_empty()) {
979-
argIdxToErase.push_back(idx);
980-
if (act_val == Activity::enzyme_dup ||
981-
act_val == Activity::enzyme_dupnoneed) {
982-
in_idx++;
983-
}
984-
} else {
985-
newInActivityArgs.push_back(act);
986-
newInArgs.push_back(res);
987-
if (act_val == Activity::enzyme_dup ||
988-
act_val == Activity::enzyme_dupnoneed) {
989-
res = uop.getInputs()[in_idx++];
990-
newInArgs.push_back(res);
991-
}
992-
}
993-
}
994-
995-
if (argIdxToErase.empty())
996-
return failure();
997-
998-
newInArgs.append(uop.getDifferentialReturns());
999-
ArrayAttr newInActivity =
1000-
ArrayAttr::get(rewriter.getContext(),
1001-
llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
1002-
newInActivityArgs.end()));
1003-
1004-
auto newOp = AutoDiffRegionOp::create(
1005-
rewriter, uop.getLoc(), uop.getResultTypes(), newInArgs, newInActivity,
1006-
uop.getRetActivity(), uop.getWidthAttr(), uop.getStrongZeroAttr(),
1007-
uop.getFnAttr());
1008-
rewriter.inlineRegionBefore(uop.getBody(), newOp.getBody(),
1009-
newOp.getBody().begin());
1010-
1011-
for (auto idx : llvm::reverse(argIdxToErase)) {
1012-
newOp.getBody().eraseArgument(idx);
1013-
}
1014-
1015-
rewriter.replaceOp(uop, newOp);
1016-
return success();
1017-
}
1018-
};
1019-
1020961
void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021962
MLIRContext *context) {
1022963
patterns.add<ReverseRetOpt<AutoDiffOp>>(context);
1023964
}
1024965

1025966
void AutoDiffRegionOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1026967
MLIRContext *context) {
1027-
patterns.add<ReverseRetOpt<AutoDiffRegionOp>, RemoveUnusedArgs>(context);
968+
patterns.add<ReverseRetOpt<AutoDiffRegionOp>>(context);
1028969
}
1029970
//===----------------------------------------------------------------------===//
1030971
// SampleOp

enzyme/test/MLIR/ReverseMode/region_canonicalize.mlir

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)