Skip to content

Commit b8882c1

Browse files
committed
mlir: support active pointers in scf.if
1 parent 8520c15 commit b8882c1

2 files changed

Lines changed: 132 additions & 2 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,10 @@ struct IfOpInterfaceReverse
885885

886886
SmallVector<bool> resultsActive(ifOp.getNumResults(), false);
887887
for (int i = 0, e = resultsActive.size(); i < e; ++i) {
888-
resultsActive[i] = !gutils->isConstantValue(ifOp.getResult(i));
888+
auto result = ifOp.getResult(i);
889+
auto iface = dyn_cast<AutoDiffTypeInterface>(result.getType());
890+
bool needsGrad = iface && !iface.isMutable();
891+
resultsActive[i] = needsGrad && !gutils->isConstantValue(result);
889892
}
890893

891894
SmallVector<Value> incomingGradients;
@@ -969,7 +972,72 @@ struct IfOpInterfaceReverse
969972
}
970973

971974
void createShadowValues(Operation *op, OpBuilder &builder,
972-
MGradientUtilsReverse *gutils) const {}
975+
MGradientUtilsReverse *gutils) const {
976+
// TODO: consider making this generic for RegionBranchOpInterface
977+
auto ifOp = cast<scf::IfOp>(op);
978+
if (ifOp.getNumResults() == 0)
979+
return;
980+
981+
auto newIf = cast<scf::IfOp>(gutils->getNewFromOriginal(ifOp));
982+
SmallVector<Type> newResultTypes;
983+
SmallVector<bool> needsShadow(op->getNumResults());
984+
for (auto result : op->getResults()) {
985+
// TODO: consider isActivePointer/isActiveData methods on gutils?
986+
newResultTypes.push_back(result.getType());
987+
auto iface = dyn_cast<AutoDiffTypeInterface>(result.getType());
988+
if (iface && iface.isMutable() && !gutils->isConstantValue(result)) {
989+
newResultTypes.push_back(result.getType());
990+
needsShadow[result.getResultNumber()] = true;
991+
} else {
992+
needsShadow[result.getResultNumber()] = false;
993+
}
994+
}
995+
996+
// Replace the new op with an augmented op
997+
auto augmentedOp =
998+
scf::IfOp::create(builder, op->getLoc(), newResultTypes,
999+
gutils->getNewFromOriginal(ifOp.getCondition()),
1000+
/*withElseRegion=*/true);
1001+
1002+
for (auto &&[oldReg, newReg, augReg] :
1003+
llvm::zip(op->getRegions(), newIf->getRegions(),
1004+
augmentedOp->getRegions())) {
1005+
augReg.takeBody(newReg);
1006+
for (auto &&[oldBlk, augBlk] : llvm::zip(oldReg, augReg)) {
1007+
Operation *oldYield = oldBlk.getTerminator();
1008+
Operation *augYield = augBlk.getTerminator();
1009+
1010+
OpBuilder::InsertionGuard guard(builder);
1011+
builder.setInsertionPoint(augYield);
1012+
SmallVector<Value> newOperands;
1013+
for (auto &&[oldOperand, augOperand] :
1014+
llvm::zip(oldYield->getOpOperands(), augYield->getOpOperands())) {
1015+
newOperands.push_back(augOperand.get());
1016+
if (needsShadow[oldOperand.getOperandNumber()]) {
1017+
newOperands.push_back(
1018+
gutils->invertPointerM(oldOperand.get(), builder));
1019+
}
1020+
}
1021+
1022+
scf::YieldOp::create(builder, oldYield->getLoc(), newOperands);
1023+
augYield->erase();
1024+
}
1025+
}
1026+
1027+
// Determine which returns correspond to the primal
1028+
SmallVector<Value> augmentedResults;
1029+
unsigned resIdx = 0;
1030+
for (auto res : ifOp.getResults()) {
1031+
augmentedResults.push_back(augmentedOp.getResult(resIdx));
1032+
resIdx++;
1033+
if (needsShadow[res.getResultNumber()]) {
1034+
gutils->setInvertedPointer(res, augmentedOp.getResult(resIdx));
1035+
resIdx++;
1036+
}
1037+
}
1038+
newIf.replaceAllUsesWith(augmentedResults);
1039+
newIf.erase();
1040+
}
9731041
};
9741042

9751043
struct ForOpADDataFlow

enzyme/test/MLIR/ReverseMode/scf_if.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,65 @@ func.func @dif_overwrite(%cond: i1, %x: memref<f32>, %dx: memref<f32>) {
133133
// CHECK: }
134134
// CHECK: return
135135
// CHECK: }
136+
137+
// -----
138+
139+
func.func private @active_ptr(%cond: i1, %x: !llvm.ptr) -> f32 {
140+
%c4 = llvm.mlir.constant (4 : i64) : i64
141+
%ptr = scf.if %cond -> !llvm.ptr {
142+
%gep = llvm.getelementptr %x[%c4] : (!llvm.ptr, i64) -> !llvm.ptr, f32
143+
%val = llvm.load %x : !llvm.ptr -> f32
144+
%cos = math.cos %val : f32
145+
llvm.store %cos, %x : f32, !llvm.ptr
146+
scf.yield %gep : !llvm.ptr
147+
} else {
148+
scf.yield %x : !llvm.ptr
149+
}
150+
151+
%ld = llvm.load %ptr : !llvm.ptr -> f32
152+
%cos = math.sin %ld : f32
153+
return %cos : f32
154+
}
155+
156+
func.func @dif_overwrite(%cond: i1, %x: !llvm.ptr, %dx: !llvm.ptr, %dres: f32) {
157+
enzyme.autodiff @active_ptr(%cond, %x, %dx, %dres) {
158+
activity=[#enzyme<activity enzyme_const>, #enzyme<activity enzyme_dup>],
159+
ret_activity=[#enzyme<activity enzyme_activenoneed>]
160+
} : (i1, !llvm.ptr, !llvm.ptr, f32) -> ()
161+
return
162+
}
163+
164+
// CHECK-LABEL: func.func private @diffeactive_ptr(
165+
// CHECK-SAME: %[[ARG0:.*]]: i1,
166+
// CHECK-SAME: %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr,
167+
// CHECK-SAME: %[[ARG3:.*]]: f32) {
168+
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 0.000000e+00 : f32
169+
// CHECK: %[[IF_0:.*]]:3 = scf.if %[[ARG0]] -> (!llvm.ptr, !llvm.ptr, f32) {
170+
// CHECK: %[[GETELEMENTPTR_0:.*]] = llvm.getelementptr %[[ARG2]][4] : (!llvm.ptr) -> !llvm.ptr, f32
171+
// CHECK: %[[GETELEMENTPTR_1:.*]] = llvm.getelementptr %[[ARG1]][4] : (!llvm.ptr) -> !llvm.ptr, f32
172+
// CHECK: %[[LOAD_0:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> f32
173+
// CHECK: %[[COS_0:.*]] = math.cos %[[LOAD_0]] : f32
174+
// CHECK: llvm.store %[[COS_0]], %[[ARG1]] : f32, !llvm.ptr
175+
// CHECK: scf.yield %[[GETELEMENTPTR_1]], %[[GETELEMENTPTR_0]], %[[LOAD_0]] : !llvm.ptr, !llvm.ptr, f32
176+
// CHECK: } else {
177+
// CHECK: scf.yield %[[ARG1]], %[[ARG2]], %[[CONSTANT_0]] : !llvm.ptr, !llvm.ptr, f32
178+
// CHECK: }
179+
// CHECK: %[[LOAD_1:.*]] = llvm.load %[[VAL_0:.*]]#0 : !llvm.ptr -> f32
180+
// CHECK: %[[COS_1:.*]] = math.cos %[[LOAD_1]] : f32
181+
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[ARG3]], %[[COS_1]] : f32
182+
// CHECK: %[[LOAD_2:.*]] = llvm.load %[[VAL_0]]#1 : !llvm.ptr -> f32
183+
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[LOAD_2]], %[[MULF_0]] : f32
184+
// CHECK: llvm.store %[[ADDF_0]], %[[VAL_0]]#1 : f32, !llvm.ptr
185+
// CHECK: scf.if %[[ARG0]] {
186+
// CHECK: %[[LOAD_3:.*]] = llvm.load %[[ARG2]] : !llvm.ptr -> f32
187+
// CHECK: llvm.store %[[CONSTANT_0]], %[[ARG2]] : f32, !llvm.ptr
188+
// CHECK: %[[SIN_0:.*]] = math.sin %[[VAL_0]]#2 : f32
189+
// CHECK: %[[NEGF_0:.*]] = arith.negf %[[SIN_0]] : f32
190+
// CHECK: %[[MULF_1:.*]] = arith.mulf %[[LOAD_3]], %[[NEGF_0]] : f32
191+
// CHECK: %[[LOAD_4:.*]] = llvm.load %[[ARG2]] : !llvm.ptr -> f32
192+
// CHECK: %[[ADDF_1:.*]] = arith.addf %[[LOAD_4]], %[[MULF_1]] : f32
193+
// CHECK: llvm.store %[[ADDF_1]], %[[ARG2]] : f32, !llvm.ptr
194+
// CHECK: } else {
195+
// CHECK: }
196+
// CHECK: return
197+
// CHECK: }

0 commit comments

Comments
 (0)