@@ -196,26 +196,38 @@ mlir::Value replaceWithAddrOfOrASCast(mlir::ConversionPatternRewriter &rewriter,
196196 return mlir::LLVM::AddressOfOp::create (rewriter, loc, type, symName);
197197}
198198
199- static std::uint64_t getAddressSpace (fir::AddrOfOp addr,
200- mlir::ConversionPatternRewriter &rewriter,
201- std::uint64_t defaultAS) {
202- auto global = addr->getParentOfType <mlir::ModuleOp>()
203- .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
204- if (global)
205- return global.getAddrSpace ();
206- auto firGlobal =
207- addr->getParentOfType <mlir::ModuleOp>().lookupSymbol <fir::GlobalOp>(
208- addr.getSymbol ());
209- if (firGlobal && firGlobal.getDataAttr () &&
210- *firGlobal.getDataAttr () == cuf::DataAttribute::Constant)
211- return static_cast <unsigned >(mlir::NVVM::NVVMMemorySpace::Constant);
212- return defaultAS;
199+ // / Return the NVVM address space implied by a CUF data attribute on a
200+ // / fir::GlobalOp that has not yet been converted to llvm.mlir.global.
201+ // / Returns std::nullopt if no CUF-specific address space applies.
202+ static std::optional<unsigned > getCUFAddrSpace (fir::GlobalOp global) {
203+ if (auto dataAttr = global.getDataAttr ()) {
204+ if (*dataAttr == cuf::DataAttribute::Constant)
205+ return static_cast <unsigned >(mlir::NVVM::NVVMMemorySpace::Constant);
206+ if (*dataAttr == cuf::DataAttribute::Shared)
207+ return static_cast <unsigned >(mlir::NVVM::NVVMMemorySpace::Shared);
208+ if (*dataAttr == cuf::DataAttribute::Managed)
209+ return static_cast <unsigned >(mlir::NVVM::NVVMMemorySpace::Global);
210+ }
211+ return std::nullopt ;
213212}
214213
215214// / Lower `fir.address_of` operation to `llvm.address_of` operation.
216215struct AddrOfOpConversion : public fir ::FIROpConversion<fir::AddrOfOp> {
217216 using FIROpConversion::FIROpConversion;
218217
218+ // / Look up the address space for a symbol in \p mod, handling both
219+ // / already-converted llvm.mlir.global and not-yet-converted fir.global.
220+ template <typename ModOp>
221+ unsigned getAddrSpaceForGlobal (ModOp mod, mlir::SymbolRefAttr sym,
222+ unsigned fallback) const {
223+ if (auto g = mod.template lookupSymbol <mlir::LLVM::GlobalOp>(sym))
224+ return g.getAddrSpace ();
225+ if (auto g = mod.template lookupSymbol <fir::GlobalOp>(sym))
226+ if (auto as = getCUFAddrSpace (g))
227+ return *as;
228+ return fallback;
229+ }
230+
219231 llvm::LogicalResult
220232 matchAndRewrite (fir::AddrOfOp addr, OpAdaptor adaptor,
221233 mlir::ConversionPatternRewriter &rewriter) const override {
@@ -224,20 +236,22 @@ struct AddrOfOpConversion : public fir::FIROpConversion<fir::AddrOfOp> {
224236 auto global = gpuMod.lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
225237 replaceWithAddrOfOrASCast (
226238 rewriter, addr->getLoc (),
227- global ? global.getAddrSpace () : getGlobalAddressSpace (rewriter),
239+ getAddrSpaceForGlobal (gpuMod, addr.getSymbol (),
240+ getGlobalAddressSpace (rewriter)),
228241 getProgramAddressSpace (rewriter),
229242 global ? global.getSymName ()
230243 : addr.getSymbol ().getRootReference ().getValue (),
231244 convertType (addr.getType ()), addr);
232245 return mlir::success ();
233246 }
234247
235- std::uint64_t globalAS =
236- getAddressSpace (addr, rewriter, getGlobalAddressSpace (rewriter));
237- auto global = addr->getParentOfType <mlir::ModuleOp>()
238- .lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
248+ auto mod = addr->getParentOfType <mlir::ModuleOp>();
249+ auto global = mod.lookupSymbol <mlir::LLVM::GlobalOp>(addr.getSymbol ());
239250 replaceWithAddrOfOrASCast (
240- rewriter, addr->getLoc (), globalAS, getProgramAddressSpace (rewriter),
251+ rewriter, addr->getLoc (),
252+ getAddrSpaceForGlobal (mod, addr.getSymbol (),
253+ getGlobalAddressSpace (rewriter)),
254+ getProgramAddressSpace (rewriter),
241255 global ? global.getSymName ()
242256 : addr.getSymbol ().getRootReference ().getValue (),
243257 convertType (addr.getType ()), addr);
0 commit comments