Skip to content

Commit a304c4f

Browse files
authored
Fix issues in lift-array-alloc (#2570)
* Fix issues in lif-array-alloc * DCO Remediation Commit for Anna Gringauze <agringauze@nvidia.com> I, Anna Gringauze <agringauze@nvidia.com>, hereby add my Signed-off-by to this commit: c1592b8 Signed-off-by: Anna Gringauze <agringauze@nvidia.com> * Addressed CR comments Signed-off-by: Anna Gringauze <agringauze@nvidia.com> * Address CR comments Signed-off-by: Anna Gringauze <agringauze@nvidia.com> * Add new pass Signed-off-by: Anna Gringauze <agringauze@nvidia.com> * Fix null deref Signed-off-by: Anna Gringauze <agringauze@nvidia.com> --------- Signed-off-by: Anna Gringauze <agringauze@nvidia.com>
1 parent eaae1c6 commit a304c4f

File tree

3 files changed

+85
-9
lines changed

3 files changed

+85
-9
lines changed

lib/Optimizer/Transforms/LiftArrayAlloc.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
4040
return failure();
4141

4242
LLVM_DEBUG(llvm::dbgs() << "Candidate was found\n");
43-
auto eleTy = alloc.getElementType();
44-
auto arrTy = cast<cudaq::cc::ArrayType>(eleTy);
43+
auto allocTy = alloc.getElementType();
44+
auto arrTy = cast<cudaq::cc::ArrayType>(allocTy);
45+
auto eleTy = arrTy.getElementType();
46+
4547
SmallVector<Attribute> values;
4648

4749
// Every element of `stores` must be a cc::StoreOp with a ConstantOp as the
@@ -89,6 +91,8 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
8991
cannotEraseAlloc = isLive = true;
9092
} else {
9193
for (auto *useuser : user->getUsers()) {
94+
if (!useuser)
95+
continue;
9296
if (auto load = dyn_cast<cudaq::cc::LoadOp>(useuser)) {
9397
rewriter.setInsertionPointAfter(useuser);
9498
LLVM_DEBUG(llvm::dbgs() << "replaced load\n");
@@ -160,14 +164,13 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
160164
if (!u)
161165
return nullptr;
162166
if (auto store = dyn_cast<cudaq::cc::StoreOp>(u)) {
163-
if (op.getOperation() == store.getPtrvalue().getDefiningOp() &&
164-
isa_and_present<arith::ConstantOp, complex::ConstantOp>(
165-
store.getValue().getDefiningOp())) {
167+
if (op.getOperation() == store.getPtrvalue().getDefiningOp()) {
166168
if (theStore) {
167169
LLVM_DEBUG(llvm::dbgs()
168170
<< "more than 1 store to element of array\n");
169171
return nullptr;
170172
}
173+
LLVM_DEBUG(llvm::dbgs() << "found store: " << store << "\n");
171174
theStore = u;
172175
}
173176
continue;
@@ -182,7 +185,13 @@ class AllocaPattern : public OpRewritePattern<cudaq::cc::AllocaOp> {
182185
}
183186
return nullptr;
184187
}
185-
return theStore;
188+
return theStore &&
189+
isa_and_present<arith::ConstantOp, complex::ConstantOp>(
190+
dyn_cast<cudaq::cc::StoreOp>(theStore)
191+
.getValue()
192+
.getDefiningOp())
193+
? theStore
194+
: nullptr;
186195
};
187196

188197
auto unsizedArrTy = cudaq::cc::ArrayType::get(arrEleTy);

targettests/execution/state_preparation_vector_sizes.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
#include <cudaq.h>
2222
#include <iostream>
2323

24-
#include <cudaq.h>
25-
#include <iostream>
26-
2724
__qpu__ void test(std::vector<cudaq::complex> inState) {
2825
cudaq::qvector q1 = inState;
2926
}

test/Quake/lift_array.qke

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,73 @@ func.func @test2() -> !quake.veq<2> {
125125
// GLOBAL-DAG: cc.global constant private @__nvqpp__mlirgen__function_test_complex_constant_array._Z27test_complex_constant_arrayv.rodata_{{[0-9]+}} (dense<[(0.707106769,0.000000e+00), (0.707106769,0.000000e+00), (0.000000e+00,0.000000e+00), (0.000000e+00,0.000000e+00)]> : tensor<4xcomplex<f32>>) : !cc.array<complex<f32> x 4>
126126
// GLOBAL-DAG: cc.global constant private @__nvqpp__mlirgen__function_custom_h_generator_1._Z20custom_h_generator_1v.rodata_{{[0-9]+}} (dense<[(0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (0.70710678118654757,0.000000e+00), (-0.70710678118654757,0.000000e+00)]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
127127
// GLOBAL-DAG: cc.global constant private @test2.rodata_{{[0-9]+}} (dense<[1.000000e+00, 2.000000e+00, 6.000000e+00, 9.000000e+00]>" : tensor<4xf64>) : !cc.array<f64 x 4>
128+
129+
func.func @test_two_stores() {
130+
%c0_i64 = arith.constant 0 : i64
131+
%c1_i64 = arith.constant 1 : i64
132+
133+
// qubits = cudaq.qvector(2)
134+
%0 = quake.alloca !quake.veq<2>
135+
136+
// arr1 = [1]
137+
%1 = cc.alloca !cc.array<i64 x 1>
138+
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
139+
cc.store %c1_i64, %2 : !cc.ptr<i64>
140+
141+
// t = arr1[0]
142+
%3 = cc.load %2 : !cc.ptr<i64>
143+
144+
// arr2 = [0]
145+
%4 = cc.alloca !cc.array<i64 x 1>
146+
%5 = cc.cast %4 : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
147+
cc.store %c0_i64, %5 : !cc.ptr<i64> // Dominates the next store, don't lift
148+
149+
// arr2[0] = t
150+
cc.store %3, %5 : !cc.ptr<i64>
151+
152+
// b = arr2[0]
153+
%6 = cc.load %5 : !cc.ptr<i64>
154+
155+
// x(qubits[b])
156+
%7 = quake.extract_ref %0[%6] : (!quake.veq<2>, i64) -> !quake.ref
157+
quake.x %7 : (!quake.ref) -> ()
158+
return
159+
}
160+
161+
// CHECK-LABEL: func.func @test_two_stores() {
162+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i64
163+
// CHECK: %[[VAL_1:.*]] = quake.alloca !quake.veq<2>
164+
// CHECK: %[[VAL_2:.*]] = cc.const_array [1] : !cc.array<i64 x 1>
165+
// CHECK: %[[VAL_3:.*]] = cc.extract_value %[[VAL_2]][0] : (!cc.array<i64 x 1>) -> i64
166+
// CHECK: %[[VAL_4:.*]] = cc.alloca !cc.array<i64 x 1>
167+
// CHECK: %[[VAL_5:.*]] = cc.cast %[[VAL_4]] : (!cc.ptr<!cc.array<i64 x 1>>) -> !cc.ptr<i64>
168+
// CHECK: cc.store %[[VAL_0]], %[[VAL_5]] : !cc.ptr<i64>
169+
// CHECK: cc.store %[[VAL_3]], %[[VAL_5]] : !cc.ptr<i64>
170+
// CHECK: %[[VAL_6:.*]] = cc.load %[[VAL_5]] : !cc.ptr<i64>
171+
// CHECK: %[[VAL_7:.*]] = quake.extract_ref %[[VAL_1]][%[[VAL_6]]] : (!quake.veq<2>, i64) -> !quake.ref
172+
// CHECK: quake.x %[[VAL_7]] : (!quake.ref) -> ()
173+
// CHECK: return
174+
// CHECK: }
175+
176+
func.func @test_complex_array() {
177+
%cst = complex.constant [0.000000e+00 : f32, 1.000000e+00 : f32] : complex<f32>
178+
%cst_0 = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
179+
%0 = cc.alloca !cc.array<complex<f32> x 2>
180+
%1 = cc.cast %0 : (!cc.ptr<!cc.array<complex<f32> x 2>>) -> !cc.ptr<complex<f32>>
181+
cc.store %cst_0, %1 : !cc.ptr<complex<f32>>
182+
%2 = cc.compute_ptr %0[1] : (!cc.ptr<!cc.array<complex<f32> x 2>>) -> !cc.ptr<complex<f32>>
183+
cc.store %cst, %2 : !cc.ptr<complex<f32>>
184+
%3 = quake.alloca !quake.veq<1>
185+
%4 = quake.init_state %3, %1 : (!quake.veq<1>, !cc.ptr<complex<f32>>) -> !quake.veq<1>
186+
return
187+
}
188+
189+
// CHECK-LABEL: func.func @test_complex_array() {
190+
// CHECK: %[[VAL_0:.*]] = cc.const_array {{\[}}[1.000000e+00 : f32, 0.000000e+00 : f32], [0.000000e+00 : f32, 1.000000e+00 : f32]{{\]}} : !cc.array<complex<f32> x 2>
191+
// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array<complex<f32> x 2>
192+
// CHECK: cc.store %[[VAL_0]], %[[VAL_1]] : !cc.ptr<!cc.array<complex<f32> x 2>>
193+
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!cc.array<complex<f32> x 2>>) -> !cc.ptr<complex<f32>>
194+
// CHECK: %[[VAL_3:.*]] = quake.alloca !quake.veq<1>
195+
// CHECK: %[[VAL_4:.*]] = quake.init_state %[[VAL_3]], %[[VAL_2]] : (!quake.veq<1>, !cc.ptr<complex<f32>>) -> !quake.veq<1>
196+
// CHECK: return
197+
// CHECK: }

0 commit comments

Comments
 (0)