Skip to content

[CINN]Add remove_redundant_full_ops_pass #72554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
5 changes: 2 additions & 3 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/reduce_as_to_sum_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_assign_out_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_redundant_full_int_array_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_redundant_full_ops_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_redundant_group_output_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_zero_scale_to_full_pass.h"
Expand Down Expand Up @@ -286,8 +286,7 @@ void ApplyCinnLowerPass(
}
pass_manager->AddPass(
cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass());
pass_manager->AddPass(
cinn::dialect::ir::CreateRemoveRedundantFullIntArrayPass());
pass_manager->AddPass(cinn::dialect::ir::CreateRemoveRedundantFullOpsPass());
pass_manager->Run(program);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/remove_redundant_full_int_array_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_redundant_full_ops_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand All @@ -23,23 +23,23 @@ namespace cinn {
namespace dialect {
namespace ir {

class RemoveRedundantFullIntArrayPattern
: public pir::OpRewritePattern<paddle::dialect::FullIntArrayOp> {
template <class OPTYPE>
class RemoveRedundantFullOpsPattern : public pir::OpRewritePattern<OPTYPE> {
public:
using pir::OpRewritePattern<
paddle::dialect::FullIntArrayOp>::OpRewritePattern;
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::FullIntArrayOp op,
bool MatchAndRewrite(OPTYPE op,
pir::PatternRewriter& rewriter) const override {
auto block = op->GetParent();
auto* block = op->GetParent();
if (!block) return false;

pir::AttributeMap attrs = op->attributes();
auto dtype = attrs.at("dtype");
auto value = attrs.at("value");
auto place = attrs.at("place");

for (auto other_op = block->begin(); other_op != block->end(); ++other_op) {
if (!other_op->isa<paddle::dialect::FullIntArrayOp>()) continue;
if (!other_op->template isa<OPTYPE>()) continue;
if (op.operation() == &(*other_op)) break;
pir::AttributeMap other_attrs = other_op->attributes();
if (dtype != other_attrs.at("dtype") || place != other_attrs.at("place"))
Expand All @@ -56,14 +56,18 @@ class RemoveRedundantFullIntArrayPattern
}
};

class RemoveRedundantFullIntArrayPass : public pir::PatternRewritePass {
class RemoveRedundantFullOpsPass : public pir::PatternRewritePass {
public:
RemoveRedundantFullIntArrayPass()
: pir::PatternRewritePass("remove_redundant_full_int_array_pass", 1) {}
RemoveRedundantFullOpsPass()
: pir::PatternRewritePass("remove_redundant_full_ops_pass", 1) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
pir::RewritePatternSet ps(context);
ps.Add<RemoveRedundantFullIntArrayPattern>(context);

ps.Add<RemoveRedundantFullOpsPattern<paddle::dialect::FullOp>>(context);
ps.Add<RemoveRedundantFullOpsPattern<paddle::dialect::FullIntArrayOp>>(
context);

return ps;
}

Expand All @@ -72,8 +76,8 @@ class RemoveRedundantFullIntArrayPass : public pir::PatternRewritePass {
}
};

std::unique_ptr<pir::Pass> CreateRemoveRedundantFullIntArrayPass() {
return std::make_unique<RemoveRedundantFullIntArrayPass>();
std::unique_ptr<pir::Pass> CreateRemoveRedundantFullOpsPass() {
return std::make_unique<RemoveRedundantFullOpsPass>();
}

} // namespace ir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace cinn {
namespace dialect {
namespace ir {

std::unique_ptr<pir::Pass> CreateRemoveRedundantFullIntArrayPass();
std::unique_ptr<pir::Pass> CreateRemoveRedundantFullOpsPass();

} // namespace ir
} // namespace dialect
Expand Down
Loading