diff --git a/paddle/cinn/optim/simplify_util.cc b/paddle/cinn/optim/simplify_util.cc index a9ae0b5391045c..0c02ff5ce9bb89 100644 --- a/paddle/cinn/optim/simplify_util.cc +++ b/paddle/cinn/optim/simplify_util.cc @@ -237,8 +237,14 @@ ir::IndexExpr::IndexType VerifyIndex(const ir::Expr &expr) { : ir::IndexExpr::IndexType::kInvalid; } case ir::IrNodeTy::Load: { - return expr.type().is_index_type() ? ir::IndexExpr::IndexType::kLoad - : ir::IndexExpr::IndexType::kInvalid; + if (!expr.type().is_index_type()) + return ir::IndexExpr::IndexType::kInvalid; + auto load = expr.As(); + for (const auto &indices : load->indices) { + if (VerifyIndex(indices) == ir::IndexExpr::IndexType::kInvalid) + return ir::IndexExpr::IndexType::kInvalid; + } + return ir::IndexExpr::IndexType::kLoad; } case ir::IrNodeTy::Cast: { ir::IndexExpr::IndexType result = VerifyIndex(expr->operand(0));