Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 271 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,275 @@ struct ForOpInterfaceReverse
}
};

struct WhileOpADDataFlow
: public ADDataFlowOpInterface::ExternalModel<WhileOpADDataFlow,
scf::WhileOp> {
SmallVector<Value> getPotentialIncomingValuesRes(Operation *op,
OpResult res) const {
auto whileOp = cast<scf::WhileOp>(op);
return {whileOp.getBeforeBody()->getTerminator()->getOperand(
res.getResultNumber() + 1)};
}
SmallVector<Value> getPotentialIncomingValuesArg(Operation *op,
BlockArgument arg) const {
auto whileOp = cast<scf::WhileOp>(op);
if (arg.getOwner() == whileOp.getBeforeBody()) {
return {whileOp->getOperand(arg.getArgNumber()),
whileOp.getAfterBody()->getTerminator()->getOperand(
arg.getArgNumber())};
}
return {whileOp.getBeforeBody()->getTerminator()->getOperand(
arg.getArgNumber() + 1)};
}
SmallVector<Value> getPotentialTerminatorUsers(Operation *op, Operation *term,
Value val) const {
auto whileOp = cast<scf::WhileOp>(op);
SmallVector<Value> sv;

if (term->getBlock() == whileOp.getBeforeBody()) {
for (auto &&[res, arg, barg] : llvm::zip_equal(
whileOp->getResults(), term->getOperands().drop_front(),
whileOp.getAfterBody()->getArguments())) {
if (arg == val) {
sv.push_back(res);
sv.push_back(barg);
}
}
} else if (term->getBlock() == whileOp.getAfterBody()) {
for (auto &&[arg, barg] : llvm::zip_equal(
term->getOperands(), whileOp.getBeforeBody()->getArguments())) {
if (arg == val) {
sv.push_back(barg);
}
}
}

return sv;
}
};

struct WhileOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<WhileOpInterfaceReverse,
scf::WhileOp> {

LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto whileOp = cast<scf::WhileOp>(op);
OpBuilder::InsertionGuard guard(builder);

bool valid = true;

Value numIters = gutils->popCache(caches[0], builder);

SmallVector<bool> operandsActive;
SmallVector<Value> incomingGradients;

for (auto [operand, result, argBefore, argAfter] :
llvm::zip_equal(op->getOperands(), op->getResults(),
whileOp.getBeforeBody()->getArguments(),
whileOp.getAfterBody()->getArguments())) {
bool active = !gutils->isConstantValue(operand) ||
!gutils->isConstantValue(result) ||
!gutils->isConstantValue(argBefore) ||
!gutils->isConstantValue(argAfter);

operandsActive.push_back(active);
if (active) {
incomingGradients.push_back(gutils->diffe(result, builder));
if (!gutils->isConstantValue(result))
gutils->zeroDiffe(result, builder);
}
}

auto forOp = builder.create<scf::ForOp>(
op->getLoc(),
builder.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(numIters.getType(), 0)),
numIters,
builder.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(numIters.getType(), 1)),
incomingGradients);

SmallVector<Value> outgoingGradients;

auto zeroAllDiffes = [&](Block *oBB, OpBuilder &builder) {
// All values defined in the body should have no use outside this block
// therefore we can set their diffe to zero upon entering the reverse
// block to simplify the work of the remove-unnecessary-enzyme-ops pass.
for (auto operand : oBB->getArguments()) {
if (!gutils->isConstantValue(operand)) {
gutils->zeroDiffe(operand, builder);
}
}

for (auto &it : oBB->getOperations()) {
for (auto res : it.getResults()) {
if (!gutils->isConstantValue(res)) {
gutils->zeroDiffe(res, builder);
}
}
}
};

auto makeReverse = [&](Block *oBB, OpBuilder &builder) {
bool valid = true;
auto first = oBB->rbegin();
first++; // skip terminator
auto last = oBB->rend();
for (auto it = first; it != last; ++it) {
Operation *op = &*it;
valid &= gutils->Logic.visitChild(op, builder, gutils).succeeded();
}
return valid;
};

builder.setInsertionPointToEnd(forOp.getBody());
{
Block *oBB = whileOp.getBeforeBody();
Operation *term = oBB->getTerminator();

zeroAllDiffes(oBB, builder);

unsigned revIdx = 1;
for (auto [active, operand] :
llvm::zip_equal(operandsActive, term->getOperands().drop_front())) {
if (active) {
gutils->addToDiffe(operand, forOp.getBody()->getArgument(revIdx),
builder);
revIdx++;
}
}

valid &= makeReverse(oBB, builder);

for (auto &&[active, arg] :
llvm::zip(operandsActive, oBB->getArguments())) {
if (active) {
outgoingGradients.push_back(gutils->diffe(arg, builder));
if (!gutils->isConstantValue(arg))
gutils->zeroDiffe(arg, builder);
}
}
}

// In the forward, if this is the last iteration, then the after body is not
// executed.
//
// In the reverse, the after reverse is not executed for the first
// iteration.
Value isLastIteration = builder.create<arith::CmpIOp>(
whileOp.getBeforeBody()->getTerminator()->getLoc(),
arith::CmpIPredicate::eq, forOp.getInductionVar(),
builder.create<arith::ConstantIndexOp>(
whileOp.getBeforeBody()->getTerminator()->getLoc(), 0));
auto ifOp = builder.create<scf::IfOp>(
op->getLoc(), ValueRange(incomingGradients).getTypes(), isLastIteration,
/*withElseRegion*/ true);

{
builder.setInsertionPointToEnd(ifOp.thenBlock());
builder.create<scf::YieldOp>(op->getLoc(), outgoingGradients);
}

{
Block *oBB = whileOp.getAfterBody();
Operation *term = oBB->getTerminator();
builder.setInsertionPointToEnd(ifOp.elseBlock());

zeroAllDiffes(oBB, builder);

unsigned revIdx = 0;
for (auto [active, operand] :
llvm::zip_equal(operandsActive, term->getOperands())) {
if (active) {
gutils->addToDiffe(operand, outgoingGradients[revIdx], builder);
revIdx++;
}
}

valid &= makeReverse(oBB, builder);

outgoingGradients.clear();
for (auto &&[active, arg] :
llvm::zip(operandsActive, oBB->getArguments())) {
if (active) {
outgoingGradients.push_back(gutils->diffe(arg, builder));
if (!gutils->isConstantValue(arg))
gutils->zeroDiffe(arg, builder);
}
}

builder.create<scf::YieldOp>(op->getLoc(), outgoingGradients);
}

builder.setInsertionPointToEnd(forOp.getBody());

builder.create<scf::YieldOp>(op->getLoc(), ifOp.getResults());

builder.setInsertionPointAfter(forOp);

int revIdx = 0;
for (auto &&[active, arg] : llvm::zip(operandsActive, op->getOperands())) {
if (active) {
if (!gutils->isConstantValue(arg))
gutils->addToDiffe(arg, forOp->getResult(revIdx), builder);
revIdx++;
}
}

return success(valid);
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
// Cache the number of iterations of the *before* block.
auto whileOp = cast<scf::WhileOp>(op);

auto newOp = cast<scf::WhileOp>(gutils->getNewFromOriginal(op));
OpBuilder builder(newOp);

Block *before = newOp.getBeforeBody(), *after = newOp.getAfterBody();

auto zero = builder.create<arith::ConstantIndexOp>(op->getLoc(), 0);

Value inBefore = before->addArgument(zero.getType(), zero.getLoc());
Operation *beforeTerm = before->getTerminator();
builder.setInsertionPoint(beforeTerm);
inBefore = builder.create<arith::AddIOp>(
op->getLoc(), inBefore,
builder.create<arith::ConstantIndexOp>(op->getLoc(), 1));
beforeTerm->insertOperands(beforeTerm->getNumOperands(), inBefore);

Value inAfter = after->addArgument(zero.getType(), zero.getLoc());
Operation *afterTerm = after->getTerminator();
afterTerm->insertOperands(afterTerm->getNumOperands(), inAfter);

SmallVector<Value> initArgs(newOp->getOperands().begin(),
newOp->getOperands().end());
initArgs.push_back(zero);

builder.setInsertionPoint(newOp);
auto newWhile = builder.create<scf::WhileOp>(
newOp->getLoc(), ValueRange(initArgs).getTypes(), initArgs);

newWhile.getBefore().takeBody(newOp.getBefore());
newWhile.getAfter().takeBody(newOp.getAfter());

Value numItersCache = gutils->initAndPushCache(
newWhile->getResult(newWhile->getNumResults() - 1), builder);

gutils->replaceOrigOpWith(op, newWhile->getResults().drop_back());
gutils->erase(newOp);
gutils->originalToNewFnOps[op] = newWhile;

return {numItersCache};
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};

} // namespace

void mlir::enzyme::registerSCFDialectAutoDiffInterface(
Expand All @@ -572,5 +841,7 @@ void mlir::enzyme::registerSCFDialectAutoDiffInterface(
registerInterfaces(context);
scf::ForOp::attachInterface<ForOpInterfaceReverse>(*context);
scf::ForOp::attachInterface<ForOpEnzymeOpsRemover>(*context);
scf::WhileOp::attachInterface<WhileOpInterfaceReverse>(*context);
scf::WhileOp::attachInterface<WhileOpADDataFlow>(*context);
});
}
65 changes: 65 additions & 0 deletions enzyme/test/MLIR/ReverseMode/scf_while.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN:%eopt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_active retTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math | FileCheck %s

module {
func.func @main(%init1: f32) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index

%res:2 = scf.while (%iter = %c0, %arg1 = %init1) : (index, f32) -> (index, f32) {
// "Before" region.
// In a "while" loop, this region computes the condition.
%condition = arith.cmpi eq, %iter, %c10 : index

// Forward the argument (as result or "after" region argument).
scf.condition(%condition) %iter, %arg1 : index, f32

} do {
^bb0(%iterAfter: index, %arg2: f32):
// "After" region.
// In a "while" loop, this region is the loop body.
%next = arith.mulf %arg2, %arg2 : f32
%nextIter = arith.addi %iterAfter, %c1 : index

// Forward the new value to the "before" region.
// The operand types must match the types of the `scf.while` operands.
scf.yield %nextIter, %next : index, f32
}
return %res#1 : f32
}
}

// CHECK: func.func @main(%arg0: f32, %arg1: f32) -> f32 {
// CHECK-NEXT: %c10 = arith.constant 10 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %[[v0:.+]] = "enzyme.init"() : () -> !enzyme.Cache<f32>
// CHECK-NEXT: %[[v1:.+]] = "enzyme.init"() : () -> !enzyme.Cache<f32>
// CHECK-NEXT: %[[v2:.+]]:3 = scf.while (%arg2 = %c0, %arg3 = %arg0, %arg4 = %c0) : (index, f32, index) -> (index, f32, index) {
// CHECK-NEXT: %[[v4:.+]] = arith.cmpi eq, %arg2, %c10 : index
// CHECK-NEXT: %[[v5:.+]] = arith.addi %arg4, %c1 : index
// CHECK-NEXT: scf.condition(%[[v4]]) %arg2, %arg3, %[[v5]] : index, f32, index
// CHECK-NEXT: } do {
// CHECK-NEXT: ^bb0(%arg2: index, %arg3: f32, %arg4: index):
// CHECK-NEXT: "enzyme.push"(%[[v1]], %arg3) : (!enzyme.Cache<f32>, f32) -> ()
// CHECK-NEXT: "enzyme.push"(%[[v0]], %arg3) : (!enzyme.Cache<f32>, f32) -> ()
// CHECK-NEXT: %[[v4:.+]] = arith.mulf %arg3, %arg3 : f32
// CHECK-NEXT: %[[v5:.+]] = arith.addi %arg2, %c1 : index
// CHECK-NEXT: scf.yield %[[v5]], %[[v4]], %arg4 : index, f32, index
// CHECK-NEXT: }
// CHECK-NEXT: %[[v3:.+]] = scf.for %arg2 = %c0 to %[[v2]]#2 step %c1 iter_args(%arg3 = %arg1) -> (f32) {
// CHECK-NEXT: %[[v4:.+]] = arith.cmpi eq, %arg2, %c0 : index
// CHECK-NEXT: %[[v5:.+]] = scf.if %[[v4]] -> (f32) {
// CHECK-NEXT: scf.yield %arg3 : f32
// CHECK-NEXT: } else {
// CHECK-NEXT: %[[v6:.+]] = "enzyme.pop"(%[[v1]]) : (!enzyme.Cache<f32>) -> f32
// CHECK-NEXT: %[[v7:.+]] = "enzyme.pop"(%[[v0]]) : (!enzyme.Cache<f32>) -> f32
// CHECK-NEXT: %[[v8:.+]] = arith.mulf %arg3, %[[v7]] : f32
// CHECK-NEXT: %[[v9:.+]] = arith.mulf %arg3, %[[v6]] : f32
// CHECK-NEXT: %[[v10:.+]] = arith.addf %[[v8]], %[[v9]] : f32
// CHECK-NEXT: scf.yield %[[v10]] : f32
// CHECK-NEXT: }
// CHECK-NEXT: scf.yield %[[v5]] : f32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[v3]] : f32
// CHECK-NEXT: }
Loading