-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[CINN] CINN supports arange for fusion #72598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self-reviewed, conflicts resolved.
@@ -1,3 +1,11 @@ | |||
- op : arange | |||
args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 为什么需要 ScalarType 参数?
@@ -1,3 +1,11 @@ | |||
- op : arange | |||
args : (Tensor start_, Tensor end_, Tensor step_, ScalarType start, ScalarType end, ScalarType step, DataType dtype=DataType::INT64) | |||
output : Tensor(out)x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x?
* SourceOpT and TargetOpT should be the derived class of pir::Op | ||
*/ | ||
template <typename TargetOpT, typename SourceOpT> | ||
bool IsDefinedBy(const SourceOpT& op, const size_t idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- SourceOp 看起来不需要使用模板类型
- idx的参数命名看不出来是op的input index,作为公共函数最好写的语义清晰一些。或者直接传Value进来也行
GET_SIZE_GIVEN_TYPE(double) | ||
case DataType::INT32: | ||
GET_SIZE_GIVEN_TYPE(int) | ||
default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否需要考虑fp16的类型?
rewriter.Build<cinn::dialect::ArangeOp>(op->operand_source(0), | ||
op->operand_source(1), | ||
op->operand_source(2), | ||
input_list[0], | ||
input_list[1], | ||
input_list[2], | ||
dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CINN的ArangeOp是否可以只传入Tensor处理?
PR Category
CINN
PR Types
Improvements
Description
Introduction
(This is the re-opened version of #72447, #72447 is closed due to bad git rebase and merge, caused by the removal of ABSL dependency introduced in #72546)
This PR adds support for
paddle.arange
, so that it can participate in backend op fusion (instead of using phi kernels). However, the participating arange op should satisfy the following constraint: all the inputs should be static, i.e.,start
,end
andstep
should be statically known. The following example illustrates this constraint:dynamic_arange
won't be converted to a CINN func, since the second param will cause new symbolic shape.static_arange_full
andstatic_arange_with_defaults
will result incinn_op.arange
, since the inputs can be implicitly converted to aFullOp
. The matching condition forpd_op.arange
to be rewritten ascinn_op.arange
is that all three input params havedefining_op
asFullOp
.Other modifications:
ops.yaml
for CINN, a new typeScalarType
is introduced. Originally,Scalar
written in the config file will be converted tofloat
, making the generated op dialects have pre-determined types forBuild
function, which might not be good.ScalarType
will be converted tophi::Scalar
, it retains dtype info, while it is not explicitly typed.ConvertAttribute
function will automatically convert thephi::Scalar
according to its underlying dtype (to a variant). Therefore, frontend (StrategyForXXXSymbolic
) can extract useful type info to instantiate properExpr
.Notes on Implementation
In
ops.yaml
, there might be something confusing like the following:One might wonder why we need
Tensor start_
andScalarType start
. This is due to the reason that:Tensor
typed input must be present in thecinn_op.arange
, otherwise, during symbolic shape inference, since the current code relies onop->operand_source(index)
, if there is noTensor
typed input (converted topir::Value
), the CINN API will be considered to have no input, thus resulting in out-of-range indexing for operand source.ScalarType
typed input should be given, in order to calculate the static tensor shape forelementwise.cc
CINN symbolic strategy. The input should be statically know-able, otherwise it has risks for introducing new symbolic shape and getting denied for CINN.cinn_op.arange
in theops.yaml
, theScalarType
typed three will actually be converted to function attribute and gets extracted by node attributes in the CINN end.Pcard-89620