Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/CodeGen_PTX_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,16 @@ void CodeGen_PTX_Dev::visit(const Call *op) {
auto fence_type_ptr = as_const_int(op->args[0]);
internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n";

llvm::Function *barrier0 = module->getFunction("llvm.nvvm.barrier0");
internal_assert(barrier0) << "Could not find PTX barrier intrinsic (llvm.nvvm.barrier0)\n";
builder->CreateCall(barrier0);
llvm::Function *barrier;
if ((barrier = module->getFunction("llvm.nvvm.barrier.cta.sync.aligned.all")) && barrier->getIntrinsicID() != 0) {
// LLVM 20.1.6 and above: https://github.com/llvm/llvm-project/pull/140615
builder->CreateCall(barrier, builder->getInt32(0));
} else if ((barrier = module->getFunction("llvm.nvvm.barrier0")) && barrier->getIntrinsicID() != 0) {
// LLVM 21.1.5 and below: Testing for llvm.nvvm.barrier0 can be removed once we drop support for LLVM 20
builder->CreateCall(barrier);
} else {
internal_error << "Could not find PTX barrier intrinsic llvm.nvvm.barrier0 nor llvm.nvvm.barrier.cta.sync.aligned.all\n";
}
value = ConstantInt::get(i32_t, 0);
return;
}
Expand Down
9 changes: 8 additions & 1 deletion src/runtime/ptx_dev.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
declare void @llvm.nvvm.barrier0()
; The two forward declared intrinsics below refer to the same thing.
; LLVM 20.1.6 introduced a new naming scheme for these intrinsics
; We have to declare both, such that we can access them from the Module's
; getFunction(), but one of those will map to an intrinsic, which we
; will use to determine which intrinsic is supported by LLVM.
declare void @llvm.nvvm.barrier0() ; LLVM <=20.1.5
declare void @llvm.nvvm.barrier.cta.sync.aligned.all(i32) ; LLVM >=20.1.6

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
Expand Down
Loading