Skip to content

graph: utils: pm: support commutative select #3166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 15, 2025
Merged

Conversation

ShanSimu
Copy link

@ShanSimu ShanSimu commented Apr 25, 2025

Description

The original setting for "commutative" was limited to operators with only two inputs. However, the select operator has three inputs, and only two of them are commutative.
A new member variable, std::set<size_t> commutative_inputs, has been added to op_schema to store the indices of commutative inputs. The function set_commutative_inputs(std::set<size_t> inputs = {0, 1}) uses default parameters to maintain support for the original commutative operators.
In nested_matcher.cpp, the functions has_commutative_inputs, match_commutative_inputs, and match_node_inputs have been updated to include checks for whether a port belongs to the commutative inputs.

This PR is for fix MFDNN-13383. Another method is also implemented here, which is achieved by swapping the mapping method of ports inside and outside the subgraph.

@ShanSimu ShanSimu added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Apr 25, 2025
@ShanSimu ShanSimu self-assigned this Apr 25, 2025
@ShanSimu ShanSimu requested a review from a team as a code owner April 25, 2025 02:16
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 7ffd645 to b5fc799 Compare April 25, 2025 06:24
@ElaineBao
Copy link
Contributor

please format the commit message with correct component, graph: interface: ..., graph: utils: pm: ...

@ShanSimu
Copy link
Author

ShanSimu commented Apr 25, 2025

please format the commit message with correct component, graph: interface: ..., graph: utils: pm: ...

Thanks for letting me know. I split the commit into three parts based on file paths, using the directory structure as the distinction: graph: interface: ..., graph: utils: pm: ... and gtests: unit: ....

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from b5fc799 to 26cd4b3 Compare April 25, 2025 08:17
@github-actions github-actions bot added the component:tests Codeowner: @oneapi-src/onednn-arch label Apr 25, 2025
@ShanSimu ShanSimu requested a review from ElaineBao April 25, 2025 08:18
@ShanSimu ShanSimu changed the title graph: support commutative select graph: pm: support commutative select Apr 25, 2025
@ShanSimu ShanSimu changed the title graph: pm: support commutative select graph: utils: pm: support commutative select Apr 25, 2025
@ShanSimu ShanSimu requested a review from a team April 25, 2025 09:01
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 26cd4b3 to 7e9349a Compare April 27, 2025 02:08
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 7e9349a to 37c9cd8 Compare April 27, 2025 08:01
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch 2 times, most recently from 7932537 to df94010 Compare April 29, 2025 03:25
@ShanSimu
Copy link
Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph


/*! @brief Get whether the commutative inputs option is enabled or not */
bool get_commutative_inputs() const;
bool get_commutative_inputs(const std::set<size_t> &inputs = {}) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a brief description for this function, check if the inputs set are all commutative inputs, if inputs is an empty set, check if the op has commutative inputs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default argument here is confusing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default argument here is to maintain compatibility with zero-argument usage(in nested_matcher.cpp#L73), where we simply check if the op has commutative inputs. It might make sense to split this into two separate functions:

  • is_commutative_inputs(inputs)
  • is_commutative_op()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from df94010 to 91fcec4 Compare April 30, 2025 05:20
@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 91fcec4 to 1dcd72c Compare April 30, 2025 06:40
@ShanSimu
Copy link
Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 1dcd72c to 49ad763 Compare May 6, 2025 02:16
@ShanSimu
Copy link
Author

ShanSimu commented May 6, 2025

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@ShanSimu
Copy link
Author

ShanSimu commented May 7, 2025

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

commutative_inputs_enabled_ = true;
commutative_inputs_ = inputs;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need both commutative_inputs_enabled_ and commutative_inputs_?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with this, we plan to support only 1 pair of commutative inputs, correct?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need both commutative_inputs_enabled_ and commutative_inputs_?

The commutative_inputs_enabled_ flag can indeed be removed—checking whether commutative_inputs_ is empty is sufficient to determine if commutative behavior is enabled.

with this, we plan to support only 1 pair of commutative inputs, correct?

We support a set of commutative inputs, not just a single pair.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does a set of commutative inputs mean? More than two inputs can be commutative? Do you have a test case for that?
I was thinking about an extreme case that input 0 and 1 are commutative and input 2 and 3 are commutative.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a set of commutative inputs means more than two inputs can be commutative. We currently don't have such an op or test case.
input 0 and 1 are commutative and input 2 and 3 are commutative is another corner case I hadn't considered, and currently there are no corresponding test cases for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I would suggest to stick on what's supported and tested for now. That's only 1 pair of commutative inputs is supported. Update the code and comment to reflect this design and restriction.
For the cases with more than 2 commutative inputs, or two pairs of commutative inputs, we can add support once the request pops up.

Copy link
Author

@ShanSimu ShanSimu May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commutative_inputs_ has been changed to std::array<size_t, 2>, strictly limiting it to two inputs. This change has been documented in the comments to clarify the restriction.

commutative_inputs_enabled_ has been removed. Since std::array<size_t, 2> has a fixed size of 2, we use a sentinel value—SIZE_MAX—to represent an "empty array" or an uninitialized state of commutative_inputs_.


return commutative_inputs_enabled_
&& std::includes(commutative_inputs_.begin(),
commutative_inputs_.end(), inputs.begin(), inputs.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how this function is used. But from it's implementation, i think the function name can be improved:

bool op_schema_t::is_commutative_inputs(size_t input1, size_t input2) {
      // make sure both {input1, input2} and {input2, input1} are checked.
      const bool ret = .....;
      return ret;
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In actual usage, one way this function is called is with no parameters, solely to check whether the current op supports commutativity. If we change the parameters to size_t input1, size_t input2, it would make this zero-argument usage less straightforward or even impossible without overloading or adding a separate method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is not the arguments of the function. It's more about the function name. You can still use std::set to pass the inputs.

  • is_commutative_inputs(inputs): check if the given inputs are commutative, return true or false accordingly.
  • get_commutative_inputs(): return the set of commutative inputs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name has been changed to is_commutative_inputs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the zero-argument usage is separated, the arguments have been changed to const size_t input0, const size_t input1

op_schema_t &set_commutative_inputs();
/*! @brief Enable commutative inputs for specific inputs*/
op_schema_t &set_commutative_inputs(
const std::set<size_t> &inputs = {0, 1});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many places is this function called? If possible, let's remove the default argument and add it explicitly at where it's needed.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are only four usages relying on the default argument. Removing it would indeed make the function interface clearer and more explicit.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


/*! @brief Get whether the commutative inputs option is enabled or not */
bool get_commutative_inputs() const;
bool get_commutative_inputs(const std::set<size_t> &inputs = {}) const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default argument here is confusing.

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch 3 times, most recently from a1aaec7 to 655aaaf Compare May 9, 2025 11:17
@ShanSimu
Copy link
Author

ShanSimu commented May 9, 2025

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph


bool op_schema_t::is_commutative_op() const {
return commutative_inputs_[0] != SIZE_MAX
&& commutative_inputs_[1] != SIZE_MAX;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to just make it empty to represent that there are no commutative ops?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::array<T, N> is a fixed-size container, so empty() always returns false. I also tried using std::pair, but it likewise always contains two values—and it doesn't even have an empty() function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explain. I would suggest to not use SIZE_MAX as magical numbers in this case.
If we don't consider std::optional which requires C++17, two options I can imagine:

  1. Still use std::set as your previous solution, but check the size of the set when needed. It should be either empty or 2 elements.
  2. Use a vector/set/pair/etc of std::pair. Check empty if it's not set. Once set, access the elements in the pair.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting @uxlfoundation/onednn-graph to see if other suggestions here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adopted Method 1, performing a size check on the inputs in set_commutative_inputs

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 655aaaf to 86bbc2b Compare May 14, 2025 07:27
Copy link
Contributor

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase and test the PR.

@ShanSimu ShanSimu force-pushed the shaojiec/select_commutative branch from 86bbc2b to b9aa1d6 Compare May 15, 2025 02:41
@ShanSimu
Copy link
Author

make test
set test_scope=NIGHTLY
disable benchdnn_all
enable benchdnn_graph

@TaoLv TaoLv merged commit 0626699 into main May 15, 2025
23 of 24 checks passed
@TaoLv TaoLv deleted the shaojiec/select_commutative branch May 15, 2025 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants