forked from NVIDIA/cuda-quantum
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLiftArrayAlloc.cpp
More file actions
56 lines (46 loc) · 1.97 KB
/
LiftArrayAlloc.cpp
File metadata and controls
56 lines (46 loc) · 1.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/*******************************************************************************
* Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/
#include "PassDetails.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
namespace cudaq::opt {
#define GEN_PASS_DEF_LIFTARRAYALLOC
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt
#define DEBUG_TYPE "lift-array-alloc"
using namespace mlir;
#include "LiftArrayAllocPatterns.inc"
namespace {
class LiftArrayAllocPass
: public cudaq::opt::impl::LiftArrayAllocBase<LiftArrayAllocPass> {
public:
using LiftArrayAllocBase::LiftArrayAllocBase;
void runOnOperation() override {
auto *ctx = &getContext();
auto func = getOperation();
DominanceInfo domInfo(func);
StringRef funcName = func.getName();
RewritePatternSet patterns(ctx);
patterns.insert<AllocaPattern>(ctx, domInfo, funcName);
LLVM_DEBUG(llvm::dbgs()
<< "Before lifting constant array: " << func << '\n');
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
signalPassFailure();
LLVM_DEBUG(llvm::dbgs()
<< "After lifting constant array: " << func << '\n');
}
};
} // namespace