Skip to content

Commit a073bb5

Browse files
authored
[mlir][acc] Add LegalizeDataValues support for DeclareEnterOp (#138008)
The patch extends the existing LegalizeDataValues to support DeclareEnter and DeclareExit pair. Since unlike other ops, DeclareEnter and DeclareExit don't have a region defined, we use dominance/post dominance information to ensure only the uses within the region dominated by DeclareEnter and post dominated by DeclareExit are updated with data on device.
1 parent 4fdb8cb commit a073bb5

File tree

2 files changed

+106
-7
lines changed

2 files changed

+106
-7
lines changed

mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp

+78-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Dialect/Func/IR/FuncOps.h"
1212
#include "mlir/Dialect/OpenACC/OpenACC.h"
13+
#include "mlir/IR/Dominance.h"
1314
#include "mlir/Pass/Pass.h"
1415
#include "mlir/Transforms/RegionUtils.h"
1516
#include "llvm/Support/ErrorHandling.h"
@@ -71,7 +72,55 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
7172
}
7273

7374
template <typename Op>
74-
static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
75+
static void replaceAllUsesInUnstructuredComputeRegionWith(
76+
Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
77+
DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78+
79+
SmallVector<Operation *> exitOps;
80+
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81+
// For declare enter/exit pairs, collect all exit ops
82+
for (auto *user : op.getToken().getUsers()) {
83+
if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84+
exitOps.push_back(declareExit);
85+
}
86+
if (exitOps.empty())
87+
return;
88+
}
89+
90+
for (auto p : values) {
91+
Value hostVal = std::get<0>(p);
92+
Value deviceVal = std::get<1>(p);
93+
for (auto &use : llvm::make_early_inc_range(hostVal.getUses())) {
94+
Operation *owner = use.getOwner();
95+
96+
// Check It's the case that the acc entry operation dominates the use.
97+
if (!domInfo.dominates(op.getOperation(), owner))
98+
continue;
99+
100+
// Check It's the case that at least one of the acc exit operations
101+
// post-dominates the use
102+
bool hasPostDominatingExit = false;
103+
for (auto *exit : exitOps) {
104+
if (postDomInfo.postDominates(exit, owner)) {
105+
hasPostDominatingExit = true;
106+
break;
107+
}
108+
}
109+
110+
if (!hasPostDominatingExit)
111+
continue;
112+
113+
if (insideAccComputeRegion(owner))
114+
use.set(deviceVal);
115+
}
116+
}
117+
}
118+
119+
template <typename Op>
120+
static void
121+
collectAndReplaceInRegion(Op &op, bool hostToDevice,
122+
DominanceInfo *domInfo = nullptr,
123+
PostDominanceInfo *postDomInfo = nullptr) {
75124
llvm::SmallVector<std::pair<Value, Value>> values;
76125

77126
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +131,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
82131
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83132
!std::is_same_v<Op, acc::DataOp> &&
84133
!std::is_same_v<Op, acc::DeclareOp> &&
85-
!std::is_same_v<Op, acc::HostDataOp>) {
134+
!std::is_same_v<Op, acc::HostDataOp> &&
135+
!std::is_same_v<Op, acc::DeclareEnterOp>) {
86136
collectVars(op.getReductionOperands(), values, hostToDevice);
87137
collectVars(op.getPrivateOperands(), values, hostToDevice);
88138
collectVars(op.getFirstprivateOperands(), values, hostToDevice);
89139
}
90140
}
91141

92-
for (auto p : values)
93-
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
94-
op.getRegion());
142+
if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143+
assert(domInfo && postDomInfo &&
144+
"Dominance info required for DeclareEnterOp");
145+
replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
146+
*postDomInfo);
147+
} else {
148+
for (auto p : values) {
149+
replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0>(p), std::get<1>(p),
150+
op.getRegion());
151+
}
152+
}
95153
}
96154

97155
class LegalizeDataValuesInRegion
@@ -105,10 +163,16 @@ class LegalizeDataValuesInRegion
105163
func::FuncOp funcOp = getOperation();
106164
bool replaceHostVsDevice = this->hostToDevice.getValue();
107165

166+
// Initialize dominance info
167+
DominanceInfo domInfo;
168+
PostDominanceInfo postDomInfo;
169+
bool computedDomInfo = false;
170+
108171
funcOp.walk([&](Operation *op) {
109172
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
110173
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
111-
applyToAccDataConstruct))
174+
applyToAccDataConstruct) &&
175+
!isa<acc::DeclareEnterOp>(*op))
112176
return;
113177

114178
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -125,6 +189,14 @@ class LegalizeDataValuesInRegion
125189
collectAndReplaceInRegion(declareOp, replaceHostVsDevice);
126190
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127191
collectAndReplaceInRegion(hostDataOp, replaceHostVsDevice);
192+
} else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193+
if (!computedDomInfo) {
194+
domInfo = DominanceInfo(funcOp);
195+
postDomInfo = PostDominanceInfo(funcOp);
196+
computedDomInfo = true;
197+
}
198+
collectAndReplaceInRegion(declareEnterOp, replaceHostVsDevice, &domInfo,
199+
&postDomInfo);
128200
} else {
129201
llvm_unreachable("unsupported acc region op");
130202
}

mlir/test/Dialect/OpenACC/legalize-data.mlir

+28-1
Original file line numberDiff line numberDiff line change
@@ -245,4 +245,31 @@ func.func private @foo(memref<10xf32>)
245245
// CHECK: acc.host_data dataOperands(%[[USE_DEVICE]] : memref<10xf32>) {
246246
// DEVICE: func.call @foo(%[[USE_DEVICE]]) : (memref<10xf32>) -> ()
247247
// CHECK: acc.terminator
248-
// CHECK: }
248+
// CHECK: }
249+
250+
// -----
251+
252+
func.func @test(%a: memref<10xf32>) {
253+
%declare = acc.create varPtr(%a : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"}
254+
%token = acc.declare_enter dataOperands(%declare : memref<10xf32>)
255+
acc.kernels dataOperands(%declare : memref<10xf32>) {
256+
%c0 = arith.constant 0 : index
257+
%c1 = arith.constant 1.000000e+00 : f32
258+
memref.store %c1, %a[%c0] : memref<10xf32>
259+
acc.terminator
260+
}
261+
acc.declare_exit token(%token) dataOperands(%declare : memref<10xf32>)
262+
return
263+
}
264+
265+
// CHECK-LABEL: func.func @test
266+
// CHECK-SAME: (%[[A:.*]]: memref<10xf32>)
267+
// CHECK: %[[DECLARE:.*]] = acc.create varPtr(%[[A]] : memref<10xf32>) varType(tensor<10xf32>) -> memref<10xf32> {name = "arr"}
268+
// CHECK: %[[TOKEN:.*]] = acc.declare_enter dataOperands(%[[DECLARE]] : memref<10xf32>)
269+
// CHECK: acc.kernels dataOperands(%[[DECLARE]] : memref<10xf32>) {
270+
// DEVICE: memref.store %{{.*}}, %[[DECLARE]][%{{.*}}] : memref<10xf32>
271+
// HOST: memref.store %{{.*}}, %[[A]][%{{.*}}] : memref<10xf32>
272+
// CHECK: acc.terminator
273+
// CHECK: }
274+
// CHECK: acc.declare_exit token(%[[TOKEN]]) dataOperands(%[[DECLARE]] : memref<10xf32>)
275+
// CHECK: return

0 commit comments

Comments
 (0)