Skip to content

Commit 08c8877

Browse files
avik-palwsmoses
andauthored
fix: update API to new llvm (#2690)
* fix: update API to new llvm * update mlir * fix --------- Co-authored-by: William S. Moses <gh@wsmoses.com>
1 parent 742e142 commit 08c8877

File tree

4 files changed

+36
-35
lines changed

4 files changed

+36
-35
lines changed

.github/workflows/enzyme-mlir.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ jobs:
4545
- uses: actions/checkout@v4
4646
with:
4747
repository: 'llvm/llvm-project'
48-
ref: '01e6245af481dac4604e8a25be6bec0dbe36f99d'
48+
ref: '909041e4802c4b9a2223ca04099f35bf1dbbd460'
4949
path: 'llvm-project'
5050

5151
- name: Set BASE_DIR

enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,8 @@ static bool isValuePotentiallyUsedAsPointer(Value val) {
852852
continue;
853853
seen.insert(cur);
854854
for (Operation *user : cur.getUsers()) {
855-
if (isa<RegionBranchOpInterface>(user->getParentOp()))
855+
if (auto regionIface =
856+
dyn_cast<RegionBranchOpInterface>(user->getParentOp()))
856857
if (auto termIface =
857858
dyn_cast<RegionBranchTerminatorOpInterface>(user)) {
858859
SmallVector<RegionSuccessor> successors;
@@ -864,9 +865,10 @@ static bool isValuePotentiallyUsedAsPointer(Value val) {
864865
for (auto &successor : successors) {
865866
OperandRange operandRange =
866867
termIface.getSuccessorOperands(successor);
867-
ValueRange targetValues = successor.isParent()
868-
? parentOp->getResults()
869-
: successor.getSuccessorInputs();
868+
ValueRange targetValues =
869+
successor.isParent()
870+
? parentOp->getResults()
871+
: regionIface.getSuccessorInputs(successor);
870872
assert(operandRange.size() == targetValues.size());
871873
for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) {
872874
if (prev == cur) {
@@ -969,7 +971,8 @@ getPotentialTerminatorUsers(Operation *op, Value parent) {
969971

970972
if (auto termIface = dyn_cast<ADDataFlowOpInterface>(op->getParentOp())) {
971973
return termIface.getPotentialTerminatorUsers(op, parent);
972-
} else if (isa<RegionBranchOpInterface>(op->getParentOp())) {
974+
} else if (auto regionIface =
975+
dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
973976
if (auto termIface = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
974977
SmallVector<RegionSuccessor> successors;
975978
termIface.getSuccessorRegions(
@@ -980,9 +983,9 @@ getPotentialTerminatorUsers(Operation *op, Value parent) {
980983
SmallVector<Value> results;
981984
for (auto &successor : successors) {
982985
OperandRange operandRange = termIface.getSuccessorOperands(successor);
983-
ValueRange targetValues = successor.isParent()
984-
? parentOp->getResults()
985-
: successor.getSuccessorInputs();
986+
ValueRange targetValues =
987+
successor.isParent() ? parentOp->getResults()
988+
: regionIface.getSuccessorInputs(successor);
986989
assert(operandRange.size() == targetValues.size());
987990
for (auto &&[prev, post] : llvm::zip(operandRange, targetValues)) {
988991
if (prev == parent) {
@@ -1065,8 +1068,8 @@ static SmallVector<Value> getPotentialIncomingValues(OpResult res) {
10651068
block.getTerminator())) {
10661069
// TODO: the interface may also tell us which regions are allowed to
10671070
// yield parent op results, and which only branch to other regions.
1068-
auto successorOperands = llvm::to_vector(iface.getSuccessorOperands(
1069-
RegionSuccessor::parent(iface->getResults())));
1071+
auto successorOperands = llvm::to_vector(
1072+
iface.getSuccessorOperands(RegionSuccessor::parent()));
10701073
// TODO: understand/document the assumption of how operands flow.
10711074

10721075
if (successorOperands.size() != owner->getNumResults()) {
@@ -1131,7 +1134,8 @@ static SmallVector<Value> getPotentialIncomingValues(BlockArgument arg) {
11311134
continue;
11321135

11331136
unsigned operandOffset = static_cast<unsigned>(-1);
1134-
for (const auto &en : llvm::enumerate(successor.getSuccessorInputs())) {
1137+
for (const auto &en :
1138+
llvm::enumerate(iface.getSuccessorInputs(successor))) {
11351139
if (en.value() != arg)
11361140
continue;
11371141
operandOffset = en.index();
@@ -1146,7 +1150,7 @@ static SmallVector<Value> getPotentialIncomingValues(BlockArgument arg) {
11461150
// XXX: this assumes a contiguous slice of operands is mapped 1-1
11471151
// without swaps to a contiguous slice of entry block arguments.
11481152
assert(iface.getEntrySuccessorOperands(region).size() ==
1149-
successor.getSuccessorInputs().size());
1153+
iface.getSuccessorInputs(successor).size());
11501154
potentialSources.insert(
11511155
iface.getEntrySuccessorOperands(region)[operandOffset]);
11521156
} else {
@@ -1162,7 +1166,7 @@ static SmallVector<Value> getPotentialIncomingValues(BlockArgument arg) {
11621166
// 1-1 without swaps to a contiguous slice of entry block
11631167
// arguments.
11641168
assert(terminator.getSuccessorOperands(region).size() ==
1165-
successor.getSuccessorInputs().size());
1169+
iface.getSuccessorInputs(successor).size());
11661170
potentialSources.insert(
11671171
terminator.getSuccessorOperands(region)[operandOffset]);
11681172
} else {

enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,9 @@ void mlir::enzyme::detail::regionTerminatorForwardHandler(
268268

269269
llvm::SmallDenseSet<unsigned> operandsToShadow;
270270
auto termIface = dyn_cast<RegionBranchTerminatorOpInterface>(origTerminator);
271-
if (termIface &&
272-
isa<RegionBranchOpInterface>(origTerminator->getParentOp())) {
271+
auto regionBranchOp =
272+
dyn_cast<RegionBranchOpInterface>(origTerminator->getParentOp());
273+
if (termIface && regionBranchOp) {
273274

274275
SmallVector<RegionSuccessor> successors;
275276
termIface.getSuccessorRegions(
@@ -278,9 +279,9 @@ void mlir::enzyme::detail::regionTerminatorForwardHandler(
278279

279280
for (auto &successor : successors) {
280281
OperandRange operandRange = termIface.getSuccessorOperands(successor);
281-
ValueRange targetValues = successor.isParent()
282-
? parentOp->getResults()
283-
: successor.getSuccessorInputs();
282+
ValueRange targetValues =
283+
successor.isParent() ? parentOp->getResults()
284+
: regionBranchOp.getSuccessorInputs(successor);
284285
assert(operandRange.size() == targetValues.size());
285286
for (auto &&[i, target] : llvm::enumerate(targetValues)) {
286287
if (!gutils->isConstantValue(target))
@@ -337,9 +338,9 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
337338
OperandRange operandRange =
338339
regionBranchOp.getEntrySuccessorOperands(successor);
339340

340-
ValueRange targetValues = successor.isParent()
341-
? op->getResults()
342-
: successor.getSuccessorInputs();
341+
ValueRange targetValues =
342+
successor.isParent() ? op->getResults()
343+
: regionBranchOp.getSuccessorInputs(successor);
343344

344345
// Need to know which of the arguments are being forwarded to from
345346
// operands.

enzyme/test/MLIR/ReverseMode/scf_for_checkpointing.mlir

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ module {
2323
// CHECK-NEXT: %c1 = arith.constant 1 : index
2424
// CHECK-NEXT: %c9 = arith.constant 9 : index
2525
// CHECK-NEXT: %c0 = arith.constant 0 : index
26-
// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f32
2726
// CHECK-NEXT: %[[v0:.+]] = tensor.empty() : tensor<3xf32>
2827
// CHECK-NEXT: %[[v1:.+]]:2 = scf.for %arg2 = %c0 to %c9 step %c3 iter_args(%arg3 = %arg0, %[[arg5:.+]] = %[[v0]]) -> (f32, tensor<3xf32>) {
2928
// CHECK-NEXT: %[[idx:.+]] = arith.divui %arg2, %c3 : index
@@ -36,7 +35,7 @@ module {
3635
// CHECK-NEXT: scf.yield %[[v3]], %inserted : f32, tensor<3xf32>
3736
// CHECK-NEXT: }
3837
// CHECK-NEXT: %[[v3:.+]] = tensor.empty() : tensor<3xf32>
39-
// CHECK-NEXT: %[[v2:.+]]:4 = scf.for %arg2 = %c0 to %c3 step %c1 iter_args(%arg3 = %arg1, %arg4 = %[[zero]], %arg5 = %[[zero]], %arg6 = %[[zero]]) -> (f32, f32, f32, f32) {
38+
// CHECK-NEXT: %[[v2:.+]] = scf.for %arg2 = %c0 to %c3 step %c1 iter_args(%arg3 = %arg1) -> (f32) {
4039
// CHECK-NEXT: %[[ridx:.+]] = arith.subi %c2, %arg2 : index
4140
// CHECK-NEXT: %extracted = tensor.extract %[[v1]]#1[%[[ridx]]] : tensor<3xf32>
4241
// CHECK-NEXT: %[[v5:.+]]:2 = scf.for %[[arg8:.+]] = %c0 to %c3 step %c1 iter_args(%[[arg9:.+]] = %extracted, %[[arg10:.+]] = %[[v3]]) -> (f32, tensor<3xf32>) {
@@ -45,22 +44,19 @@ module {
4544
// CHECK-NEXT: %[[v9:.+]] = math.cos %[[v8]] : f32
4645
// CHECK-NEXT: scf.yield %[[v9]], %inserted : f32, tensor<3xf32>
4746
// CHECK-NEXT: }
48-
// CHECK-NEXT: %[[v6:.+]]:4 = scf.for %[[arg8:.+]] = %c0 to %c3 step %c1 iter_args(%[[arg9:.+]] = %arg3, %[[arg10:.+]] = %arg4, %[[arg11:.+]] = %arg5, %[[arg12:.+]] = %arg6) -> (f32, f32, f32, f32) {
47+
// CHECK-NEXT: %[[v6:.+]] = scf.for %[[arg8:.+]] = %c0 to %c3 step %c1 iter_args(%[[arg9:.+]] = %arg3) -> (f32) {
4948
// CHECK-NEXT: %[[ridx2:.+]] = arith.subi %c2, %[[arg8]] : index
5049
// CHECK-NEXT: %extracted_0 = tensor.extract %[[v5]]#1[%[[ridx2]]] : tensor<3xf32>
5150
// CHECK-NEXT: %[[r8:.+]] = arith.mulf %extracted_0, %extracted_0 : f32
52-
// CHECK-NEXT: %[[v8:.+]] = arith.addf %[[arg10]], %[[arg9]] : f32
5351
// CHECK-NEXT: %[[v9:.+]] = math.sin %[[r8]] : f32
5452
// CHECK-NEXT: %[[v10:.+]] = arith.negf %[[v9]] : f32
55-
// CHECK-NEXT: %[[v11:.+]] = arith.mulf %[[v8]], %[[v10]] : f32
56-
// CHECK-NEXT: %[[v12:.+]] = arith.addf %[[arg11]], %[[v11]] : f32
57-
// CHECK-NEXT: %[[v13:.+]] = arith.mulf %[[v12]], %extracted_0 : f32
58-
// CHECK-NEXT: %[[v14:.+]] = arith.addf %[[arg12]], %[[v13]] : f32
59-
// CHECK-NEXT: %[[v15:.+]] = arith.mulf %[[v12]], %extracted_0 : f32
60-
// CHECK-NEXT: %[[v16:.+]] = arith.addf %[[v14]], %[[v15]] : f32
61-
// CHECK-NEXT: scf.yield %[[v16]], %[[zero]], %[[zero]], %[[zero]] : f32, f32, f32, f32
53+
// CHECK-NEXT: %[[v11:.+]] = arith.mulf %[[arg9]], %[[v10]] : f32
54+
// CHECK-NEXT: %[[v12:.+]] = arith.mulf %[[v11]], %extracted_0 : f32
55+
// CHECK-NEXT: %[[v13:.+]] = arith.mulf %[[v11]], %extracted_0 : f32
56+
// CHECK-NEXT: %[[v14:.+]] = arith.addf %[[v12]], %[[v13]] : f32
57+
// CHECK-NEXT: scf.yield %[[v14]] : f32
6258
// CHECK-NEXT: }
63-
// CHECK-NEXT: scf.yield %[[v6]]#0, %[[v6]]#1, %[[v6]]#2, %[[v6]]#3 : f32, f32, f32, f32
59+
// CHECK-NEXT: scf.yield %[[v6]] : f32
6460
// CHECK-NEXT: }
65-
// CHECK-NEXT: return %[[v2]]#0 : f32
61+
// CHECK-NEXT: return %[[v2]] : f32
6662
// CHECK-NEXT: }

0 commit comments

Comments
 (0)