Skip to content

[CINN] CINN supports arange for fusion #72598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from

Conversation

Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented May 7, 2025

PR Category

CINN

PR Types

Improvements

Description

Introduction

(This is the re-opened version of #72447, #72447 is closed due to bad git rebase and merge, caused by the removal of ABSL dependency introduced in #72546)
This PR adds support for paddle.arange, so that it can participate in backend op fusion (instead of using phi kernels). However, the participating arange op should satisfy the following constraint: all the inputs should be static, i.e., start, end and step should be statically known. The following example illustrates this constraint:

def dynamic_arange():
    @paddle.jit.to_static(backend='CINN', full_graph=True)
    def func(x):
        return paddle.arange(0, x, 1).astype("float32") * 0.2 + 1

    x =(paddle.rand([1]) + 1) * 20
    func(x)

def static_arange_full():
    @paddle.jit.to_static(backend='CINN', full_graph=True)
    def func():
        return paddle.arange(0, 32, 1).astype("float32") * 0.2 + 1
    func()

def static_arange_with_defaults():
    @paddle.jit.to_static(backend='CINN', full_graph=True)
    def func():
        return paddle.arange(32) * 2 + 1
    func()
  • dynamic_arange won't be converted to a CINN func, since the second param will cause new symbolic shape.
  • static_arange_full and static_arange_with_defaults will result in cinn_op.arange, since the inputs can be implicitly converted to a FullOp. The matching condition for pd_op.arange to be rewritten as cinn_op.arange is that all three input params have defining_op as FullOp.

Other modifications:

  • In ops.yaml for CINN, a new type ScalarType is introduced. Originally, Scalar written in the config file will be converted to float, making the generated op dialects have pre-determined types for Build function, which might not be good. ScalarType will be converted to phi::Scalar, it retains dtype info, while it is not explicitly typed. ConvertAttribute function will automatically convert the phi::Scalar according to its underlying dtype (to a variant). Therefore, frontend (StrategyForXXXSymbolic) can extract useful type info to instantiate proper Expr.
Notes on Implementation

In ops.yaml, there might be something confusing like the following:

args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64)

One might wonder why we need Tensor start_ and ScalarType start. This is due to the reason that:

  • Tensor typed input must be present in the cinn_op.arange, otherwise, during symbolic shape inference, since the current code relies on op->operand_source(index), if there is no Tensor typed input (converted to pir::Value), the CINN API will be considered to have no input, thus resulting in out-of-range indexing for operand source.
  • ScalarType typed input should be given, in order to calculate the static tensor shape for elementwise.cc CINN symbolic strategy. The input should be statically know-able, otherwise it has risks for introducing new symbolic shape and getting denied for CINN.
  • Though it seems that there are 6 inputs for cinn_op.arange in the ops.yaml, the ScalarType typed three will actually be converted to function attribute and gets extracted by node attributes in the CINN end.

Pcard-89620

Copy link

paddle-bot bot commented May 7, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Enigmatisms Enigmatisms marked this pull request as ready for review May 7, 2025 06:36
Copy link
Contributor Author

@Enigmatisms Enigmatisms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self-reviewed, conflicts resolved.

@@ -1,3 +1,11 @@
- op : arange
args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 为什么需要 ScalarType 参数?

@@ -1,3 +1,11 @@
- op : arange
args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64)
output : Tensor(out)x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x?

* SourceOpT and TargetOpT should be the derived class of pir::Op
*/
template <typename TargetOpT, typename SourceOpT>
bool IsDefinedBy(const SourceOpT& op, const size_t idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. SourceOp 看起来不需要使用模板类型
  2. idx的参数命名看不出来是op的input index,作为公共函数最好写的语义清晰一些。或者直接传Value进来也行

GET_SIZE_GIVEN_TYPE(double)
case DataType::INT32:
GET_SIZE_GIVEN_TYPE(int)
default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要考虑fp16的类型?

Comment on lines +564 to +570
rewriter.Build<cinn::dialect::ArangeOp>(op->operand_source(0),
op->operand_source(1),
op->operand_source(2),
input_list[0],
input_list[1],
input_list[2],
dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CINN的ArangeOp是否可以只传入Tensor处理?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants