Skip to content

[CINN] CINN supports arange for fusion #72598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 9, 2025
Merged

Conversation

Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented May 7, 2025

PR Category

CINN

PR Types

Improvements

Description

Introduction

(This is the re-opened version of #72447, #72447 is closed due to bad git rebase and merge, caused by the removal of ABSL dependency introduced in #72546)
This PR adds support for paddle.arange, so that it can participate in backend op fusion (instead of using phi kernels). However, the participating arange op should satisfy the following constraint: all the inputs should be static, i.e., start, end and step should be statically known. The following example illustrates this constraint:

def dynamic_arange():
    @paddle.jit.to_static(backend='CINN', full_graph=True)
    def func(x):
        return paddle.arange(0, x, 1).astype("float32") * 0.2 + 1

    x =(paddle.rand([1]) + 1) * 20
    func(x)

def static_arange_full():
    @paddle.jit.to_static(backend='CINN', full_graph=True)
    def func():
        return paddle.arange(0, 32, 1).astype("float32") * 0.2 + 1
    func()

def static_arange_with_defaults():
    @paddle.jit.to_static(backend='CINN', full_graph=True)
    def func():
        return paddle.arange(32) * 2 + 1
    func()
  • dynamic_arange won't be converted to a CINN func, since the second param will cause new symbolic shape.
  • static_arange_full and static_arange_with_defaults will result in cinn_op.arange, since the inputs can be implicitly converted to a FullOp. The matching condition for pd_op.arange to be rewritten as cinn_op.arange is that all three input params have defining_op as FullOp.

Other modifications:

  • In ops.yaml for CINN, a new type ScalarType is introduced. Originally, Scalar written in the config file will be converted to float, making the generated op dialects have pre-determined types for Build function, which might not be good. ScalarType will be converted to phi::Scalar, it retains dtype info, while it is not explicitly typed. ConvertAttribute function will automatically convert the phi::Scalar according to its underlying dtype (to a variant). Therefore, frontend (StrategyForXXXSymbolic) can extract useful type info to instantiate proper Expr.
Notes on Implementation

In ops.yaml, there might be something confusing like the following:

args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64)

One might wonder why we need Tensor start_ and ScalarType start. This is due to the reason that:

  • Tensor typed input must be present in the cinn_op.arange, otherwise, during symbolic shape inference, since the current code relies on op->operand_source(index), if there is no Tensor typed input (converted to pir::Value), the CINN API will be considered to have no input, thus resulting in out-of-range indexing for operand source.
  • ScalarType typed input should be given, in order to calculate the static tensor shape for elementwise.cc CINN symbolic strategy. The input should be statically know-able, otherwise it has risks for introducing new symbolic shape and getting denied for CINN.
  • Though it seems that there are 6 inputs for cinn_op.arange in the ops.yaml, the ScalarType typed three will actually be converted to function attribute and gets extracted by node attributes in the CINN end.

Pcard-89620

Copy link

paddle-bot bot commented May 7, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Enigmatisms Enigmatisms marked this pull request as ready for review May 7, 2025 06:36
Copy link
Contributor Author

@Enigmatisms Enigmatisms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self-reviewed, conflicts resolved.

lshpku
lshpku previously approved these changes May 8, 2025
@@ -1,3 +1,11 @@
- op : arange
args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 为什么需要 ScalarType 参数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两种类似意义参数同时存在是出于如下考虑:

  • Tensor 参数需要用于 symbolic shape 推导(且仅仅使用在 symbolic shape 推导中),对于 CINN 而言 arange 目前只支持静态的配置,所以不会考虑从 Tensor 中取 data
  • ScalarType 是为了让 op 自动生成时生成 phi::Scalar 类型输入参数,phi::Scalar 可以方便取出内部的值(以计算静态的shape大小)以及进行多种类型的转换。原来的实现就直接用了 float,觉得直接用 float 或者某个固定的类型会引入正确性的问题(比如极端值输入或者精度问题)。相当于以 ScalarType 作为一个擦除了类型的输入。而op gen没有提供可以直接产生 phi::Scalar 的接口,所以这里加了一个。

* SourceOpT and TargetOpT should be the derived class of pir::Op
*/
template <typename TargetOpT, typename SourceOpT>
bool IsDefinedBy(const SourceOpT& op, const size_t idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. SourceOp 看起来不需要使用模板类型
  2. idx的参数命名看不出来是op的input index,作为公共函数最好写的语义清晰一些。或者直接传Value进来也行

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方第一点我查了一下,由于输入的实际类型是 paddle::dialect::XXXOp(具体的 Op类型在 pd_op.h 中定义的),所以处理成 SourceOpT 和 paddle::dialect::XXXOp 的基类指针(OpBase)应该都可以。IsDefinedBy 这个函数实际上是从 pd_to_cinn_pass.cc 中直接摘过来的,所以没有修改其原来的写法。

Comment on lines +564 to +570
rewriter.Build<cinn::dialect::ArangeOp>(op->operand_source(0),
op->operand_source(1),
op->operand_source(2),
input_list[0],
input_list[1],
input_list[2],
dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CINN的ArangeOp是否可以只传入Tensor处理?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前考虑过,但发现改动起来比现在更复杂:需要能判断 Tensor 符合输入要求并且可以取出 CINN 可直接使用的值(后面这一步我不熟悉怎么做,Tensor 看起来比较 meta,就只保存了tensor的语义信息),如果有办法可以直接从 Tensor 中取出静态的值就可能可以化简这里的写法,但还是需要测试的,此前用Tensor传入后出现过算子被 CINN denied 的现象,原因就是动态的 tensor 导致了 new symbolic shape。

@lshpku lshpku merged commit bc5df86 into PaddlePaddle:develop May 9, 2025
48 of 49 checks passed
@Enigmatisms Enigmatisms deleted the arange branch May 9, 2025 03:21
@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (develop@a890738). Learn more about missing BASE report.

Additional details and impacted files
@@             Coverage Diff             @@
##             develop    #72598   +/-   ##
===========================================
  Coverage           ?   100.00%           
===========================================
  Files              ?         2           
  Lines              ?        24           
  Branches           ?         0           
===========================================
  Hits               ?        24           
  Misses             ?         0           
  Partials           ?         0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants