10
10
11
11
#include " mlir/Dialect/Func/IR/FuncOps.h"
12
12
#include " mlir/Dialect/OpenACC/OpenACC.h"
13
+ #include " mlir/IR/Dominance.h"
13
14
#include " mlir/Pass/Pass.h"
14
15
#include " mlir/Transforms/RegionUtils.h"
15
16
#include " llvm/Support/ErrorHandling.h"
@@ -71,7 +72,55 @@ static void replaceAllUsesInAccComputeRegionsWith(Value orig, Value replacement,
71
72
}
72
73
73
74
template <typename Op>
74
- static void collectAndReplaceInRegion (Op &op, bool hostToDevice) {
75
+ static void replaceAllUsesInUnstructuredComputeRegionWith (
76
+ Op &op, llvm::SmallVector<std::pair<Value, Value>> &values,
77
+ DominanceInfo &domInfo, PostDominanceInfo &postDomInfo) {
78
+
79
+ SmallVector<Operation *> exitOps;
80
+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
81
+ // For declare enter/exit pairs, collect all exit ops
82
+ for (auto *user : op.getToken ().getUsers ()) {
83
+ if (auto declareExit = dyn_cast<acc::DeclareExitOp>(user))
84
+ exitOps.push_back (declareExit);
85
+ }
86
+ if (exitOps.empty ())
87
+ return ;
88
+ }
89
+
90
+ for (auto p : values) {
91
+ Value hostVal = std::get<0 >(p);
92
+ Value deviceVal = std::get<1 >(p);
93
+ for (auto &use : llvm::make_early_inc_range (hostVal.getUses ())) {
94
+ Operation *owner = use.getOwner ();
95
+
96
+ // Check It's the case that the acc entry operation dominates the use.
97
+ if (!domInfo.dominates (op.getOperation (), owner))
98
+ continue ;
99
+
100
+ // Check It's the case that at least one of the acc exit operations
101
+ // post-dominates the use
102
+ bool hasPostDominatingExit = false ;
103
+ for (auto *exit : exitOps) {
104
+ if (postDomInfo.postDominates (exit , owner)) {
105
+ hasPostDominatingExit = true ;
106
+ break ;
107
+ }
108
+ }
109
+
110
+ if (!hasPostDominatingExit)
111
+ continue ;
112
+
113
+ if (insideAccComputeRegion (owner))
114
+ use.set (deviceVal);
115
+ }
116
+ }
117
+ }
118
+
119
+ template <typename Op>
120
+ static void
121
+ collectAndReplaceInRegion (Op &op, bool hostToDevice,
122
+ DominanceInfo *domInfo = nullptr ,
123
+ PostDominanceInfo *postDomInfo = nullptr ) {
75
124
llvm::SmallVector<std::pair<Value, Value>> values;
76
125
77
126
if constexpr (std::is_same_v<Op, acc::LoopOp>) {
@@ -82,16 +131,25 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
82
131
if constexpr (!std::is_same_v<Op, acc::KernelsOp> &&
83
132
!std::is_same_v<Op, acc::DataOp> &&
84
133
!std::is_same_v<Op, acc::DeclareOp> &&
85
- !std::is_same_v<Op, acc::HostDataOp>) {
134
+ !std::is_same_v<Op, acc::HostDataOp> &&
135
+ !std::is_same_v<Op, acc::DeclareEnterOp>) {
86
136
collectVars (op.getReductionOperands (), values, hostToDevice);
87
137
collectVars (op.getPrivateOperands (), values, hostToDevice);
88
138
collectVars (op.getFirstprivateOperands (), values, hostToDevice);
89
139
}
90
140
}
91
141
92
- for (auto p : values)
93
- replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
94
- op.getRegion ());
142
+ if constexpr (std::is_same_v<Op, acc::DeclareEnterOp>) {
143
+ assert (domInfo && postDomInfo &&
144
+ " Dominance info required for DeclareEnterOp" );
145
+ replaceAllUsesInUnstructuredComputeRegionWith<Op>(op, values, *domInfo,
146
+ *postDomInfo);
147
+ } else {
148
+ for (auto p : values) {
149
+ replaceAllUsesInAccComputeRegionsWith<Op>(std::get<0 >(p), std::get<1 >(p),
150
+ op.getRegion ());
151
+ }
152
+ }
95
153
}
96
154
97
155
class LegalizeDataValuesInRegion
@@ -105,10 +163,16 @@ class LegalizeDataValuesInRegion
105
163
func::FuncOp funcOp = getOperation ();
106
164
bool replaceHostVsDevice = this ->hostToDevice .getValue ();
107
165
166
+ // Initialize dominance info
167
+ DominanceInfo domInfo;
168
+ PostDominanceInfo postDomInfo;
169
+ bool computedDomInfo = false ;
170
+
108
171
funcOp.walk ([&](Operation *op) {
109
172
if (!isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(*op) &&
110
173
!(isa<ACC_DATA_CONSTRUCT_STRUCTURED_OPS>(*op) &&
111
- applyToAccDataConstruct))
174
+ applyToAccDataConstruct) &&
175
+ !isa<acc::DeclareEnterOp>(*op))
112
176
return ;
113
177
114
178
if (auto parallelOp = dyn_cast<acc::ParallelOp>(*op)) {
@@ -125,6 +189,14 @@ class LegalizeDataValuesInRegion
125
189
collectAndReplaceInRegion (declareOp, replaceHostVsDevice);
126
190
} else if (auto hostDataOp = dyn_cast<acc::HostDataOp>(*op)) {
127
191
collectAndReplaceInRegion (hostDataOp, replaceHostVsDevice);
192
+ } else if (auto declareEnterOp = dyn_cast<acc::DeclareEnterOp>(*op)) {
193
+ if (!computedDomInfo) {
194
+ domInfo = DominanceInfo (funcOp);
195
+ postDomInfo = PostDominanceInfo (funcOp);
196
+ computedDomInfo = true ;
197
+ }
198
+ collectAndReplaceInRegion (declareEnterOp, replaceHostVsDevice, &domInfo,
199
+ &postDomInfo);
128
200
} else {
129
201
llvm_unreachable (" unsupported acc region op" );
130
202
}
0 commit comments