Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,17 @@ void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset,
}
assert(!isConstantValue(val));

bool isMutable = false;
if (auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType()))
isMutable = iface.isMutable();

if (mode == DerivativeMode::ForwardMode ||
mode == DerivativeMode::ForwardModeSplit) {
mode == DerivativeMode::ForwardModeSplit || isMutable) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func.func @square(%x : memref<f64>){
     %y = memref.load %x[] : f64
     return %y
}

%out = enzyme.fwddiff dsquare(%x : memref<f64>, %dx : memref<f64>) {act = [enzyme_dup] ....}

In this case, the user provides the inverted Pointer right? We should ensure that we arent creating a shadow

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have add a guard in setDiffe when val is mutable or intertedPointers[val] is not a placeholderop, we skip instead of rewriting it.

if (isMutable) {
auto existing = invertedPointers.lookupOrNull(val);
if (existing && !existing.getDefiningOp<enzyme::PlaceholderOp>())
return;
}
setInvertedPointer(val, toset);
}
/*
Expand Down
62 changes: 62 additions & 0 deletions enzyme/test/MLIR/ReverseMode/alloca.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: %eopt --enzyme %s | FileCheck %s

func.func @foo_flat(%x : f64) -> f64 {
%buf = memref.alloca() : memref<f64>
memref.store %x, %buf[] : memref<f64>
%y = memref.load %buf[] : memref<f64>
return %y : f64
}

func.func @dfoo_flat(%x: f64, %dout : f64) -> f64 {
%dx = enzyme.autodiff @foo_flat(%x, %dout) {
activity = [#enzyme<activity enzyme_active>],
ret_activity = [#enzyme<activity enzyme_activenoneed>]
} : (f64, f64) -> (f64)
return %dx : f64
}

// CHECK-LABEL: func.func private @diffefoo_flat(
// CHECK-SAME: %[[X:[^,]+]]: f64,
// CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 {

// CHECK: %[[GX:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK: %[[CS:.+]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
// CHECK: %[[CL:.+]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
// CHECK: %[[GY:.+]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()

// CHECK: %[[DBUF:.+]] = memref.alloca() : memref<f64>
// CHECK: %[[ZERO_INIT:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: linalg.fill ins(%[[ZERO_INIT]] : f64) outs(%[[DBUF]] : memref<f64>)

// CHECK: %[[BUF:.+]] = memref.alloca() : memref<f64>
// CHECK: "enzyme.push"(%[[CS]], %[[DBUF]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
// CHECK: memref.store %[[X]], %[[BUF]][] : memref<f64>
// CHECK: "enzyme.push"(%[[CL]], %[[DBUF]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
// CHECK: memref.load %[[BUF]][] : memref<f64>
// CHECK: cf.br ^bb1
// CHECK: ^bb1:

// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK: arith.addf %{{.+}}, %[[DOUT]] : f64
// CHECK: "enzyme.set"(%[[GY]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()

// CHECK: %{{.+}} = "enzyme.get"(%[[GY]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK: %[[POPL:.+]] = "enzyme.pop"(%[[CL]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
// CHECK: memref.load %[[POPL]][] : memref<f64>
// CHECK: arith.addf
// CHECK: memref.store %{{.+}}, %[[POPL]][] : memref<f64>

// CHECK: %[[POPS:.+]] = "enzyme.pop"(%[[CS]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
// CHECK: memref.load %[[POPS]][] : memref<f64>
// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK: arith.addf
// CHECK: "enzyme.set"(%[[GX]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
// CHECK: %[[ZERO_CLEAR:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: memref.store %[[ZERO_CLEAR]], %[[POPS]][] : memref<f64>

// CHECK: %{{.+}} = "enzyme.get"(%[[GX]]) : (!enzyme.Gradient<f64>) -> f64
// CHECK: return %{{.+}} : f64

// CHECK-NOT: enzyme.placeholder
Loading