From 3bfcc96fa4e8e7a85021c803b7961682853ef233 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 13 May 2026 15:05:24 -0500 Subject: [PATCH 1/8] using scf.parallel+memref.store to zero out a memref --- .../MemRefAutoDiffOpInterfaceImpl.cpp | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 1a47acf65136..a2837a6ba48a 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -18,13 +18,10 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" -// TODO: We need a way to zero out a memref (which linalg.fill does), but -// ideally we wouldn't depend on the linalg dialect. -#include "mlir/Dialect/Linalg/IR/Linalg.h" - using namespace mlir; using namespace mlir::enzyme; @@ -273,14 +270,32 @@ class MemRefAutoDiffTypeInterface LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { auto MT = cast(self); - if (auto iface = dyn_cast(MT.getElementType())) { - if (!iface.isMutable()) { - Value zero = iface.createNullValue(builder, loc); - linalg::FillOp::create(builder, loc, zero, val); - } - } else { + auto eltIface = dyn_cast(MT.getElementType()); + if (!eltIface || eltIface.isMutable()) return failure(); + Value zero = eltIface.createNullValue(builder, loc); + + if (MT.getRank() == 0) { + memref::StoreOp::create(builder, loc, zero, val, ValueRange{}); + return success(); } + + Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); + Value c1 = arith::ConstantIndexOp::create(builder, loc, 1); + + SmallVector lbs(MT.getRank(), c0); + SmallVector steps(MT.getRank(), c1); + SmallVector ubs; + for (auto [i, d] : llvm::enumerate(MT.getShape())) { + ubs.push_back(d == ShapedType::kDynamic + ? memref::DimOp::create(builder, loc, val, i).getResult() + : arith::ConstantIndexOp::create(builder, loc, d).getResult()); + } + + scf::ParallelOp::create(builder, loc, lbs, ubs, steps, + [&](OpBuilder &b, Location l, ValueRange ivs) { + memref::StoreOp::create(b, l, zero, val, ivs); + }); return success(); } From d69009f22474a100c2f00502be36352b1163d6ab Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 13 May 2026 15:10:22 -0500 Subject: [PATCH 2/8] add test for memref.alloca --- .../MemRefAutoDiffOpInterfaceImpl.cpp | 15 +++++----- enzyme/test/MLIR/ReverseMode/alloca.mlir | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 enzyme/test/MLIR/ReverseMode/alloca.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index a2837a6ba48a..62ee51ba125d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -274,7 +274,7 @@ class MemRefAutoDiffTypeInterface if (!eltIface || eltIface.isMutable()) return failure(); Value zero = eltIface.createNullValue(builder, loc); - + if (MT.getRank() == 0) { memref::StoreOp::create(builder, loc, zero, val, ValueRange{}); return success(); @@ -287,15 +287,16 @@ class MemRefAutoDiffTypeInterface SmallVector steps(MT.getRank(), c1); SmallVector ubs; for (auto [i, d] : llvm::enumerate(MT.getShape())) { - ubs.push_back(d == ShapedType::kDynamic - ? memref::DimOp::create(builder, loc, val, i).getResult() - : arith::ConstantIndexOp::create(builder, loc, d).getResult()); + ubs.push_back( + d == ShapedType::kDynamic + ? memref::DimOp::create(builder, loc, val, i).getResult() + : arith::ConstantIndexOp::create(builder, loc, d).getResult()); } scf::ParallelOp::create(builder, loc, lbs, ubs, steps, - [&](OpBuilder &b, Location l, ValueRange ivs) { - memref::StoreOp::create(b, l, zero, val, ivs); - }); + [&](OpBuilder &b, Location l, ValueRange ivs) { + memref::StoreOp::create(b, l, zero, val, ivs); + }); return success(); } diff --git a/enzyme/test/MLIR/ReverseMode/alloca.mlir b/enzyme/test/MLIR/ReverseMode/alloca.mlir new file mode 100644 index 000000000000..a61fb3eaea0d --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/alloca.mlir @@ -0,0 +1,29 @@ +// RUN: %eopt --split-input-file --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s + +func.func @foo_flat(%x : f64) -> f64 { + %buf = memref.alloca() : memref + memref.store %x, %buf[] : memref + %y = memref.load %buf[] : memref + return %y : f64 +} + +func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 { + %dx = enzyme.autodiff @foo_flat(%x, %dout) { + activity = [#enzyme], + ret_activity = [#enzyme] + } : (f64, f64) -> (f64) + return %dx : f64 +} + +// CHECK-LABEL: func.func private @diffefoo_flat( +// CHECK-SAME: %[[X:[^,]+]]: f64, +// CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 { +// A shadow memref.alloca must be created and zero-initialized so the +// reverse-mode adjoint can accumulate gradients into it. +// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[DBUF:.*]] = memref.alloca() : memref +// CHECK: memref.store %[[ZERO]], %[[DBUF]][] : memref +// No leftover placeholders should survive the differentiation. +// CHECK-NOT: enzyme.placeholder +// The function ultimately returns the gradient w.r.t. x (= dout). +// CHECK: return %{{.*}} : f64 \ No newline at end of file From ef3590a5c9ed1b2143ba8da7394bce266ccede37 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 14 May 2026 00:12:53 -0500 Subject: [PATCH 3/8] revert --- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 18 ++++--- enzyme/test/MLIR/ReverseMode/alloca.mlir | 51 +++++++++++++++---- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index f0c064293856..5e58e6e0b6c9 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -223,8 +223,12 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, } assert(!isConstantValue(val)); + bool isMutable = false; + if (auto iface = dyn_cast(val.getType())) + isMutable = iface.isMutable(); + if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit) { + mode == DerivativeMode::ForwardModeSplit || isMutable) { setInvertedPointer(val, toset); } /* @@ -240,11 +244,13 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, void mlir::enzyme::MGradientUtils::setInvertedPointer(Value val, Value toset) { assert(getShadowType(val.getType()) == toset.getType()); - auto found = invertedPointers.lookupOrNull(val); - assert(found != nullptr); - auto placeholder = found.getDefiningOp(); - placeholder.replaceAllUsesWith(toset); - erase(placeholder); + + if (auto found = invertedPointers.lookupOrNull(val)) { + if (auto placeholder = found.getDefiningOp()) { + placeholder.replaceAllUsesWith(toset); + erase(placeholder); + } + } invertedPointers.map(val, toset); } diff --git a/enzyme/test/MLIR/ReverseMode/alloca.mlir b/enzyme/test/MLIR/ReverseMode/alloca.mlir index a61fb3eaea0d..76b47c8c4da3 100644 --- a/enzyme/test/MLIR/ReverseMode/alloca.mlir +++ b/enzyme/test/MLIR/ReverseMode/alloca.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --split-input-file --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math %s | FileCheck %s +// RUN: %eopt --enzyme %s | FileCheck %s func.func @foo_flat(%x : f64) -> f64 { %buf = memref.alloca() : memref @@ -18,12 +18,45 @@ func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 { // CHECK-LABEL: func.func private @diffefoo_flat( // CHECK-SAME: %[[X:[^,]+]]: f64, // CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 { -// A shadow memref.alloca must be created and zero-initialized so the -// reverse-mode adjoint can accumulate gradients into it. -// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK-DAG: %[[DBUF:.*]] = memref.alloca() : memref -// CHECK: memref.store %[[ZERO]], %[[DBUF]][] : memref -// No leftover placeholders should survive the differentiation. + +// CHECK: %[[GX:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient, f64) -> () +// CHECK: %[[CS:.+]] = "enzyme.init"() : () -> !enzyme.Cache> +// CHECK: %[[CL:.+]] = "enzyme.init"() : () -> !enzyme.Cache> +// CHECK: %[[GY:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient, f64) -> () + +// CHECK: %[[DBUF:.+]] = memref.alloca() : memref +// CHECK: %[[ZERO_INIT:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: memref.store %[[ZERO_INIT]], %[[DBUF]][] : memref + +// CHECK: %[[BUF:.+]] = memref.alloca() : memref +// CHECK: "enzyme.push"(%[[CS]], %[[DBUF]]) : (!enzyme.Cache>, memref) -> () +// CHECK: memref.store %[[X]], %[[BUF]][] : memref +// CHECK: "enzyme.push"(%[[CL]], %[[DBUF]]) : (!enzyme.Cache>, memref) -> () +// CHECK: memref.load %[[BUF]][] : memref +// CHECK: cf.br ^bb1 +// CHECK: ^bb1: + +// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient) -> f64 +// CHECK: arith.addf %{{.+}}, %[[DOUT]] : f64 +// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient, f64) -> () + +// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient) -> f64 +// CHECK: %[[POPL:.+]] = "enzyme.pop"(%[[CL]]) : (!enzyme.Cache>) -> memref +// CHECK: memref.load %[[POPL]][] : memref +// CHECK: arith.addf +// CHECK: memref.store %{{.+}}, %[[POPL]][] : memref + +// CHECK: %[[POPS:.+]] = "enzyme.pop"(%[[CS]]) : (!enzyme.Cache>) -> memref +// CHECK: memref.load %[[POPS]][] : memref +// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient) -> f64 +// CHECK: arith.addf +// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient, f64) -> () +// CHECK: %[[ZERO_CLEAR:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: memref.store %[[ZERO_CLEAR]], %[[POPS]][] : memref + +// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient) -> f64 +// CHECK: return %{{.+}} : f64 + // CHECK-NOT: enzyme.placeholder -// The function ultimately returns the gradient w.r.t. x (= dout). -// CHECK: return %{{.*}} : f64 \ No newline at end of file From 435ee27934c9f18ed26f264788f10de8918bbe00 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 14 May 2026 00:16:56 -0500 Subject: [PATCH 4/8] fmt --- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 5e58e6e0b6c9..a3f9cb8099da 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -244,7 +244,7 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, void mlir::enzyme::MGradientUtils::setInvertedPointer(Value val, Value toset) { assert(getShadowType(val.getType()) == toset.getType()); - + if (auto found = invertedPointers.lookupOrNull(val)) { if (auto placeholder = found.getDefiningOp()) { placeholder.replaceAllUsesWith(toset); From 759bfb57accb4918213063052abc536017b6334f Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Wed, 20 May 2026 02:30:05 -0500 Subject: [PATCH 5/8] revert --- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index a3f9cb8099da..776841167486 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -244,13 +244,11 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, void mlir::enzyme::MGradientUtils::setInvertedPointer(Value val, Value toset) { assert(getShadowType(val.getType()) == toset.getType()); - - if (auto found = invertedPointers.lookupOrNull(val)) { - if (auto placeholder = found.getDefiningOp()) { - placeholder.replaceAllUsesWith(toset); - erase(placeholder); - } - } + auto found = invertedPointers.lookupOrNull(val); + assert(found != nullptr); + auto placeholder = found.getDefiningOp(); + placeholder.replaceAllUsesWith(toset); + erase(placeholder); invertedPointers.map(val, toset); } From 7f88905bde07c25820b7e6115fb675f237209cfb Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Thu, 21 May 2026 18:28:53 -0500 Subject: [PATCH 6/8] Revert "using scf.parallel+memref.store to zero out a memref" This reverts commit 5ddbd10a3c90e1c4bc46331d746d64ee6e918c0a. --- .../MemRefAutoDiffOpInterfaceImpl.cpp | 36 ++++++------------- 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 62ee51ba125d..1a47acf65136 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -18,10 +18,13 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" +// TODO: We need a way to zero out a memref (which linalg.fill does), but +// ideally we wouldn't depend on the linalg dialect. +#include "mlir/Dialect/Linalg/IR/Linalg.h" + using namespace mlir; using namespace mlir::enzyme; @@ -270,33 +273,14 @@ class MemRefAutoDiffTypeInterface LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { auto MT = cast(self); - auto eltIface = dyn_cast(MT.getElementType()); - if (!eltIface || eltIface.isMutable()) + if (auto iface = dyn_cast(MT.getElementType())) { + if (!iface.isMutable()) { + Value zero = iface.createNullValue(builder, loc); + linalg::FillOp::create(builder, loc, zero, val); + } + } else { return failure(); - Value zero = eltIface.createNullValue(builder, loc); - - if (MT.getRank() == 0) { - memref::StoreOp::create(builder, loc, zero, val, ValueRange{}); - return success(); } - - Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); - Value c1 = arith::ConstantIndexOp::create(builder, loc, 1); - - SmallVector lbs(MT.getRank(), c0); - SmallVector steps(MT.getRank(), c1); - SmallVector ubs; - for (auto [i, d] : llvm::enumerate(MT.getShape())) { - ubs.push_back( - d == ShapedType::kDynamic - ? memref::DimOp::create(builder, loc, val, i).getResult() - : arith::ConstantIndexOp::create(builder, loc, d).getResult()); - } - - scf::ParallelOp::create(builder, loc, lbs, ubs, steps, - [&](OpBuilder &b, Location l, ValueRange ivs) { - memref::StoreOp::create(b, l, zero, val, ivs); - }); return success(); } From 0222508494c94d3f2a9ef3cbae2880c3143a9fe5 Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Fri, 22 May 2026 23:47:49 -0500 Subject: [PATCH 7/8] add guard to setInvertedPointer --- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 776841167486..3995c354cbbf 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -229,6 +229,11 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, if (mode == DerivativeMode::ForwardMode || mode == DerivativeMode::ForwardModeSplit || isMutable) { + if (isMutable) { + auto existing = invertedPointers.lookupOrNull(val); + if (existing && !existing.getDefiningOp()) + return; + } setInvertedPointer(val, toset); } /* From 4cab90b95f660f12bb99f2a82351c7cba26853af Mon Sep 17 00:00:00 2001 From: Yuansui Xu Date: Sat, 23 May 2026 00:02:45 -0500 Subject: [PATCH 8/8] match linalg.fill --- enzyme/test/MLIR/ReverseMode/alloca.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/test/MLIR/ReverseMode/alloca.mlir b/enzyme/test/MLIR/ReverseMode/alloca.mlir index 76b47c8c4da3..84e6cc9c14fa 100644 --- a/enzyme/test/MLIR/ReverseMode/alloca.mlir +++ b/enzyme/test/MLIR/ReverseMode/alloca.mlir @@ -28,7 +28,7 @@ func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 { // CHECK: %[[DBUF:.+]] = memref.alloca() : memref // CHECK: %[[ZERO_INIT:.+]] = arith.constant 0.000000e+00 : f64 -// CHECK: memref.store %[[ZERO_INIT]], %[[DBUF]][] : memref +// CHECK: linalg.fill ins(%[[ZERO_INIT]] : f64) outs(%[[DBUF]] : memref) // CHECK: %[[BUF:.+]] = memref.alloca() : memref // CHECK: "enzyme.push"(%[[CS]], %[[DBUF]]) : (!enzyme.Cache>, memref) -> ()