-
Notifications
You must be signed in to change notification settings - Fork 1k
graph: utils: pm: support commutative select #3166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -252,7 +254,8 @@ bool match_node_inputs(const binding_t &b, match_context_t *ctx, | |||
if (node_inputs_matcher.get_node()->get_inputs().size() | |||
== VARIADIC_INPUT_NUM) { | |||
matching_status = node_inputs_matcher.match_variadic_inputs(); | |||
} else if (!has_commutative_inputs(node_inputs_matcher.get_op())) { | |||
} else if (!has_commutative_inputs(node_inputs_matcher.get_op(), | |||
{b.bind_op_port, b.bind_port})) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering, for select, whose first input is non-commutative, and second & third inputs are commutative, how can we make sure all the inputs are correctly visited and matched?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When bind_node
is select
, bind_port
is 2 and bind_op_port
is 1 which conform to commutative inputs = {1, 2}
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that. I mean how is the 0-th input matched? As it's non-commutative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the first or the second input will swap with the first one? I doubt whether that's a valid graph, since the data type of the 0-th input is boolean after all. As for other cases, like when all three input data types are the same but only two are commutative, I don't think we have such a requirement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added gtest for verifying match_node_inputs()
7ffd645
to
b5fc799
Compare
please format the commit message with correct component, graph: interface: ..., graph: utils: pm: ... |
Thanks for letting me know. I split the commit into three parts based on file paths, using the directory structure as the distinction: |
b5fc799
to
26cd4b3
Compare
26cd4b3
to
7e9349a
Compare
7e9349a
to
37c9cd8
Compare
// MatMul Add | ||
// \ / | ||
// Select | ||
auto pmatmul = graphp->append_op(MatMul); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add another input for select's 0-th input (e.g. GreaterEqual), we want to know if the three inputs (both commutative and non-commutative) can all be matched
agraph.finalize(); | ||
|
||
std::vector<op_t *> fusion_ops; | ||
EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here we need to match from the select op, to verify our assumption.
Description
The original setting for "commutative" was limited to operators with only two inputs. However, the select operator has three inputs, and only two of them are commutative.
A new member variable,
std::set<size_t> commutative_inputs
, has been added toop_schema
to store the indices of commutative inputs. The functionset_commutative_inputs(std::set<size_t> inputs = {0, 1})
uses default parameters to maintain support for the original commutative operators.In
nested_matcher.cpp
, the functionshas_commutative_inputs
,match_commutative_inputs
, andmatch_node_inputs
have been updated to include checks for whether a port belongs to the commutative inputs.This PR is for fix MFDNN-13383. Another method is also implemented here, which is achieved by swapping the mapping method of ports inside and outside the subgraph.