Skip to content

[CINN] Split complex index #72568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ ExprSetFinder ChildIfThenElses =
ExprSetFinder ChildVars =
Collector([](const ir::Expr* e) { return e->as_var(); }, "ChildVars");

ExprSetFinder ChildBlocks = Collector(
[](const ir::Expr* e) { return e->As<ir::Block>(); }, "ChildBlocks");

ExprSetFinder FindFather(const ir::Expr& root) {
const auto& f = [root](const auto& child) -> ExprSet {
ExprSetFinder find_child =
Expand Down Expand Up @@ -1224,6 +1227,60 @@ void InlineGlobalVarCompute(const std::vector<ir::Expr>& roots,
}
}

int IndexVarCounter() {
static thread_local std::atomic<int> counter = 1;
return counter++;
}

void SplitComplexIndexExpr(const std::vector<ir::Expr>& roots) {
auto search_body_blocks = ExprSetFinderUtils::ChildScheduleBlockRealizes *
ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit *
ExprSetFinderUtils::ChildBlocks;
auto search_tensor_loads = ExprSetFinderUtils::ChildStores *
ExprSetFinderUtils::Store2Value *
ExprSetFinderUtils::ChildTensorLoads;
for (auto root : roots) {
if (IsReducePattern(root)) continue;
auto body_block_expr = search_body_blocks(root).back();
auto body_block = body_block_expr.As<ir::Block>();
// Visit every tensor load in reverse topological order. Once a tensor
// load's indices are splited, other tensor loads may be affected, so we
// need to search tensor loads again until no tensor load index is splited.
std::unordered_set<ir::Expr> visited_tensor_loads;
auto tensor_loads = search_tensor_loads(body_block);
while (!tensor_loads.empty()) {
auto tensor_load = tensor_loads.back();
tensor_loads.pop_back();
if (visited_tensor_loads.count(tensor_load)) continue;
visited_tensor_loads.insert(tensor_load);

auto indices = tensor_load.As<ir::Load>()->indices;
bool need_split = false;
std::vector<ir::Expr> new_indices;
for (size_t i = 0; i < indices.size(); ++i) {
auto index = indices[i];
if (index.is_index()) {
new_indices.push_back(index);
continue;
}
need_split = true;
ir::Var index_var = ir::_Var_::Make(
"index_var_" + std::to_string(IndexVarCounter()), index.type());
auto index_let_expr = ir::Let::Make(index_var, index);
body_block->stmts.insert(body_block->stmts.end() - 1, index_let_expr);
new_indices.push_back(index_var);
}
// Replace index with index_var
if (!need_split) continue;
auto new_tensor_load = ir::ir_utils::IRCopy(tensor_load);
new_tensor_load.As<ir::Load>()->indices = new_indices;
ComposeUtils::MappingTargetExprToDestExprMutator(
tensor_load, new_tensor_load)(&body_block_expr);
tensor_loads = search_tensor_loads(body_block);
}
}
}

} // namespace trivial_fusion_detail
} // namespace pir
} // namespace framework
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/framework/pir/trivial_op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ ir::Tensor GetOutputTensor(const ir::Expr& root);
void InlineGlobalVarCompute(const std::vector<ir::Expr>& roots,
const std::set<std::string>& global_var_names);

void SplitComplexIndexExpr(const std::vector<ir::Expr>& roots);

} // namespace trivial_fusion_detail
} // namespace pir
} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ struct IterSpaceGetter {
: load_(load), loops_(loops), indices_vars_(load->indices.size()) {
for (int i = 0; i < load_->indices.size(); ++i) {
ir::ir_utils::CollectIRNodes(load_->indices[i], [&](const ir::Expr* x) {
if (x->is_var() && !x->as_var()->is_symbolic_constant) {
if (x->is_var() && !x->as_var()->is_symbolic_constant &&
x->as_var()->name.find(analyzer::kLoopVar) != std::string::npos) {
indices_vars_[i].insert(x->as_var_ref());
}
return false;
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ void RunReturnInstr(const std::shared_ptr<ReturnInstr>& instr,
}
// Inline global vars
InlineGlobalVarCompute(result, interpreter->global_var_names);
SplitComplexIndexExpr(result);
interpreter->ret_expr = result;
}

Expand Down
14 changes: 14 additions & 0 deletions test/ir/pir/cinn/test_anchor_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,20 @@ def init():

self.check_accuracy_and_kernel_num(init, func)

def test_concat_gather(self):
def func(x, y, z):
u = paddle.concat([y, z], axis=0) + 32
v = paddle.gather(x, u, axis=0)
return v * 3, paddle.sum(v, axis=0)

def init():
x = paddle.rand((128, 256))
y = paddle.randint(0, 64, [32], dtype="int64")
z = paddle.randint(0, 64, [32], dtype="int64")
return (x, y, z)

self.check_accuracy_and_kernel_num(init, func)


if __name__ == "__main__":
unittest.main()
Loading