@@ -87,6 +87,23 @@ convertArrayAttrToGlobalConstant(MLIRContext *ctx, Location loc,
8787}
8888
8989namespace {
90+
91+ // This pattern replaces a cc.const_array with a global constant. It can
92+ // recognize a couple of usage patterns and will generate efficient IR in those
93+ // cases.
94+ //
95+ // Pattern 1: The entire constant array is stored to a stack variable(s). Here
96+ // we can eliminate the stack allocation and use the global constant.
97+ //
98+ // Pattern 2: Individual elements at dynamic offsets are extracted from the
99+ // constant array and used. This can be replaced with a compute pointer
100+ // operation using the global constant and a load of the element at the computed
101+ // offset.
102+ //
103+ // Default: If the usage is not recognized, the constant array value is replaced
104+ // with a load of the entire global variable. In this case, LLVM's optimizations
105+ // are counted on to help demote the (large?) sequence value to primitive memory
106+ // address arithmetic.
90107struct ConstantArrayPattern
91108 : public OpRewritePattern<cudaq::cc::ConstantArrayOp> {
92109 explicit ConstantArrayPattern (MLIRContext *ctx, ModuleOp module ,
@@ -95,21 +112,30 @@ struct ConstantArrayPattern
95112
96113 LogicalResult matchAndRewrite (cudaq::cc::ConstantArrayOp conarr,
97114 PatternRewriter &rewriter) const override {
115+ auto func = conarr->getParentOfType <func::FuncOp>();
116+ if (!func)
117+ return failure ();
118+
98119 SmallVector<cudaq::cc::AllocaOp> allocas;
99120 SmallVector<cudaq::cc::StoreOp> stores;
121+ SmallVector<cudaq::cc::ExtractValueOp> extracts;
122+ bool loadAsValue = false ;
100123 for (auto *usr : conarr->getUsers ()) {
101124 auto store = dyn_cast<cudaq::cc::StoreOp>(usr);
102- if (!store)
103- return failure ();
104- auto alloca = store.getPtrvalue ().getDefiningOp <cudaq::cc::AllocaOp>();
105- if (!alloca)
106- return failure ();
107- stores.push_back (store);
108- allocas.push_back (alloca);
125+ auto extract = dyn_cast<cudaq::cc::ExtractValueOp>(usr);
126+ if (store) {
127+ auto alloca = store.getPtrvalue ().getDefiningOp <cudaq::cc::AllocaOp>();
128+ if (alloca) {
129+ stores.push_back (store);
130+ allocas.push_back (alloca);
131+ continue ;
132+ }
133+ } else if (extract) {
134+ extracts.push_back (extract);
135+ continue ;
136+ }
137+ loadAsValue = true ;
109138 }
110- auto func = conarr->getParentOfType <func::FuncOp>();
111- if (!func)
112- return failure ();
113139 std::string globalName =
114140 func.getName ().str () + " .rodata_" + std::to_string (counter++);
115141 auto *ctx = rewriter.getContext ();
@@ -118,12 +144,39 @@ struct ConstantArrayPattern
118144 if (failed (convertArrayAttrToGlobalConstant (ctx, conarr.getLoc (), valueAttr,
119145 module , globalName, eleTy)))
120146 return failure ();
121- for (auto alloca : allocas)
122- rewriter.replaceOpWithNewOp <cudaq::cc::AddressOfOp>(
123- alloca, alloca.getType (), globalName);
124- for (auto store : stores)
125- rewriter.eraseOp (store);
126- rewriter.eraseOp (conarr);
147+ auto loc = conarr.getLoc ();
148+ if (!extracts.empty ()) {
149+ auto base = rewriter.create <cudaq::cc::AddressOfOp>(
150+ loc, cudaq::cc::PointerType::get (conarr.getType ()), globalName);
151+ auto elePtrTy = cudaq::cc::PointerType::get (eleTy);
152+ for (auto extract : extracts) {
153+ SmallVector<cudaq::cc::ComputePtrArg> args;
154+ unsigned i = 0 ;
155+ for (auto arg : extract.getRawConstantIndices ()) {
156+ if (arg == cudaq::cc::ExtractValueOp::getDynamicIndexValue ())
157+ args.push_back (extract.getDynamicIndices ()[i++]);
158+ else
159+ args.push_back (arg);
160+ }
161+ OpBuilder::InsertionGuard guard (rewriter);
162+ rewriter.setInsertionPoint (extract);
163+ auto addrVal =
164+ rewriter.create <cudaq::cc::ComputePtrOp>(loc, elePtrTy, base, args);
165+ rewriter.replaceOpWithNewOp <cudaq::cc::LoadOp>(extract, addrVal);
166+ }
167+ }
168+ if (!stores.empty ()) {
169+ for (auto alloca : allocas)
170+ rewriter.replaceOpWithNewOp <cudaq::cc::AddressOfOp>(
171+ alloca, alloca.getType (), globalName);
172+ for (auto store : stores)
173+ rewriter.eraseOp (store);
174+ }
175+ if (loadAsValue) {
176+ auto base = rewriter.create <cudaq::cc::AddressOfOp>(
177+ loc, cudaq::cc::PointerType::get (conarr.getType ()), globalName);
178+ rewriter.replaceOpWithNewOp <cudaq::cc::LoadOp>(conarr, base);
179+ }
127180 return success ();
128181 }
129182
0 commit comments