@@ -958,73 +958,14 @@ class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
958958 }
959959};
960960
961- class RemoveUnusedArgs final : public OpRewritePattern<AutoDiffRegionOp> {
962-
963- public:
964- using OpRewritePattern<AutoDiffRegionOp>::OpRewritePattern;
965-
966- LogicalResult matchAndRewrite (AutoDiffRegionOp uop,
967- PatternRewriter &rewriter) const override {
968- SmallVector<Value> newInArgs;
969- SmallVector<size_t > argIdxToErase;
970- SmallVector<ActivityAttr> newInActivityArgs;
971- llvm::SmallVector<Value> blockArg (uop.getBody ().getArguments ());
972- auto in_idx = 0 ;
973- for (auto [idx, act] :
974- llvm::enumerate (uop.getActivity ().getAsRange <ActivityAttr>())) {
975- auto act_val = act.getValue ();
976- Value res = uop.getInputs ()[in_idx++];
977-
978- if (blockArg[idx].use_empty ()) {
979- argIdxToErase.push_back (idx);
980- if (act_val == Activity::enzyme_dup ||
981- act_val == Activity::enzyme_dupnoneed) {
982- in_idx++;
983- }
984- } else {
985- newInActivityArgs.push_back (act);
986- newInArgs.push_back (res);
987- if (act_val == Activity::enzyme_dup ||
988- act_val == Activity::enzyme_dupnoneed) {
989- res = uop.getInputs ()[in_idx++];
990- newInArgs.push_back (res);
991- }
992- }
993- }
994-
995- if (argIdxToErase.empty ())
996- return failure ();
997-
998- newInArgs.append (uop.getDifferentialReturns ());
999- ArrayAttr newInActivity =
1000- ArrayAttr::get (rewriter.getContext (),
1001- llvm::ArrayRef<Attribute>(newInActivityArgs.begin (),
1002- newInActivityArgs.end ()));
1003-
1004- auto newOp = AutoDiffRegionOp::create (
1005- rewriter, uop.getLoc (), uop.getResultTypes (), newInArgs, newInActivity,
1006- uop.getRetActivity (), uop.getWidthAttr (), uop.getStrongZeroAttr (),
1007- uop.getFnAttr ());
1008- rewriter.inlineRegionBefore (uop.getBody (), newOp.getBody (),
1009- newOp.getBody ().begin ());
1010-
1011- for (auto idx : llvm::reverse (argIdxToErase)) {
1012- newOp.getBody ().eraseArgument (idx);
1013- }
1014-
1015- rewriter.replaceOp (uop, newOp);
1016- return success ();
1017- }
1018- };
1019-
1020961void AutoDiffOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
1021962 MLIRContext *context) {
1022963 patterns.add <ReverseRetOpt<AutoDiffOp>>(context);
1023964}
1024965
1025966void AutoDiffRegionOp::getCanonicalizationPatterns (RewritePatternSet &patterns,
1026967 MLIRContext *context) {
1027- patterns.add <ReverseRetOpt<AutoDiffRegionOp>, RemoveUnusedArgs >(context);
968+ patterns.add <ReverseRetOpt<AutoDiffRegionOp>>(context);
1028969}
1029970// ===----------------------------------------------------------------------===//
1030971// SampleOp
0 commit comments