Skip to content

Commit f09850a

Browse files
authored
[flang][CodeGen] Fix address space mismatch for CUF globals in AddrOfOpConversion (llvm#190408)
AddrOfOpConversion in CodeGen.cpp only handled `LLVM::GlobalOp` when determining the address space for `llvm.mlir.addressof`. When the global was still a `fir::GlobalOp` (not yet converted), it fell back to address space 0, breaking CUF constant globals (addr_space 4) and AMDGPU targets (global addr_space 1). This extends the upstream fix (llvm#192111, which only covered Constant) to also handle Shared and Managed CUF data attributes, and returns `std::nullopt` instead of 0 for non-CUF globals so the target's default address space is preserved.
1 parent 8d0997f commit f09850a

File tree

2 files changed

+77
-21
lines changed

2 files changed

+77
-21
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -196,26 +196,38 @@ mlir::Value replaceWithAddrOfOrASCast(mlir::ConversionPatternRewriter &rewriter,
196196
return mlir::LLVM::AddressOfOp::create(rewriter, loc, type, symName);
197197
}
198198

199-
static std::uint64_t getAddressSpace(fir::AddrOfOp addr,
200-
mlir::ConversionPatternRewriter &rewriter,
201-
std::uint64_t defaultAS) {
202-
auto global = addr->getParentOfType<mlir::ModuleOp>()
203-
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
204-
if (global)
205-
return global.getAddrSpace();
206-
auto firGlobal =
207-
addr->getParentOfType<mlir::ModuleOp>().lookupSymbol<fir::GlobalOp>(
208-
addr.getSymbol());
209-
if (firGlobal && firGlobal.getDataAttr() &&
210-
*firGlobal.getDataAttr() == cuf::DataAttribute::Constant)
211-
return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant);
212-
return defaultAS;
199+
/// Return the NVVM address space implied by a CUF data attribute on a
200+
/// fir::GlobalOp that has not yet been converted to llvm.mlir.global.
201+
/// Returns std::nullopt if no CUF-specific address space applies.
202+
static std::optional<unsigned> getCUFAddrSpace(fir::GlobalOp global) {
203+
if (auto dataAttr = global.getDataAttr()) {
204+
if (*dataAttr == cuf::DataAttribute::Constant)
205+
return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Constant);
206+
if (*dataAttr == cuf::DataAttribute::Shared)
207+
return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Shared);
208+
if (*dataAttr == cuf::DataAttribute::Managed)
209+
return static_cast<unsigned>(mlir::NVVM::NVVMMemorySpace::Global);
210+
}
211+
return std::nullopt;
213212
}
214213

215214
/// Lower `fir.address_of` operation to `llvm.address_of` operation.
216215
struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
217216
using FIROpConversion::FIROpConversion;
218217

218+
/// Look up the address space for a symbol in \p mod, handling both
219+
/// already-converted llvm.mlir.global and not-yet-converted fir.global.
220+
template <typename ModOp>
221+
unsigned getAddrSpaceForGlobal(ModOp mod, mlir::SymbolRefAttr sym,
222+
unsigned fallback) const {
223+
if (auto g = mod.template lookupSymbol<mlir::LLVM::GlobalOp>(sym))
224+
return g.getAddrSpace();
225+
if (auto g = mod.template lookupSymbol<fir::GlobalOp>(sym))
226+
if (auto as = getCUFAddrSpace(g))
227+
return *as;
228+
return fallback;
229+
}
230+
219231
llvm::LogicalResult
220232
matchAndRewrite(fir::AddrOfOp addr, OpAdaptor adaptor,
221233
mlir::ConversionPatternRewriter &rewriter) const override {
@@ -224,20 +236,22 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
224236
auto global = gpuMod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
225237
replaceWithAddrOfOrASCast(
226238
rewriter, addr->getLoc(),
227-
global ? global.getAddrSpace() : getGlobalAddressSpace(rewriter),
239+
getAddrSpaceForGlobal(gpuMod, addr.getSymbol(),
240+
getGlobalAddressSpace(rewriter)),
228241
getProgramAddressSpace(rewriter),
229242
global ? global.getSymName()
230243
: addr.getSymbol().getRootReference().getValue(),
231244
convertType(addr.getType()), addr);
232245
return mlir::success();
233246
}
234247

235-
std::uint64_t globalAS =
236-
getAddressSpace(addr, rewriter, getGlobalAddressSpace(rewriter));
237-
auto global = addr->getParentOfType<mlir::ModuleOp>()
238-
.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
248+
auto mod = addr->getParentOfType<mlir::ModuleOp>();
249+
auto global = mod.lookupSymbol<mlir::LLVM::GlobalOp>(addr.getSymbol());
239250
replaceWithAddrOfOrASCast(
240-
rewriter, addr->getLoc(), globalAS, getProgramAddressSpace(rewriter),
251+
rewriter, addr->getLoc(),
252+
getAddrSpaceForGlobal(mod, addr.getSymbol(),
253+
getGlobalAddressSpace(rewriter)),
254+
getProgramAddressSpace(rewriter),
241255
global ? global.getSymName()
242256
: addr.getSymbol().getRootReference().getValue(),
243257
convertType(addr.getType()), addr);

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ func.func @sub16_(%arg0: !fir.ref<f32> {fir.bindc_name = "h16"}) attributes {fir
349349
return
350350
}
351351
fir.global linkonce @_QQclXc8657e47c19bb9e89730387c9d99c2da constant : !fir.char<1,38> {
352-
%0 = fir.string_lit "/local/home/vclement/lorado/dummy.cuf\00"(38) : !fir.char<1,38>
352+
%0 = fir.string_lit "/path/to/source/test/dummy_module.cuf\00"(38) : !fir.char<1,38>
353353
fir.has_value %0 : !fir.char<1,38>
354354
}
355355
fir.global @_QMdevice_dataEd16 {data_attr = #cuf.cuda<constant>} : f32 {
@@ -360,3 +360,45 @@ func.func private @_FortranACUFGetDeviceAddress(!fir.llvm_ptr<i8>, !fir.ref<i8>,
360360

361361
// CHECK-LABEL: llvm.func @sub16_
362362
// CHECK: llvm.mlir.addressof @_QMdevice_dataEd16 : !llvm.ptr<4>
363+
364+
// -----
365+
366+
// Test that a host-side fir.address_of referencing a fir.global with CUF
367+
// shared data_attr produces an addrspacecast from ptr<3> to ptr.
368+
369+
fir.global @_QMmodEsval {data_attr = #cuf.cuda<shared>} : i32 {
370+
%0 = fir.zero_bits i32
371+
fir.has_value %0 : i32
372+
}
373+
func.func @_QQhost_shared() {
374+
%0 = fir.address_of(@_QMmodEsval) : !fir.ref<i32>
375+
return
376+
}
377+
378+
// CHECK: llvm.mlir.global external @_QMmodEsval() {addr_space = 3 : i32} : i32
379+
// CHECK-LABEL: llvm.func @_QQhost_shared()
380+
// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @_QMmodEsval : !llvm.ptr<3>
381+
// CHECK: %{{.*}} = llvm.addrspacecast %[[ADDR]] : !llvm.ptr<3> to !llvm.ptr
382+
383+
// -----
384+
385+
// Test that fir.address_of inside gpu.module referencing a managed fir.global
386+
// produces an addressof with ptr<1> and an addrspacecast.
387+
388+
module attributes {gpu.container_module} {
389+
gpu.module @cuda_device_mod {
390+
fir.global @_QMmodEmval {data_attr = #cuf.cuda<managed>} : i32 {
391+
%0 = fir.zero_bits i32
392+
fir.has_value %0 : i32
393+
}
394+
gpu.func @_QMkernelsPuse_managed() kernel {
395+
%0 = fir.address_of(@_QMmodEmval) : !fir.ref<i32>
396+
gpu.return
397+
}
398+
}
399+
}
400+
401+
// CHECK: llvm.mlir.global external @_QMmodEmval() {addr_space = 1 : i32, nvvm.managed} : i32
402+
// CHECK-LABEL: gpu.func @_QMkernelsPuse_managed()
403+
// CHECK: %[[ADDR:.*]] = llvm.mlir.addressof @_QMmodEmval : !llvm.ptr<1>
404+
// CHECK: %{{.*}} = llvm.addrspacecast %[[ADDR]] : !llvm.ptr<1> to !llvm.ptr

0 commit comments

Comments
 (0)