Skip to content

graph: utils: pm: support commutative select #3166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ShanSimu
Copy link
Contributor

@ShanSimu ShanSimu commented Apr 25, 2025

Description

The original setting for "commutative" was limited to operators with only two inputs. However, the select operator has three inputs, and only two of them are commutative.
A new member variable, std::set<size_t> commutative_inputs, has been added to op_schema to store the indices of commutative inputs. The function set_commutative_inputs(std::set<size_t> inputs = {0, 1}) uses default parameters to maintain support for the original commutative operators.
In nested_matcher.cpp, the functions has_commutative_inputs, match_commutative_inputs, and match_node_inputs have been updated to include checks for whether a port belongs to the commutative inputs.

This PR is for fix MFDNN-13383. Another method is also implemented here, which is achieved by swapping the mapping method of ports inside and outside the subgraph.

@ShanSimu ShanSimu added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Apr 25, 2025
@ShanSimu ShanSimu self-assigned this Apr 25, 2025
@ShanSimu ShanSimu requested a review from a team as a code owner April 25, 2025 02:16
@@ -252,7 +254,8 @@ bool match_node_inputs(const binding_t &b, match_context_t *ctx,
if (node_inputs_matcher.get_node()->get_inputs().size()
== VARIADIC_INPUT_NUM) {
matching_status = node_inputs_matcher.match_variadic_inputs();
} else if (!has_commutative_inputs(node_inputs_matcher.get_op())) {
} else if (!has_commutative_inputs(node_inputs_matcher.get_op(),
{b.bind_op_port, b.bind_port})) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering, for select, whose first input is non-commutative, and second & third inputs are commutative, how can we make sure all the inputs are correctly visited and matched?

Copy link
Contributor Author

@ShanSimu ShanSimu Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When bind_node is select, bind_port is 2 and bind_op_port is 1 which conform to commutative inputs = {1, 2}.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that. I mean how is the 0-th input matched? As it's non-commutative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the first or the second input will swap with the first one? I doubt whether that's a valid graph, since the data type of the 0-th input is boolean after all. As for other cases, like when all three input data types are the same but only two are commutative, I don't think we have such a requirement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added gtest for verifying match_node_inputs()

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 7ffd645 to b5fc799 Compare April 25, 2025 06:24
@ElaineBao
Copy link
Contributor

please format the commit message with correct component, graph: interface: ..., graph: utils: pm: ...

@ShanSimu
Copy link
Contributor Author

ShanSimu commented Apr 25, 2025

please format the commit message with correct component, graph: interface: ..., graph: utils: pm: ...

Thanks for letting me know. I split the commit into three parts based on file paths, using the directory structure as the distinction: graph: interface: ..., graph: utils: pm: ... and gtests: unit: ....

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from b5fc799 to 26cd4b3 Compare April 25, 2025 08:17
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Apr 25, 2025
@ShanSimu ShanSimu requested a review from ElaineBao April 25, 2025 08:18
@ShanSimu ShanSimu changed the title graph: support commutative select graph: pm: support commutative select Apr 25, 2025
@ShanSimu ShanSimu changed the title graph: pm: support commutative select graph: utils: pm: support commutative select Apr 25, 2025
@ShanSimu ShanSimu requested a review from a team April 25, 2025 09:01
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 26cd4b3 to 7e9349a Compare April 27, 2025 02:08
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 7e9349a to 37c9cd8 Compare April 27, 2025 08:01
// MatMul Add
// \ /
// Select
auto pmatmul = graphp->append_op(MatMul);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add another input for select's 0-th input (e.g. GreaterEqual), we want to know if the three inputs (both commutative and non-commutative) can all be matched

agraph.finalize();

std::vector<op_t *> fusion_ops;
EXPECT_TRUE(match_pattern(agraph.get_ops()[0].get(), graphp, fusion_ops));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here we need to match from the select op, to verify our assumption.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants