diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index f0c06429385..3995c354cbb 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -223,8 +223,17 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, } assert(!isConstantValue(val)); + bool isMutable = false; + if (auto iface = dyn_cast(val.getType())) + isMutable = iface.isMutable(); + if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit) { + mode == DerivativeMode::ForwardModeSplit || isMutable) { + if (isMutable) { + auto existing = invertedPointers.lookupOrNull(val); + if (existing && !existing.getDefiningOp()) + return; + } setInvertedPointer(val, toset); } /* diff --git a/enzyme/test/MLIR/ReverseMode/alloca.mlir b/enzyme/test/MLIR/ReverseMode/alloca.mlir new file mode 100644 index 00000000000..84e6cc9c14f --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/alloca.mlir @@ -0,0 +1,62 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +func.func @foo_flat(%x : f64) -> f64 { + %buf = memref.alloca() : memref + memref.store %x, %buf[] : memref + %y = memref.load %buf[] : memref + return %y : f64 +} + +func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 { + %dx = enzyme.autodiff @foo_flat(%x, %dout) { + activity = [#enzyme], + ret_activity = [#enzyme] + } : (f64, f64) -> (f64) + return %dx : f64 +} + +// CHECK-LABEL: func.func private @diffefoo_flat( +// CHECK-SAME: %[[X:[^,]+]]: f64, +// CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 { + +// CHECK: %[[GX:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient, f64) -> () +// CHECK: %[[CS:.+]] = "enzyme.init"() : () -> !enzyme.Cache> +// CHECK: %[[CL:.+]] = "enzyme.init"() : () -> !enzyme.Cache> +// CHECK: %[[GY:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient, f64) -> () + +// CHECK: %[[DBUF:.+]] = memref.alloca() : memref +// CHECK: %[[ZERO_INIT:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: linalg.fill ins(%[[ZERO_INIT]] : f64) outs(%[[DBUF]] : memref) + +// CHECK: %[[BUF:.+]] = memref.alloca() : memref +// CHECK: "enzyme.push"(%[[CS]], %[[DBUF]]) : (!enzyme.Cache>, memref) -> () +// CHECK: memref.store %[[X]], %[[BUF]][] : memref +// CHECK: "enzyme.push"(%[[CL]], %[[DBUF]]) : (!enzyme.Cache>, memref) -> () +// CHECK: memref.load %[[BUF]][] : memref +// CHECK: cf.br ^bb1 +// CHECK: ^bb1: + +// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient) -> f64 +// CHECK: arith.addf %{{.+}}, %[[DOUT]] : f64 +// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient, f64) -> () + +// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient) -> f64 +// CHECK: %[[POPL:.+]] = "enzyme.pop"(%[[CL]]) : (!enzyme.Cache>) -> memref +// CHECK: memref.load %[[POPL]][] : memref +// CHECK: arith.addf +// CHECK: memref.store %{{.+}}, %[[POPL]][] : memref + +// CHECK: %[[POPS:.+]] = "enzyme.pop"(%[[CS]]) : (!enzyme.Cache>) -> memref +// CHECK: memref.load %[[POPS]][] : memref +// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient) -> f64 +// CHECK: arith.addf +// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient, f64) -> () +// CHECK: %[[ZERO_CLEAR:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: memref.store %[[ZERO_CLEAR]], %[[POPS]][] : memref + +// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient) -> f64 +// CHECK: return %{{.+}} : f64 + +// CHECK-NOT: enzyme.placeholder