Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions lib/Optimizer/Transforms/WriteAfterWriteElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,11 @@ class SimplifyWritesAnalysis {
for (const auto &[ptr, stores] : ptrToStores) {
if (stores.size() > 1) {
auto replacement = stores.back();
for (auto it = stores.rend(); it != stores.rbegin(); it++) {
auto store = *it;
if (isReplacement(ptr, *store, *replacement)) {
LLVM_DEBUG(llvm::dbgs() << "replacing store " << store
<< " by: " << replacement << '\n');
toErase.push_back(store->getOperation());
for (auto *store : stores) {
if (isReplacement(ptr, store, replacement)) {
LLVM_DEBUG(llvm::dbgs() << "replacing store " << *store
<< " by: " << *replacement << '\n');
toErase.push_back(store);
}
}
}
Expand All @@ -72,13 +71,18 @@ class SimplifyWritesAnalysis {

private:
/// Detect if value is used in the op or its nested blocks.
bool isReplacement(Value ptr, cudaq::cc::StoreOp store,
cudaq::cc::StoreOp replacement) const {
// Check that there are no stores dominated by the store and not dominated
// by the replacement (i.e. used in between the store and the replacement)
for (auto *user : ptr.getUsers()) {
bool isReplacement(Operation *ptr, Operation *store,
Operation *replacement) const {
if (store == replacement)
return false;

// Check that there are no non-store uses dominated by the store and
// not dominated by the replacement, i.e. only uses between the two
// stores are other stores to the same pointer.
for (auto *user : ptr->getUsers()) {
if (user != store && user != replacement) {
if (dom.dominates(store, user) && !dom.dominates(replacement, user)) {
if (!isStoreToPtr(user, ptr) && dom.dominates(store, user) &&
!dom.dominates(replacement, user)) {
LLVM_DEBUG(llvm::dbgs() << "store " << replacement
<< " is used before: " << store << '\n');
return false;
Expand All @@ -88,6 +92,13 @@ class SimplifyWritesAnalysis {
return true;
}

/// Detects a store to the pointer.
static bool isStoreToPtr(Operation *op, Operation *ptr) {
return isa_and_present<cudaq::cc::StoreOp>(op) &&
(dyn_cast<cudaq::cc::StoreOp>(op).getPtrvalue().getDefiningOp() ==
ptr);
}

/// Collect all stores to a pointer for a block.
void collectBlockInfo(Block *block) {
for (auto &op : *block) {
Expand All @@ -96,11 +107,11 @@ class SimplifyWritesAnalysis {
collectBlockInfo(&b);

if (auto store = dyn_cast<cudaq::cc::StoreOp>(&op)) {
auto ptr = store.getPtrvalue();
auto ptr = store.getPtrvalue().getDefiningOp();
if (isStoreToStack(store)) {
auto ptrToStores = blockInfo.FindAndConstruct(block).second;
auto stores = ptrToStores.FindAndConstruct(ptr).second;
stores.push_back(&store);
auto &[b, ptrToStores] = blockInfo.FindAndConstruct(block);
auto &[p, stores] = ptrToStores.FindAndConstruct(ptr);
stores.push_back(&op);
}
}
}
Expand Down Expand Up @@ -128,8 +139,7 @@ class SimplifyWritesAnalysis {
}

DominanceInfo &dom;
DenseMap<Block *, DenseMap<Value, SmallVector<cudaq::cc::StoreOp *>>>
blockInfo;
DenseMap<Block *, DenseMap<Operation *, SmallVector<Operation *>>> blockInfo;
};

class WriteAfterWriteEliminationPass
Expand Down
22 changes: 13 additions & 9 deletions test/Quake/write_after_write_elimination.qke
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

// RUN: cudaq-opt -write-after-write-elimination %s | FileCheck %s


func.func @test_two_stores_same_pointer() {
%c0_i64 = arith.constant 0 : i64
%0 = quake.alloca !quake.veq<2>
Expand Down Expand Up @@ -63,23 +64,26 @@ func.func @test_two_stores_different_pointers() {
func.func @test_two_stores_same_pointer_interleaving() {
%c0_i64 = arith.constant 0 : i64
%c1_i64 = arith.constant 1 : i64
%1 = cc.alloca !cc.array<i64 x 2>
%2 = cc.cast %1 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
%c2_i64 = arith.constant 2 : i64
%0 = cc.alloca !cc.array<i64 x 2>
%1 = cc.cast %0 : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %1 : !cc.ptr<i64>
%2 = cc.compute_ptr %0[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c2_i64, %1 : !cc.ptr<i64>
cc.store %c0_i64, %2 : !cc.ptr<i64>
%3 = cc.compute_ptr %1[1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
cc.store %c0_i64, %3 : !cc.ptr<i64>
cc.store %c1_i64, %1 : !cc.ptr<i64>
cc.store %c1_i64, %2 : !cc.ptr<i64>
cc.store %c1_i64, %3 : !cc.ptr<i64>
return
}

// CHECK-LABEL: func.func @test_two_stores_same_pointer_interleaving() {
// CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_c0:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_c1:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_c2:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_1:.*]] = cc.alloca !cc.array<i64 x 2>
// CHECK: %[[VAL_2:.*]] = cc.cast %[[VAL_1]] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: %[[VAL_3:.*]] = cc.compute_ptr %[[VAL_1]][1] : (!cc.ptr<!cc.array<i64 x 2>>) -> !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_2]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_0]], %[[VAL_3]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_c1]], %[[VAL_2]] : !cc.ptr<i64>
// CHECK: cc.store %[[VAL_c1]], %[[VAL_3]] : !cc.ptr<i64>
// CHECK: return
// CHECK: }

Loading