|
13 | 13 |
|
14 | 14 | #include "flang/Optimizer/Dialect/FIRDialect.h"
|
15 | 15 | #include "flang/Optimizer/Dialect/FIROpsSupport.h"
|
| 16 | +#include "flang/Optimizer/HLFIR/HLFIROps.h" |
16 | 17 | #include "flang/Optimizer/OpenMP/Passes.h"
|
17 | 18 |
|
18 | 19 | #include "mlir/Dialect/Func/IR/FuncOps.h"
|
19 | 20 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
20 | 21 | #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
|
21 | 22 | #include "mlir/IR/BuiltinOps.h"
|
| 23 | +#include "llvm/ADT/SetVector.h" |
22 | 24 | #include "llvm/ADT/SmallVector.h"
|
23 | 25 |
|
24 | 26 | namespace flangomp {
|
@@ -94,12 +96,298 @@ class FunctionFilteringPass
|
94 | 96 | funcOp.erase();
|
95 | 97 | return WalkResult::skip();
|
96 | 98 | }
|
| 99 | + |
| 100 | + if (failed(rewriteHostRegion(funcOp.getRegion()))) { |
| 101 | + funcOp.emitOpError() << "could not be rewritten for target device"; |
| 102 | + return WalkResult::interrupt(); |
| 103 | + } |
| 104 | + |
97 | 105 | if (declareTargetOp)
|
98 | 106 | declareTargetOp.setDeclareTarget(declareType,
|
99 | 107 | omp::DeclareTargetCaptureClause::to);
|
100 | 108 | }
|
101 | 109 | return WalkResult::advance();
|
102 | 110 | });
|
103 | 111 | }
|
| 112 | + |
| 113 | +private: |
| 114 | + /// Add the given \c omp.map.info to a sorted set while taking into account |
| 115 | + /// its dependencies. |
| 116 | + static void collectMapInfos(omp::MapInfoOp mapOp, Region ®ion, |
| 117 | + llvm::SetVector<omp::MapInfoOp> &mapInfos) { |
| 118 | + for (Value member : mapOp.getMembers()) |
| 119 | + collectMapInfos(cast<omp::MapInfoOp>(member.getDefiningOp()), region, |
| 120 | + mapInfos); |
| 121 | + |
| 122 | + if (region.isAncestor(mapOp->getParentRegion())) |
| 123 | + mapInfos.insert(mapOp); |
| 124 | + } |
| 125 | + |
| 126 | + /// Add the given value to a sorted set if it should be replaced by a |
| 127 | + /// placeholder when used as a pointer-like argument to an operation |
| 128 | + /// participating in the initialization of an \c omp.map.info. |
| 129 | + static void markPtrOperandForRewrite(Value value, |
| 130 | + llvm::SetVector<Value> &rewriteValues) { |
| 131 | + // We don't need to rewrite operands if they are defined by block arguments |
| 132 | + // of operations that will still remain after the region is rewritten. |
| 133 | + if (isa<BlockArgument>(value) && |
| 134 | + isa<func::FuncOp, omp::TargetDataOp>( |
| 135 | + cast<BlockArgument>(value).getOwner()->getParentOp())) |
| 136 | + return; |
| 137 | + |
| 138 | + rewriteValues.insert(value); |
| 139 | + } |
| 140 | + |
| 141 | + /// Rewrite the given host device region belonging to a function that contains |
| 142 | + /// \c omp.target operations, to remove host-only operations that are not used |
| 143 | + /// by device codegen. |
| 144 | + /// |
| 145 | + /// It is based on the expected form of the MLIR module as produced by Flang |
| 146 | + /// lowering and it performs the following mutations: |
| 147 | + /// - Replace all values returned by the function with \c fir.undefined. |
| 148 | + /// - Operations taking map-like clauses (e.g. \c omp.target, |
| 149 | + /// \c omp.target_data, etc) are moved to the end of the function. If they |
| 150 | + /// are nested inside of any other operations, they are hoisted out of |
| 151 | + /// them. If the region belongs to \c omp.target_data, these operations |
| 152 | + /// are hoisted to its top level, rather than to the parent function. |
| 153 | + /// - Only \c omp.map.info operations associated to these target regions are |
| 154 | + /// preserved. These are moved above all \c omp.target and sorted to |
| 155 | + /// satisfy dependencies among them. |
| 156 | + /// - \c bounds arguments are removed from \c omp.map.info operations. |
| 157 | + /// - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are |
| 158 | + /// handled as follows: |
| 159 | + /// - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset |
| 160 | + /// operation which is preserved. Otherwise, the pass will fail. |
| 161 | + /// - \c var_ptr can be defined by an \c hlfir.declare which is also |
| 162 | + /// preserved. If the \c var_ptr or \c hlfir.declare \c memref argument |
| 163 | + /// is a \c fir.address_of operation, that operation is also maintained. |
| 164 | + /// Otherwise, it is replaced by a placeholder \c fir.alloca and a |
| 165 | + /// \c fir.convert or kept unmodified when it is defined by an entry |
| 166 | + /// block argument. If it has \c shape or \c typeparams arguments, they |
| 167 | + /// are also replaced by applicable constants. \c dummy_scope arguments |
| 168 | + /// are discarded. |
| 169 | + /// - Every other operation not located inside of an \c omp.target is |
| 170 | + /// removed. |
| 171 | + LogicalResult rewriteHostRegion(Region ®ion) { |
| 172 | + // Extract parent op information. |
| 173 | + auto [funcOp, targetDataOp] = [®ion]() { |
| 174 | + Operation *parent = region.getParentOp(); |
| 175 | + return std::make_tuple(dyn_cast<func::FuncOp>(parent), |
| 176 | + dyn_cast<omp::TargetDataOp>(parent)); |
| 177 | + }(); |
| 178 | + assert((bool)funcOp != (bool)targetDataOp && |
| 179 | + "region must be defined by either func.func or omp.target_data"); |
| 180 | + |
| 181 | + // Collect operations that have mapping information associated to them. |
| 182 | + llvm::SmallVector< |
| 183 | + std::variant<omp::TargetOp, omp::TargetDataOp, omp::TargetEnterDataOp, |
| 184 | + omp::TargetExitDataOp, omp::TargetUpdateOp>> |
| 185 | + targetOps; |
| 186 | + |
| 187 | + WalkResult result = region.walk<WalkOrder::PreOrder>([&](Operation *op) { |
| 188 | + // Skip the inside of omp.target regions, since these contain device code. |
| 189 | + if (auto targetOp = dyn_cast<omp::TargetOp>(op)) { |
| 190 | + targetOps.push_back(targetOp); |
| 191 | + return WalkResult::skip(); |
| 192 | + } |
| 193 | + |
| 194 | + if (auto targetOp = dyn_cast<omp::TargetDataOp>(op)) { |
| 195 | + // Recursively rewrite omp.target_data regions as well. |
| 196 | + if (failed(rewriteHostRegion(targetOp.getRegion()))) { |
| 197 | + targetOp.emitOpError() << "rewrite for target device failed"; |
| 198 | + return WalkResult::interrupt(); |
| 199 | + } |
| 200 | + |
| 201 | + targetOps.push_back(targetOp); |
| 202 | + return WalkResult::skip(); |
| 203 | + } |
| 204 | + |
| 205 | + if (auto targetOp = dyn_cast<omp::TargetEnterDataOp>(op)) |
| 206 | + targetOps.push_back(targetOp); |
| 207 | + if (auto targetOp = dyn_cast<omp::TargetExitDataOp>(op)) |
| 208 | + targetOps.push_back(targetOp); |
| 209 | + if (auto targetOp = dyn_cast<omp::TargetUpdateOp>(op)) |
| 210 | + targetOps.push_back(targetOp); |
| 211 | + |
| 212 | + return WalkResult::advance(); |
| 213 | + }); |
| 214 | + |
| 215 | + if (result.wasInterrupted()) |
| 216 | + return failure(); |
| 217 | + |
| 218 | + // Make a temporary clone of the parent operation with an empty region, |
| 219 | + // and update all references to entry block arguments to those of the new |
| 220 | + // region. Users will later either be moved to the new region or deleted |
| 221 | + // when the original region is replaced by the new. |
| 222 | + OpBuilder builder(&getContext()); |
| 223 | + builder.setInsertionPointAfter(region.getParentOp()); |
| 224 | + Operation *newOp = builder.cloneWithoutRegions(*region.getParentOp()); |
| 225 | + Block &block = newOp->getRegion(0).emplaceBlock(); |
| 226 | + |
| 227 | + llvm::SmallVector<Location> locs; |
| 228 | + locs.reserve(region.getNumArguments()); |
| 229 | + llvm::transform(region.getArguments(), std::back_inserter(locs), |
| 230 | + [](const BlockArgument &arg) { return arg.getLoc(); }); |
| 231 | + block.addArguments(region.getArgumentTypes(), locs); |
| 232 | + |
| 233 | + for (auto [oldArg, newArg] : |
| 234 | + llvm::zip_equal(region.getArguments(), block.getArguments())) |
| 235 | + oldArg.replaceAllUsesWith(newArg); |
| 236 | + |
| 237 | + // Collect omp.map.info ops while satisfying interdependencies. This must be |
| 238 | + // updated whenever new map-like clauses are introduced or they are attached |
| 239 | + // to other operations. |
| 240 | + llvm::SetVector<omp::MapInfoOp> mapInfos; |
| 241 | + for (auto targetOp : targetOps) { |
| 242 | + std::visit( |
| 243 | + [®ion, &mapInfos](auto op) { |
| 244 | + for (Value mapVar : op.getMapVars()) |
| 245 | + collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), |
| 246 | + region, mapInfos); |
| 247 | + |
| 248 | + if constexpr (std::is_same_v<decltype(op), omp::TargetOp>) { |
| 249 | + for (Value mapVar : op.getHasDeviceAddrVars()) |
| 250 | + collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), |
| 251 | + region, mapInfos); |
| 252 | + } else if constexpr (std::is_same_v<decltype(op), |
| 253 | + omp::TargetDataOp>) { |
| 254 | + for (Value mapVar : op.getUseDeviceAddrVars()) |
| 255 | + collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), |
| 256 | + region, mapInfos); |
| 257 | + for (Value mapVar : op.getUseDevicePtrVars()) |
| 258 | + collectMapInfos(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), |
| 259 | + region, mapInfos); |
| 260 | + } |
| 261 | + }, |
| 262 | + targetOp); |
| 263 | + } |
| 264 | + |
| 265 | + // Move omp.map.info ops to the new block and collect dependencies. |
| 266 | + llvm::SetVector<hlfir::DeclareOp> declareOps; |
| 267 | + llvm::SetVector<fir::BoxOffsetOp> boxOffsets; |
| 268 | + llvm::SetVector<Value> rewriteValues; |
| 269 | + for (omp::MapInfoOp mapOp : mapInfos) { |
| 270 | + // Handle var_ptr: hlfir.declare. |
| 271 | + if (auto declareOp = dyn_cast_if_present<hlfir::DeclareOp>( |
| 272 | + mapOp.getVarPtr().getDefiningOp())) { |
| 273 | + if (region.isAncestor(declareOp->getParentRegion())) |
| 274 | + declareOps.insert(declareOp); |
| 275 | + } else { |
| 276 | + markPtrOperandForRewrite(mapOp.getVarPtr(), rewriteValues); |
| 277 | + } |
| 278 | + |
| 279 | + // Handle var_ptr_ptr: fir.box_offset. |
| 280 | + if (Value varPtrPtr = mapOp.getVarPtrPtr()) { |
| 281 | + if (auto boxOffset = llvm::dyn_cast_if_present<fir::BoxOffsetOp>( |
| 282 | + varPtrPtr.getDefiningOp())) { |
| 283 | + if (region.isAncestor(boxOffset->getParentRegion())) |
| 284 | + boxOffsets.insert(boxOffset); |
| 285 | + } else { |
| 286 | + return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported " |
| 287 | + "if defined by fir.box_offset"; |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + // Bounds are not used during target device codegen. |
| 292 | + mapOp.getBoundsMutable().clear(); |
| 293 | + mapOp->moveBefore(&block, block.end()); |
| 294 | + } |
| 295 | + |
| 296 | + // Create a temporary marker to simplify the op moving process below. |
| 297 | + builder.setInsertionPointToStart(&block); |
| 298 | + auto marker = builder.create<fir::UndefOp>(builder.getUnknownLoc(), |
| 299 | + builder.getNoneType()); |
| 300 | + builder.setInsertionPoint(marker); |
| 301 | + |
| 302 | + // Move dependencies of hlfir.declare ops. |
| 303 | + for (hlfir::DeclareOp declareOp : declareOps) { |
| 304 | + Value memref = declareOp.getMemref(); |
| 305 | + |
| 306 | + // If it's defined by fir.address_of, then we need to keep that op as well |
| 307 | + // because it might be pointing to a 'declare target' global. |
| 308 | + if (auto addressOf = |
| 309 | + dyn_cast_if_present<fir::AddrOfOp>(memref.getDefiningOp())) |
| 310 | + addressOf->moveBefore(marker); |
| 311 | + else |
| 312 | + markPtrOperandForRewrite(memref, rewriteValues); |
| 313 | + |
| 314 | + // Shape and typeparams aren't needed for target device codegen, but |
| 315 | + // removing them would break verifiers. |
| 316 | + Value zero; |
| 317 | + if (declareOp.getShape() || !declareOp.getTypeparams().empty()) |
| 318 | + zero = builder.create<arith::ConstantOp>(declareOp.getLoc(), |
| 319 | + builder.getI64IntegerAttr(0)); |
| 320 | + |
| 321 | + if (auto shape = declareOp.getShape()) { |
| 322 | + Operation *shapeOp = shape.getDefiningOp(); |
| 323 | + unsigned numArgs = shapeOp->getNumOperands(); |
| 324 | + if (isa<fir::ShapeShiftOp>(shapeOp)) |
| 325 | + numArgs /= 2; |
| 326 | + |
| 327 | + // Since the pre-cg rewrite pass requires the shape to be defined by one |
| 328 | + // of fir.shape, fir.shapeshift or fir.shift, we need to create one of |
| 329 | + // these. |
| 330 | + llvm::SmallVector<Value> extents(numArgs, zero); |
| 331 | + auto newShape = builder.create<fir::ShapeOp>(shape.getLoc(), extents); |
| 332 | + declareOp.getShapeMutable().assign(newShape); |
| 333 | + } |
| 334 | + |
| 335 | + for (OpOperand &typeParam : declareOp.getTypeparamsMutable()) |
| 336 | + typeParam.assign(zero); |
| 337 | + |
| 338 | + declareOp.getDummyScopeMutable().clear(); |
| 339 | + } |
| 340 | + |
| 341 | + // We don't actually need the proper local allocations, but rather maintain |
| 342 | + // the basic form of map operands. We create 1-bit placeholder allocas |
| 343 | + // that we "typecast" to the expected pointer type and replace all uses. |
| 344 | + // Using fir.undefined here instead is not possible because these variables |
| 345 | + // cannot be constants, as that would trigger different codegen for target |
| 346 | + // regions. |
| 347 | + for (Value value : rewriteValues) { |
| 348 | + Location loc = value.getLoc(); |
| 349 | + Value placeholder = |
| 350 | + builder.create<fir::AllocaOp>(loc, builder.getI1Type()); |
| 351 | + value.replaceAllUsesWith( |
| 352 | + builder.create<fir::ConvertOp>(loc, value.getType(), placeholder)); |
| 353 | + } |
| 354 | + |
| 355 | + // Move omp.map.info dependencies. |
| 356 | + for (hlfir::DeclareOp declareOp : declareOps) |
| 357 | + declareOp->moveBefore(marker); |
| 358 | + |
| 359 | + // The box_ref argument of fir.box_offset is expected to be the same value |
| 360 | + // that was passed as var_ptr to the corresponding omp.map.info, so we don't |
| 361 | + // need to move its defining op here. |
| 362 | + for (fir::BoxOffsetOp boxOffset : boxOffsets) |
| 363 | + boxOffset->moveBefore(marker); |
| 364 | + |
| 365 | + marker->erase(); |
| 366 | + |
| 367 | + // Move mapping information users to the end of the new block. |
| 368 | + for (auto targetOp : targetOps) |
| 369 | + std::visit([&block](auto op) { op->moveBefore(&block, block.end()); }, |
| 370 | + targetOp); |
| 371 | + |
| 372 | + // Add terminator to the new block. |
| 373 | + builder.setInsertionPointToEnd(&block); |
| 374 | + if (funcOp) { |
| 375 | + llvm::SmallVector<Value> returnValues; |
| 376 | + returnValues.reserve(funcOp.getNumResults()); |
| 377 | + for (auto type : funcOp.getResultTypes()) |
| 378 | + returnValues.push_back( |
| 379 | + builder.create<fir::UndefOp>(funcOp.getLoc(), type)); |
| 380 | + |
| 381 | + builder.create<func::ReturnOp>(funcOp.getLoc(), returnValues); |
| 382 | + } else { |
| 383 | + builder.create<omp::TerminatorOp>(targetDataOp.getLoc()); |
| 384 | + } |
| 385 | + |
| 386 | + // Replace old (now missing ops) region with the new one and remove the |
| 387 | + // temporary clone. |
| 388 | + region.takeBody(newOp->getRegion(0)); |
| 389 | + newOp->erase(); |
| 390 | + return success(); |
| 391 | + } |
104 | 392 | };
|
105 | 393 | } // namespace
|
0 commit comments