Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,18 @@ struct ForOpEnzymeOpsRemover
if (!map.contains(canIdx)) {
assert(Equivalent(forOp.getLowerBound(), otherForOp.getLowerBound()));
assert(Equivalent(forOp.getStep(), otherForOp.getStep()));
map.map(forOp.getBody()->getArgument(0),
otherForOp.getBody()->getArgument(0));

Location loc = forOp.getLoc();
// The reverse IV can be computed as (lb + ub - 1 - iv)
Value revIV =
arith::AddIOp::create(rewriter, loc, otherForOp.getLowerBound(),
otherForOp.getUpperBound());
Value c1 = arith::ConstantOp::create(
rewriter, loc, IntegerAttr::get(revIV.getType(), 1));
revIV = arith::SubIOp::create(rewriter, loc, revIV, c1);
revIV = arith::SubIOp::create(rewriter, loc, revIV,
otherForOp.getBody()->getArgument(0));
map.map(forOp.getBody()->getArgument(0), revIV);
}
return map;
}
Expand Down Expand Up @@ -167,11 +177,13 @@ struct ForOpInterfaceReverse
// variable).

auto forOp = cast<scf::ForOp>(op);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());

SmallVector<bool> operandsActive(forOp.getNumOperands() - 3, false);
for (int i = 0, e = operandsActive.size(); i < e; ++i) {
operandsActive[i] = !gutils->isConstantValue(op->getOperand(i + 3)) ||
!gutils->isConstantValue(op->getResult(i));
!gutils->isConstantValue(op->getResult(i)) ||
!gutils->isConstantValue(yieldOp.getOperand(i));
}

SmallVector<Value> incomingGradients;
Expand Down Expand Up @@ -412,14 +424,23 @@ struct ForOpInterfaceReverse

auto term = oBB.getTerminator();

for (auto &&[active, operand] :
llvm::zip_equal(operandsActive, term->getOperands())) {
if (active) {
// Zero the diffe at the start of each iteration because it should
// not accumulate across iterations. The new gradient is passed as
// an iter_arg in the reverse for.
gutils->zeroDiffe(operand, bodyBuilder);
}
}

unsigned argIdx = 1; // Skip over the reversed IV
for (auto &&[active, operand] :
llvm::zip_equal(operandsActive, term->getOperands())) {
if (active) {
// Set diffe here, not add because it should not accumulate across
// iterations. Instead the new gradient for this operand is passed
// in the return of the reverse for body.
gutils->setDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
// If the same value is yielded multiple times in the original, the
// gradients must be accumulated.
gutils->addToDiffe(operand, revBB.getArgument(argIdx), bodyBuilder);
argIdx++;
}
}
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/enzymemlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ int main(int argc, char **argv) {
registry.insert<mlir::math::MathDialect>();
registry.insert<mlir::linalg::LinalgDialect>();
registry.insert<mlir::tensor::TensorDialect>();
registry.insert<mlir::ub::UBDialect>();
registry.insert<DLTIDialect>();

registry.insert<mlir::enzyme::EnzymeDialect>();
Expand Down
83 changes: 81 additions & 2 deletions enzyme/test/MLIR/ReverseMode/scf_for_memref.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt %s --enzyme-wrap="infn=reduce outfn= argTys=enzyme_active,enzyme_const retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --canonicalize | FileCheck %s
// RUN: %eopt %s --enzyme --canonicalize --remove-unnecessary-enzyme-ops --enzyme-simplify-math --canonicalize --split-input-file | FileCheck %s

func.func @reduce(%x: f32, %ub: index) -> (f32) {
%lb = arith.constant 0 : index
Expand All @@ -22,7 +22,15 @@ func.func @reduce(%x: f32, %ub: index) -> (f32) {
return %sum2 : f32
}

// CHECK: func.func @reduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
func.func @dreduce(%x: f32, %ub: index, %dres: f32) -> (f32) {
%res = enzyme.autodiff @reduce(%x, %ub, %dres) {
activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>],
ret_activity = [#enzyme<activity enzyme_activenoneed>]
} : (f32, index, f32) -> f32
return %res : f32
}

// CHECK: func.func private @differeduce(%arg0: f32, %arg1: index, %arg2: f32) -> f32 {
// CHECK-NEXT: %c3 = arith.constant 3 : index
// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: %c4 = arith.constant 4 : index
Expand Down Expand Up @@ -61,3 +69,74 @@ func.func @reduce(%x: f32, %ub: index) -> (f32) {
// CHECK-NEXT: memref.dealloc %alloc : memref<?x4xf32>
// CHECK-NEXT: return %[[v1]]#1 : f32
// CHECK-NEXT: }

// -----

func.func private @reverse_index(%lb: index, %ub: index, %x: memref<?xf32>) -> f32 {
%c1 = arith.constant 1 : index
%zero = arith.constant 0.0 : f32
%poison = ub.poison : f32
// Verify correct loop indices. e.g. if the loop iterates [3, 4, 5, 6],
// the reversed loop indices are [6, 5, 4, 3] while the *canonical*
// reversed indices are [3, 2, 1, 0]. Memref accesses should use the
// reversed indices while caches should use the canonical reversed indices.
%res:2 = scf.for %iv = %lb to %ub step %c1 iter_args(%acc = %zero, %acc2 = %poison) -> (f32, f32) {
%ld = memref.load %x[%iv] : memref<?xf32>
%mulf = arith.mulf %ld, %ld : f32
memref.store %mulf, %x[%iv] : memref<?xf32>
%addf = arith.addf %mulf, %acc : f32
scf.yield %addf, %addf : f32, f32
}
return %res#1 : f32
}

func.func @dreverse_index(%lb: index, %ub: index, %x: memref<?xf32>, %dx: memref<?xf32>, %dr: f32) {
enzyme.autodiff @reverse_index(%lb, %ub, %x, %dx, %dr) {
activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_dup>],
ret_activity = [#enzyme<activity enzyme_activenoneed>]
} : (index, index, memref<?xf32>, memref<?xf32>, f32) -> ()
return
}

// CHECK-LABEL: func.func private @differeverse_index(
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: memref<?xf32>, %[[ARG3:.*]]: memref<?xf32>,
// CHECK-SAME: %[[ARG4:.*]]: f32) {
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 1 : index
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[SUBI_0:.*]] = arith.subi %[[ARG1]], %[[ARG0]] : index
// CHECK: %[[ALLOC_0:.*]] = memref.alloc(%[[SUBI_0]]) : memref<?xf32>
// CHECK: %[[FOR_0:.*]] = scf.for %[[VAL_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_1:.*]] = %[[CONSTANT_1]]) -> (f32) {
// CHECK: %[[SUBI_1:.*]] = arith.subi %[[VAL_0]], %[[ARG0]] : index
// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
// CHECK: memref.store %[[LOAD_0]], %[[ALLOC_0]]{{\[}}%[[SUBI_1]]] : memref<?xf32>
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_0]], %[[LOAD_0]] : f32
// CHECK: memref.store %[[MULF_0]], %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_0]], %[[VAL_1]] : f32
// CHECK: scf.yield %[[ADDF_0]] : f32
// CHECK: }
// CHECK: %[[SUBI_2:.*]] = arith.subi %[[ARG1]], %[[ARG0]] : index
// CHECK: %[[FOR_1:.*]]:2 = scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG1]] step %[[CONSTANT_0]] iter_args(%[[VAL_3:.*]] = %[[CONSTANT_1]], %[[VAL_4:.*]] = %[[ARG4]]) -> (f32, f32) {
// CHECK: %[[SUBI_3:.*]] = arith.subi %[[VAL_2]], %[[ARG0]] : index
// CHECK: %[[SUBI_4:.*]] = arith.subi %[[SUBI_2]], %[[CONSTANT_0]] : index
// CHECK: %[[RCANON_IV:.*]] = arith.subi %[[SUBI_4]], %[[SUBI_3]] : index
// CHECK: %[[ADDI_0:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : index
// CHECK: %[[SUBI_6:.*]] = arith.subi %[[ADDI_0]], %[[CONSTANT_0]] : index
// CHECK: %[[R_IV:.*]] = arith.subi %[[SUBI_6]], %[[VAL_2]] : index
// CHECK: %[[LOAD_1:.*]] = memref.load %[[ALLOC_0]]{{\[}}%[[RCANON_IV]]] : memref<?xf32>
// The gradient signal should be added together
// CHECK: %[[ADDF_1:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32
// CHECK: %[[LOAD_2:.*]] = memref.load %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
// CHECK: %[[ADDF_2:.*]] = arith.addf %[[ADDF_1]], %[[LOAD_2]] : f32
// CHECK: memref.store %[[CONSTANT_1]], %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
// CHECK: %[[MULF_1:.*]] = arith.mulf %[[ADDF_2]], %[[LOAD_1]] : f32
// CHECK: %[[MULF_2:.*]] = arith.mulf %[[ADDF_2]], %[[LOAD_1]] : f32
// CHECK: %[[ADDF_3:.*]] = arith.addf %[[MULF_1]], %[[MULF_2]] : f32
// CHECK: %[[LOAD_3:.*]] = memref.load %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
// CHECK: %[[ADDF_4:.*]] = arith.addf %[[LOAD_3]], %[[ADDF_3]] : f32
// CHECK: memref.store %[[ADDF_4]], %[[ARG3]]{{\[}}%[[R_IV]]] : memref<?xf32>
// CHECK: scf.yield %[[ADDF_1]], %[[CONSTANT_1]] : f32, f32
// CHECK: }
// CHECK: memref.dealloc %[[ALLOC_0]] : memref<?xf32>
// CHECK: return
// CHECK: }
Loading