Skip to content

Commit 8d66b16

Browse files
committed
mlir: Bugfixes with scf.for derivative
- register ub dialect with enzymemlir-opt
1 parent 8520c15 commit 8d66b16

3 files changed

Lines changed: 110 additions & 9 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,18 @@ struct ForOpEnzymeOpsRemover
102102
if (!map.contains(canIdx)) {
103103
assert(Equivalent(forOp.getLowerBound(), otherForOp.getLowerBound()));
104104
assert(Equivalent(forOp.getStep(), otherForOp.getStep()));
105-
map.map(forOp.getBody()->getArgument(0),
106-
otherForOp.getBody()->getArgument(0));
105+
106+
Location loc = forOp.getLoc();
107+
// The reverse IV can be computed as (lb + ub - 1 - iv)
108+
Value revIV =
109+
arith::AddIOp::create(rewriter, loc, otherForOp.getLowerBound(),
110+
otherForOp.getUpperBound());
111+
Value c1 = arith::ConstantOp::create(
112+
rewriter, loc, IntegerAttr::get(revIV.getType(), 1));
113+
revIV = arith::SubIOp::create(rewriter, loc, revIV, c1);
114+
revIV = arith::SubIOp::create(rewriter, loc, revIV,
115+
otherForOp.getBody()->getArgument(0));
116+
map.map(forOp.getBody()->getArgument(0), revIV);
107117
}
108118
return map;
109119
}
@@ -167,11 +177,13 @@ struct ForOpInterfaceReverse
167177
// variable).
168178

169179
auto forOp = cast<scf::ForOp>(op);
180+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
170181

171182
SmallVector<bool> operandsActive(forOp.getNumOperands() - 3, false);
172183
for (int i = 0, e = operandsActive.size(); i < e; ++i) {
173184
operandsActive[i] = !gutils->isConstantValue(op->getOperand(i + 3)) ||
174-
!gutils->isConstantValue(op->getResult(i));
185+
!gutils->isConstantValue(op->getResult(i)) ||
186+
!gutils->isConstantValue(yieldOp.getOperand(i));
175187
}
176188

177189
SmallVector<Value> incomingGradients;
@@ -412,14 +424,23 @@ struct ForOpInterfaceReverse
412424

413425
auto term = oBB.getTerminator();
414426

427+
for (auto &&[active, operand] :
428+
llvm::zip_equal(operandsActive, term->getOperands())) {
429+
if (active) {
430+
// Zero the diffe at the start of each iteration because it should
431+
// not accumulate across iterations. The new gradient is passed as
432+
// an iter_arg in the reverse for.
433+
gutils->zeroDiffe(operand, bodyBuilder);
434+
}
435+
}
436+
415437
unsigned argIdx = 1; // Skip over the reversed IV
416438
for (auto &&[active, operand] :
417439
llvm::zip_equal(operandsActive, term->getOperands())) {
418440
if (active) {
419-
// Set diffe here, not add because it should not accumulate across
420-
// iterations. Instead the new gradient for this operand is passed
421-
// in the return of the reverse for body.
422-
gutils->setDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
441+
// If the same value is yielded multiple times in the original, the
442+
// gradients must be accumulated.
443+
gutils->addToDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
423444
argIdx++;
424445
}
425446
}

enzyme/Enzyme/MLIR/enzymemlir-opt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ int main(int argc, char **argv) {
7373
registry.insert<mlir::math::MathDialect>();
7474
registry.insert<mlir::linalg::LinalgDialect>();
7575
registry.insert<mlir::tensor::TensorDialect>();
76+
registry.insert<mlir::ub::UBDialect>();
7677
registry.insert<DLTIDialect>();
7778

7879
registry.insert<mlir::enzyme::EnzymeDialect>();

enzyme/test/MLIR/ReverseMode/scf_for_memref.mlir

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %eopt %s --enzyme-wrap="infn=reduce outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --canonicalize | FileCheck %s
1+
// RUN: %eopt %s --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math --canonicalize --split-input-file | FileCheck %s
22

33
func.func @reduce(%x: f32, %ub: index) -> (f32) {
44
%lb = arith.constant 0 : index
@@ -22,7 +22,15 @@ func.func @reduce(%x: f32, %ub: index) -> (f32) {
2222
return %sum2 : f32
2323
}
2424

25-
// CHECK: func.func @reduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
25+
func.func @dreduce(%x: f32, %ub: index, %dres: f32) -> (f32) {
26+
%res = enzyme.autodiff @reduce(%x, %ub, %dres) {
27+
activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>],
28+
ret_activity = [#enzyme<activity enzyme_activenoneed>]
29+
} : (f32, index, f32) -> f32
30+
return %res : f32
31+
}
32+
33+
// CHECK: func.func private @differeduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
2634
// CHECK-NEXT: %c3 = arith.constant 3 : index
2735
// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f32
2836
// CHECK-NEXT: %c4 = arith.constant 4 : index
@@ -61,3 +69,74 @@ func.func @reduce(%x: f32, %ub: index) -> (f32) {
6169
// CHECK-NEXT: memref.dealloc %alloc : memref<?x4xf32>
6270
// CHECK-NEXT: return %[[v1]]#1 : f32
6371
// CHECK-NEXT: }
72+
73+
// -----
74+
75+
func.func private @reverse_index(%lb: index, %ub: index, %x: memref<?xf32>) -> f32 {
76+
%c1 = arith.constant 1 : index
77+
%zero = arith.constant 0.0 : f32
78+
%poison = ub.poison : f32
79+
// Verify correct loop indices. e.g. if the loop iterates [3, 4, 5, 6],
80+
// the reversed loop indices are [6, 5, 4, 3] while the *canonical*
81+
// reversed indices are [3, 2, 1, 0]. Memref accesses should use the
82+
// reversed indices while caches should use the canonical reversed indices.
83+
%res:2 = scf.for %iv = %lb to %ub step %c1 iter_args(%acc = %zero, %acc2 = %poison) -> (f32, f32) {
84+
%ld = memref.load %x[%iv] : memref<?xf32>
85+
%mulf = arith.mulf %ld, %ld : f32
86+
memref.store %mulf, %x[%iv] : memref<?xf32>
87+
%addf = arith.addf %mulf, %acc : f32
88+
scf.yield %addf, %addf : f32, f32
89+
}
90+
return %res#1 : f32
91+
}
92+
93+
func.func @dreverse_index(%lb: index, %ub: index, %x: memref<?xf32>, %dx: memref<?xf32>, %dr: f32) {
94+
enzyme.autodiff @reverse_index(%lb, %ub, %x, %dx, %dr) {
95+
activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_dup>],
96+
ret_activity = [#enzyme<activity enzyme_activenoneed>]
97+
} : (index, index, memref<?xf32>, memref<?xf32>, f32) -> ()
98+
return
99+
}
100+
101+
// CHECK-LABEL: func.func private @differeverse_index(
102+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index,
103+
// CHECK-SAME: %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<?xf32>,
104+
// CHECK-SAME: %[[ARG4:.*]]: f32) {
105+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
106+
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0.000000e+00 : f32
107+
// CHECK: %[[SUBI_0:.*]] = arith.subi %[[ARG1]], %[[ARG0]] : index
108+
// CHECK: %[[ALLOC_0:.*]] = memref.alloc(%[[SUBI_0]]) : memref<?xf32>
109+
// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[CONSTANT_1]]) -> (f32) {
110+
// CHECK: %[[SUBI_1:.*]] = arith.subi %[[VAL_0]], %[[ARG0]] : index
111+
// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
112+
// CHECK: memref.store %[[LOAD_0]], %[[ALLOC_0]]{{\[}}%[[SUBI_1]]] : memref<?xf32>
113+
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_0]], %[[LOAD_0]] : f32
114+
// CHECK: memref.store %[[MULF_0]], %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
115+
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_1]] : f32
116+
// CHECK: scf.yield %[[ADDF_0]] : f32
117+
// CHECK: }
118+
// CHECK: %[[SUBI_2:.*]] = arith.subi %[[ARG1]], %[[ARG0]] : index
119+
// CHECK: %[[FOR_1:.*]]:2 = scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_3:.*]] = %[[CONSTANT_1]], %[[VAL_4:.*]] = %[[ARG4]]) -> (f32, f32) {
120+
// CHECK: %[[SUBI_3:.*]] = arith.subi %[[VAL_2]], %[[ARG0]] : index
121+
// CHECK: %[[SUBI_4:.*]] = arith.subi %[[SUBI_2]], %[[CONSTANT_0]] : index
122+
// CHECK: %[[RCANON_IV:.*]] = arith.subi %[[SUBI_4]], %[[SUBI_3]] : index
123+
// CHECK: %[[ADDI_0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : index
124+
// CHECK: %[[SUBI_6:.*]] = arith.subi %[[ADDI_0]], %[[CONSTANT_0]] : index
125+
// CHECK: %[[R_IV:.*]] = arith.subi %[[SUBI_6]], %[[VAL_2]] : index
126+
// CHECK: %[[LOAD_1:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[RCANON_IV]]] : memref<?xf32>
127+
// The gradient signal should be added together
128+
// CHECK: %[[ADDF_1:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
129+
// CHECK: %[[LOAD_2:.*]] = memref.load %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
130+
// CHECK: %[[ADDF_2:.*]] = arith.addf %[[ADDF_1]], %[[LOAD_2]] : f32
131+
// CHECK: memref.store %[[CONSTANT_1]], %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
132+
// CHECK: %[[MULF_1:.*]] = arith.mulf %[[ADDF_2]], %[[LOAD_1]] : f32
133+
// CHECK: %[[MULF_2:.*]] = arith.mulf %[[ADDF_2]], %[[LOAD_1]] : f32
134+
// CHECK: %[[ADDF_3:.*]] = arith.addf %[[MULF_1]], %[[MULF_2]] : f32
135+
// CHECK: %[[LOAD_3:.*]] = memref.load %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
136+
// CHECK: %[[ADDF_4:.*]] = arith.addf %[[LOAD_3]], %[[ADDF_3]] : f32
137+
// CHECK: memref.store %[[ADDF_4]], %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
138+
// CHECK: scf.yield %[[ADDF_1]], %[[CONSTANT_1]] : f32, f32
139+
// CHECK: }
140+
// CHECK: memref.dealloc %[[ALLOC_0]] : memref<?xf32>
141+
// CHECK: return
142+
// CHECK: }

0 commit comments

Comments
 (0)