Skip to content

Commit 0405624

Browse files
committed
unified alloca_scope
1 parent 4ecf1c9 commit 0405624

2 files changed

Lines changed: 132 additions & 31 deletions

File tree

enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -223,19 +223,41 @@ struct AllocaScopeOpInterfaceReverse
223223
}
224224
}
225225

226-
auto revScope =
227-
memref::AllocaScopeOp::create(builder, op->getLoc(), TypeRange());
228-
Block *revBody = builder.createBlock(&revScope.getBodyRegion());
229-
OpBuilder bodyBuilder(revBody, revBody->end());
230-
memref::AllocaScopeReturnOp::create(bodyBuilder, op->getLoc(),
231-
ValueRange());
232-
bodyBuilder.setInsertionPoint(revBody->getTerminator());
226+
Region &scopeRegion = scopeOp.getBodyRegion();
227+
SmallVector<Value> capturedInputs;
228+
scopeRegion.walk([&](Operation *inner) {
229+
for (Value operand : inner->getOperands()) {
230+
Region *defRegion = operand.getParentRegion();
231+
if (!defRegion || scopeRegion.isAncestor(defRegion))
232+
continue;
233+
if (gutils->isConstantValue(operand))
234+
continue;
235+
auto iface = dyn_cast<AutoDiffTypeInterface>(operand.getType());
236+
if (!iface || iface.isMutable())
237+
continue;
238+
if (!llvm::is_contained(capturedInputs, operand))
239+
capturedInputs.push_back(operand);
240+
}
241+
});
242+
243+
SmallVector<Value> capturedPre;
244+
capturedPre.reserve(capturedInputs.size());
245+
for (Value v : capturedInputs) {
246+
capturedPre.push_back(gutils->diffe(v, builder));
247+
gutils->zeroDiffe(v, builder);
248+
}
249+
250+
auto newScope =
251+
cast<memref::AllocaScopeOp>(gutils->getNewFromOriginal(op));
252+
newScope->moveBefore(builder.getInsertionBlock(),
253+
builder.getInsertionPoint());
254+
255+
Block &newBody = newScope.getBodyRegion().front();
256+
OpBuilder bodyBuilder(newBody.getTerminator());
233257

234258
Block &oldBody = scopeOp.getBodyRegion().front();
235259
bool valid = true;
236260

237-
// Values defined in the scoped region cannot be used outside it. Reset
238-
// their adjoints before propagating gradients through the scoped body.
239261
for (Operation &innerOp : oldBody.getOperations()) {
240262
for (Value result : innerOp.getResults()) {
241263
if (!gutils->isConstantValue(result)) {
@@ -256,16 +278,61 @@ struct AllocaScopeOpInterfaceReverse
256278
if (!gutils->isConstantValue(operand))
257279
gutils->addToDiffe(operand, incomingGradients[incomingIdx],
258280
bodyBuilder);
259-
++incomingIdx;
281+
incomingIdx++;
260282
}
261283

262284
auto first = oldBody.rbegin();
263-
++first;
285+
first++;
264286

265-
for (auto it = first; it != oldBody.rend(); ++it) {
287+
for (auto it = first; it != oldBody.rend(); it++) {
266288
valid &= gutils->Logic.visitChild(&*it, bodyBuilder, gutils).succeeded();
267289
}
268290

291+
if (capturedInputs.empty())
292+
return success(valid);
293+
294+
SmallVector<Value> contributions;
295+
contributions.reserve(capturedInputs.size());
296+
for (Value v : capturedInputs)
297+
contributions.push_back(gutils->diffe(v, bodyBuilder));
298+
299+
unsigned numPrimal = newScope->getNumResults();
300+
Operation *bodyTerm = newBody.getTerminator();
301+
SmallVector<Value> retVals(bodyTerm->getOperands().begin(),
302+
bodyTerm->getOperands().end());
303+
retVals.append(contributions.begin(), contributions.end());
304+
305+
SmallVector<Type> newResultTypes(newScope->getResultTypes().begin(),
306+
newScope->getResultTypes().end());
307+
for (Value v : capturedInputs)
308+
newResultTypes.push_back(gutils->getShadowType(v.getType()));
309+
310+
OpBuilder scopeBuilder(newScope);
311+
auto extScope = memref::AllocaScopeOp::create(scopeBuilder, op->getLoc(),
312+
newResultTypes);
313+
extScope.getBodyRegion().takeBody(newScope.getBodyRegion());
314+
315+
Block &extBody = extScope.getBodyRegion().front();
316+
Operation *movedTerm = extBody.getTerminator();
317+
OpBuilder termBuilder(movedTerm);
318+
memref::AllocaScopeReturnOp::create(termBuilder, movedTerm->getLoc(),
319+
retVals);
320+
gutils->erase(movedTerm);
321+
322+
for (unsigned i = 0; i < numPrimal; ++i) {
323+
newScope->getResult(i).replaceAllUsesWith(extScope.getResult(i));
324+
gutils->originalToNewFn.map(scopeOp->getResult(i), extScope.getResult(i));
325+
}
326+
gutils->erase(newScope);
327+
328+
builder.setInsertionPointAfter(extScope);
329+
for (auto indexed : llvm::enumerate(capturedInputs)) {
330+
Value v = indexed.value();
331+
gutils->setDiffe(v, capturedPre[indexed.index()], builder);
332+
gutils->addToDiffe(v, extScope.getResult(numPrimal + indexed.index()),
333+
builder);
334+
}
335+
269336
return success(valid);
270337
}
271338

enzyme/test/MLIR/ReverseMode/alloca_scope2.mlir

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,63 @@ func.func @foo(%x : f64) -> f64{
1010
return %out : f64
1111
}
1212

13-
1413
func.func @dfoo(%x: f64, %dout : f64) -> f64 {
1514
%dx = enzyme.autodiff @foo(%x, %dout) {
16-
activity = [#enzyme<activity enzyme_active>],
17-
ret_activity = [#enzyme<acitivity enzyme_activenoneed>]
18-
} : (f64, f64) -> (f64)
19-
return %dx : f64
20-
}
21-
22-
func.func @foo2(%x : f64) -> f64{
23-
%cst = arith.constant 0.0000e+00 : f64
24-
%buf = memref.alloca() : memref<f64>
25-
memref.store %x, %buf[] : memref<f64>
26-
%y = memref.load %buf[] : memref<f64>
27-
%out = arith.addf %cst, %y : f64
28-
return %out : f64
29-
}
30-
31-
32-
func.func @dfoo2(%x: f64, %dout : f64) -> f64 {
33-
%dx = enzyme.autodiff @foo2(%x, %dout) {
3415
activity = [#enzyme<activity enzyme_active>],
3516
ret_activity = [#enzyme<activity enzyme_activenoneed>]
3617
} : (f64, f64) -> (f64)
3718
return %dx : f64
3819
}
20+
21+
// CHECK-LABEL: func.func private @diffefoo(
22+
// CHECK-SAME: %[[X:[^,]+]]: f64,
23+
// CHECK-SAME: %[[DOUT:[^)]+]]: f64) -> f64 {
24+
// CHECK: %[[C0:.*]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
25+
// CHECK: %[[C1:.*]] = "enzyme.init"() : () -> !enzyme.Cache<memref<f64>>
26+
// CHECK: %[[G0:.*]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
27+
// CHECK: "enzyme.set"(%[[G0]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
28+
// CHECK: %[[G1:.*]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
29+
// CHECK: "enzyme.set"(%[[G1]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
30+
// CHECK: %[[G2:.*]] = "enzyme.init"() : () -> !enzyme.Gradient<f64>
31+
// CHECK: "enzyme.set"(%[[G2]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
32+
// CHECK: cf.br ^bb1
33+
// CHECK: ^bb1:
34+
// CHECK: %[[T5:.*]] = "enzyme.get"(%[[G2]]) : (!enzyme.Gradient<f64>) -> f64
35+
// CHECK: %[[T6:.*]] = arith.addf %[[T5]], %[[DOUT]] : f64
36+
// CHECK: "enzyme.set"(%[[G2]], %[[T6]]) : (!enzyme.Gradient<f64>, f64) -> ()
37+
// CHECK: %[[T7:.*]] = "enzyme.get"(%[[G2]]) : (!enzyme.Gradient<f64>) -> f64
38+
// CHECK: "enzyme.set"(%[[G2]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
39+
// CHECK: %[[T8:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
40+
// CHECK: "enzyme.set"(%[[G1]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
41+
// CHECK: %[[SCOPE:.*]]:2 = memref.alloca_scope -> (f64, f64) {
42+
// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<f64>
43+
// CHECK: memref.store %{{.*}}, %[[ALLOCA]][] : memref<f64>
44+
// CHECK: %[[ALLOCA5:.*]] = memref.alloca() : memref<f64>
45+
// CHECK: "enzyme.push"(%[[C0]], %[[ALLOCA]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
46+
// CHECK: memref.store %[[X]], %[[ALLOCA5]][] : memref<f64>
47+
// CHECK: "enzyme.push"(%[[C1]], %[[ALLOCA]]) : (!enzyme.Cache<memref<f64>>, memref<f64>) -> ()
48+
// CHECK: %[[LD:.*]] = memref.load %[[ALLOCA5]][] : memref<f64>
49+
// CHECK: "enzyme.set"(%[[G0]], %{{.*}}) : (!enzyme.Gradient<f64>, f64) -> ()
50+
// CHECK: %[[T14:.*]] = "enzyme.get"(%[[G0]]) : (!enzyme.Gradient<f64>) -> f64
51+
// CHECK: %[[T15:.*]] = arith.addf %[[T14]], %[[T7]] : f64
52+
// CHECK: "enzyme.set"(%[[G0]], %[[T15]]) : (!enzyme.Gradient<f64>, f64) -> ()
53+
// CHECK: %[[T16:.*]] = "enzyme.get"(%[[G0]]) : (!enzyme.Gradient<f64>) -> f64
54+
// CHECK: %[[POP1:.*]] = "enzyme.pop"(%[[C1]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
55+
// CHECK: %[[LD18:.*]] = memref.load %[[POP1]][] : memref<f64>
56+
// CHECK: %[[T19:.*]] = arith.addf %[[LD18]], %[[T16]] : f64
57+
// CHECK: memref.store %[[T19]], %[[POP1]][] : memref<f64>
58+
// CHECK: %[[POP0:.*]] = "enzyme.pop"(%[[C0]]) : (!enzyme.Cache<memref<f64>>) -> memref<f64>
59+
// CHECK: %[[LD21:.*]] = memref.load %[[POP0]][] : memref<f64>
60+
// CHECK: %[[T22:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
61+
// CHECK: %[[T23:.*]] = arith.addf %[[T22]], %[[LD21]] : f64
62+
// CHECK: "enzyme.set"(%[[G1]], %[[T23]]) : (!enzyme.Gradient<f64>, f64) -> ()
63+
// CHECK: memref.store %{{.*}}, %[[POP0]][] : memref<f64>
64+
// CHECK: %[[T24:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
65+
// CHECK: memref.alloca_scope.return %[[LD]], %[[T24]] : f64, f64
66+
// CHECK: }
67+
// CHECK: "enzyme.set"(%[[G1]], %[[T8]]) : (!enzyme.Gradient<f64>, f64) -> ()
68+
// CHECK: %[[T10:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
69+
// CHECK: %[[T11:.*]] = arith.addf %[[T10]], %[[SCOPE]]#1 : f64
70+
// CHECK: "enzyme.set"(%[[G1]], %[[T11]]) : (!enzyme.Gradient<f64>, f64) -> ()
71+
// CHECK: %[[T12:.*]] = "enzyme.get"(%[[G1]]) : (!enzyme.Gradient<f64>) -> f64
72+
// CHECK: return %[[T12]] : f64

0 commit comments

Comments
 (0)