-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[CINN] CINN supports arange for fusion #72598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self-reviewed, conflicts resolved.
@@ -1,3 +1,11 @@ | |||
- op : arange | |||
args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 为什么需要 ScalarType 参数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两种类似意义参数同时存在是出于如下考虑:
- Tensor 参数需要用于 symbolic shape 推导(且仅仅使用在 symbolic shape 推导中),对于 CINN 而言 arange 目前只支持静态的配置,所以不会考虑从 Tensor 中取 data
- ScalarType 是为了让 op 自动生成时生成 phi::Scalar 类型输入参数,phi::Scalar 可以方便取出内部的值(以计算静态的shape大小)以及进行多种类型的转换。原来的实现就直接用了 float,觉得直接用 float 或者某个固定的类型会引入正确性的问题(比如极端值输入或者精度问题)。相当于以 ScalarType 作为一个擦除了类型的输入。而op gen没有提供可以直接产生 phi::Scalar 的接口,所以这里加了一个。
* SourceOpT and TargetOpT should be the derived class of pir::Op | ||
*/ | ||
template <typename TargetOpT, typename SourceOpT> | ||
bool IsDefinedBy(const SourceOpT& op, const size_t idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- SourceOp 看起来不需要使用模板类型
- idx的参数命名看不出来是op的input index,作为公共函数最好写的语义清晰一些。或者直接传Value进来也行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方第一点我查了一下,由于输入的实际类型是 paddle::dialect::XXXOp(具体的 Op类型在 pd_op.h 中定义的),所以处理成 SourceOpT 和 paddle::dialect::XXXOp 的基类指针(OpBase)应该都可以。IsDefinedBy 这个函数实际上是从 pd_to_cinn_pass.cc 中直接摘过来的,所以没有修改其原来的写法。
rewriter.Build<cinn::dialect::ArangeOp>(op->operand_source(0), | ||
op->operand_source(1), | ||
op->operand_source(2), | ||
input_list[0], | ||
input_list[1], | ||
input_list[2], | ||
dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CINN的ArangeOp是否可以只传入Tensor处理?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前考虑过,但发现改动起来比现在更复杂:需要能判断 Tensor 符合输入要求并且可以取出 CINN 可直接使用的值(后面这一步我不熟悉怎么做,Tensor 看起来比较 meta,就只保存了tensor的语义信息),如果有办法可以直接从 Tensor 中取出静态的值就可能可以化简这里的写法,但还是需要测试的,此前用Tensor传入后出现过算子被 CINN denied 的现象,原因就是动态的 tensor 导致了 new symbolic shape。
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #72598 +/- ##
===========================================
Coverage ? 100.00%
===========================================
Files ? 2
Lines ? 24
Branches ? 0
===========================================
Hits ? 24
Misses ? 0
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PR Category
CINN
PR Types
Improvements
Description
Introduction
(This is the re-opened version of #72447, #72447 is closed due to bad git rebase and merge, caused by the removal of ABSL dependency introduced in #72546)
This PR adds support for
paddle.arange
, so that it can participate in backend op fusion (instead of using phi kernels). However, the participating arange op should satisfy the following constraint: all the inputs should be static, i.e.,start
,end
andstep
should be statically known. The following example illustrates this constraint:dynamic_arange
won't be converted to a CINN func, since the second param will cause new symbolic shape.static_arange_full
andstatic_arange_with_defaults
will result incinn_op.arange
, since the inputs can be implicitly converted to aFullOp
. The matching condition forpd_op.arange
to be rewritten ascinn_op.arange
is that all three input params havedefining_op
asFullOp
.Other modifications:
ops.yaml
for CINN, a new typeScalarType
is introduced. Originally,Scalar
written in the config file will be converted tofloat
, making the generated op dialects have pre-determined types forBuild
function, which might not be good.ScalarType
will be converted tophi::Scalar
, it retains dtype info, while it is not explicitly typed.ConvertAttribute
function will automatically convert thephi::Scalar
according to its underlying dtype (to a variant). Therefore, frontend (StrategyForXXXSymbolic
) can extract useful type info to instantiate properExpr
.Notes on Implementation
In
ops.yaml
, there might be something confusing like the following:One might wonder why we need
Tensor start_
andScalarType start
. This is due to the reason that:Tensor
typed input must be present in thecinn_op.arange
, otherwise, during symbolic shape inference, since the current code relies onop->operand_source(index)
, if there is noTensor
typed input (converted topir::Value
), the CINN API will be considered to have no input, thus resulting in out-of-range indexing for operand source.ScalarType
typed input should be given, in order to calculate the static tensor shape forelementwise.cc
CINN symbolic strategy. The input should be statically know-able, otherwise it has risks for introducing new symbolic shape and getting denied for CINN.cinn_op.arange
in theops.yaml
, theScalarType
typed three will actually be converted to function attribute and gets extracted by node attributes in the CINN end.Pcard-89620