Skip to content

Commit 6841450

Browse files
committed
[NFC] Share body generation callback between task and taskloop
1 parent 257aaea commit 6841450

File tree

1 file changed

+101
-172
lines changed

1 file changed

+101
-172
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 101 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -2350,11 +2350,105 @@ void TaskContextStructManager::freeStructPtr() {
23502350
builder.CreateFree(structPtr);
23512351
}
23522352

2353+
using TaskLikeBodyGenCallbackTy =
2354+
std::function<llvm::Error(llvm::OpenMPIRBuilder::InsertPointTy allocaIP,
2355+
llvm::OpenMPIRBuilder::InsertPointTy codegenIP)>;
2356+
2357+
/// Build the body generation callback shared by task-like constructs (task and
2358+
/// taskloop).
2359+
static TaskLikeBodyGenCallbackTy buildTaskLikeBodyGenCallback(
2360+
Operation *opInst, Region &region, StringRef regionName,
2361+
llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
2362+
PrivateVarsInfo &privateVarsInfo, TaskContextStructManager &taskStructMgr) {
2363+
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2364+
return [&, regionName](InsertPointTy allocaIP,
2365+
InsertPointTy codegenIP) -> llvm::Error {
2366+
// Save the alloca insertion point on ModuleTranslation stack for use in
2367+
// nested regions.
2368+
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2369+
moduleTranslation, allocaIP);
2370+
2371+
// translate the body of the task:
2372+
builder.restoreIP(codegenIP);
2373+
2374+
llvm::BasicBlock *privInitBlock = nullptr;
2375+
privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2376+
for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2377+
privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2378+
privateVarsInfo.mlirVars))) {
2379+
auto [blockArg, privDecl, mlirPrivVar] = zip;
2380+
// This is handled before the task executes
2381+
if (privDecl.readsFromMold())
2382+
continue;
2383+
2384+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
2385+
llvm::Type *llvmAllocType =
2386+
moduleTranslation.convertType(privDecl.getType());
2387+
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2388+
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2389+
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2390+
2391+
llvm::Expected<llvm::Value *> privateVarOrError =
2392+
initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2393+
blockArg, llvmPrivateVar, privInitBlock);
2394+
if (!privateVarOrError)
2395+
return privateVarOrError.takeError();
2396+
moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2397+
privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2398+
}
2399+
2400+
taskStructMgr.createGEPsToPrivateVars();
2401+
for (auto [i, llvmPrivVar] :
2402+
llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2403+
if (!llvmPrivVar) {
2404+
assert(privateVarsInfo.llvmVars[i] &&
2405+
"This is added in the loop above");
2406+
continue;
2407+
}
2408+
privateVarsInfo.llvmVars[i] = llvmPrivVar;
2409+
}
2410+
2411+
// Find and map the addresses of each variable within the task context
2412+
// structure
2413+
for (auto [blockArg, llvmPrivateVar, privateDecl] :
2414+
llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2415+
privateVarsInfo.privatizers)) {
2416+
// This was handled above.
2417+
if (!privateDecl.readsFromMold())
2418+
continue;
2419+
// Fix broken pass-by-value case for Fortran character boxes
2420+
if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2421+
llvmPrivateVar = builder.CreateLoad(
2422+
moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2423+
}
2424+
assert(llvmPrivateVar->getType() ==
2425+
moduleTranslation.convertType(blockArg.getType()));
2426+
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2427+
}
2428+
2429+
auto continuationBlockOrError =
2430+
convertOmpOpRegions(region, regionName, builder, moduleTranslation);
2431+
if (failed(handleError(continuationBlockOrError, *opInst)))
2432+
return llvm::make_error<PreviouslyReportedError>();
2433+
2434+
builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2435+
2436+
if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst->getLoc(),
2437+
privateVarsInfo.llvmVars,
2438+
privateVarsInfo.privatizers)))
2439+
return llvm::make_error<PreviouslyReportedError>();
2440+
2441+
// Free heap allocated task context structure at the end of the task.
2442+
taskStructMgr.freeStructPtr();
2443+
2444+
return llvm::Error::success();
2445+
};
2446+
}
2447+
23532448
/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
23542449
static LogicalResult
23552450
convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
23562451
LLVM::ModuleTranslation &moduleTranslation) {
2357-
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
23582452
if (failed(checkImplementationStatus(*taskOp)))
23592453
return failure();
23602454

@@ -2467,88 +2561,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
24672561
// Set up for call to createTask()
24682562
builder.SetInsertPoint(taskStartBlock);
24692563

2470-
auto bodyCB = [&](InsertPointTy allocaIP,
2471-
InsertPointTy codegenIP) -> llvm::Error {
2472-
// Save the alloca insertion point on ModuleTranslation stack for use in
2473-
// nested regions.
2474-
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2475-
moduleTranslation, allocaIP);
2476-
2477-
// translate the body of the task:
2478-
builder.restoreIP(codegenIP);
2479-
2480-
llvm::BasicBlock *privInitBlock = nullptr;
2481-
privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2482-
for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2483-
privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2484-
privateVarsInfo.mlirVars))) {
2485-
auto [blockArg, privDecl, mlirPrivVar] = zip;
2486-
// This is handled before the task executes
2487-
if (privDecl.readsFromMold())
2488-
continue;
2489-
2490-
llvm::IRBuilderBase::InsertPointGuard guard(builder);
2491-
llvm::Type *llvmAllocType =
2492-
moduleTranslation.convertType(privDecl.getType());
2493-
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2494-
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2495-
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2496-
2497-
llvm::Expected<llvm::Value *> privateVarOrError =
2498-
initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2499-
blockArg, llvmPrivateVar, privInitBlock);
2500-
if (!privateVarOrError)
2501-
return privateVarOrError.takeError();
2502-
moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2503-
privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2504-
}
2505-
2506-
taskStructMgr.createGEPsToPrivateVars();
2507-
for (auto [i, llvmPrivVar] :
2508-
llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2509-
if (!llvmPrivVar) {
2510-
assert(privateVarsInfo.llvmVars[i] &&
2511-
"This is added in the loop above");
2512-
continue;
2513-
}
2514-
privateVarsInfo.llvmVars[i] = llvmPrivVar;
2515-
}
2516-
2517-
// Find and map the addresses of each variable within the task context
2518-
// structure
2519-
for (auto [blockArg, llvmPrivateVar, privateDecl] :
2520-
llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2521-
privateVarsInfo.privatizers)) {
2522-
// This was handled above.
2523-
if (!privateDecl.readsFromMold())
2524-
continue;
2525-
// Fix broken pass-by-value case for Fortran character boxes
2526-
if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2527-
llvmPrivateVar = builder.CreateLoad(
2528-
moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2529-
}
2530-
assert(llvmPrivateVar->getType() ==
2531-
moduleTranslation.convertType(blockArg.getType()));
2532-
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2533-
}
2534-
2535-
auto continuationBlockOrError = convertOmpOpRegions(
2536-
taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2537-
if (failed(handleError(continuationBlockOrError, *taskOp)))
2538-
return llvm::make_error<PreviouslyReportedError>();
2539-
2540-
builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2541-
2542-
if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2543-
privateVarsInfo.llvmVars,
2544-
privateVarsInfo.privatizers)))
2545-
return llvm::make_error<PreviouslyReportedError>();
2546-
2547-
// Free heap allocated task context structure at the end of the task.
2548-
taskStructMgr.freeStructPtr();
2549-
2550-
return llvm::Error::success();
2551-
};
2564+
auto bodyCB = buildTaskLikeBodyGenCallback(
2565+
taskOp, taskOp.getRegion(), "omp.task.region", builder, moduleTranslation,
2566+
privateVarsInfo, taskStructMgr);
25522567

25532568
llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
25542569
SmallVector<llvm::BranchInst *> cancelTerminators;
@@ -2669,95 +2684,9 @@ convertOmpTaskloopOp(Operation &opInst, llvm::IRBuilderBase &builder,
26692684
// Set up inserttion point for call to createTaskloop()
26702685
builder.SetInsertPoint(taskloopStartBlock);
26712686

2672-
auto bodyCB = [&](InsertPointTy allocaIP,
2673-
InsertPointTy codegenIP) -> llvm::Error {
2674-
// Save the alloca insertion point on ModuleTranslation stack for use in
2675-
// nested regions.
2676-
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2677-
moduleTranslation, allocaIP);
2678-
2679-
// translate the body of the taskloop:
2680-
builder.restoreIP(codegenIP);
2681-
2682-
llvm::BasicBlock *privInitBlock = nullptr;
2683-
privateVarsInfo.llvmVars.resize(privateVarsInfo.blockArgs.size());
2684-
for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2685-
privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2686-
privateVarsInfo.mlirVars))) {
2687-
auto [blockArg, privDecl, mlirPrivVar] = zip;
2688-
// This is handled before the task executes
2689-
if (privDecl.readsFromMold())
2690-
continue;
2691-
2692-
llvm::IRBuilderBase::InsertPointGuard guard(builder);
2693-
llvm::Type *llvmAllocType =
2694-
moduleTranslation.convertType(privDecl.getType());
2695-
builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2696-
llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2697-
llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2698-
2699-
llvm::Expected<llvm::Value *> privateVarOrError =
2700-
initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2701-
blockArg, llvmPrivateVar, privInitBlock);
2702-
if (!privateVarOrError)
2703-
return privateVarOrError.takeError();
2704-
moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2705-
privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2706-
}
2707-
2708-
taskStructMgr.createGEPsToPrivateVars();
2709-
for (auto [i, llvmPrivVar] :
2710-
llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2711-
if (!llvmPrivVar) {
2712-
assert(privateVarsInfo.llvmVars[i] &&
2713-
"This is added in the loop above");
2714-
continue;
2715-
}
2716-
privateVarsInfo.llvmVars[i] = llvmPrivVar;
2717-
}
2718-
2719-
// Find and map the addresses of each variable within the taskloop context
2720-
// structure
2721-
for (auto [blockArg, llvmPrivateVar, privateDecl] :
2722-
llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2723-
privateVarsInfo.privatizers)) {
2724-
// This was handled above.
2725-
if (!privateDecl.readsFromMold())
2726-
continue;
2727-
// Fix broken pass-by-value case for Fortran character boxes
2728-
if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2729-
llvmPrivateVar = builder.CreateLoad(
2730-
moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2731-
}
2732-
assert(llvmPrivateVar->getType() ==
2733-
moduleTranslation.convertType(blockArg.getType()));
2734-
moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2735-
}
2736-
2737-
auto continuationBlockOrError =
2738-
convertOmpOpRegions(taskloopOp.getRegion(), "omp.taskloop.region",
2739-
builder, moduleTranslation);
2740-
2741-
if (failed(handleError(continuationBlockOrError, opInst)))
2742-
return llvm::make_error<PreviouslyReportedError>();
2743-
2744-
builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2745-
2746-
// This is freeing the private variables as mapped inside of the task: these
2747-
// will be per-task private copies possibly after task duplication. This is
2748-
// handled transparently by how these are passed to the structure passed
2749-
// into the outlined function. When the task is duplicated, that structure
2750-
// is duplicated too.
2751-
if (failed(cleanupPrivateVars(builder, moduleTranslation,
2752-
taskloopOp.getLoc(), privateVarsInfo.llvmVars,
2753-
privateVarsInfo.privatizers)))
2754-
return llvm::make_error<PreviouslyReportedError>();
2755-
// Similarly, the task context structure freed inside the task is the
2756-
// per-task copy after task duplication.
2757-
taskStructMgr.freeStructPtr();
2758-
2759-
return llvm::Error::success();
2760-
};
2687+
auto bodyCB = buildTaskLikeBodyGenCallback(
2688+
&opInst, taskloopOp.getRegion(), "omp.taskloop.region", builder,
2689+
moduleTranslation, privateVarsInfo, taskStructMgr);
27612690

27622691
// Taskloop divides into an appropriate number of tasks by repeatedly
27632692
// duplicating the original task. Each time this is done, the task context

0 commit comments

Comments
 (0)