-
Notifications
You must be signed in to change notification settings - Fork 1.1k
graph: utils: pm: support commutative select #3166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7ffd645
to
b5fc799
Compare
please format the commit message with correct component, graph: interface: ..., graph: utils: pm: ... |
Thanks for letting me know. I split the commit into three parts based on file paths, using the directory structure as the distinction: |
b5fc799
to
26cd4b3
Compare
26cd4b3
to
7e9349a
Compare
7e9349a
to
37c9cd8
Compare
7932537
to
df94010
Compare
make test |
src/graph/interface/op_schema.hpp
Outdated
|
||
/*! @brief Get whether the commutative inputs option is enabled or not */ | ||
bool get_commutative_inputs() const; | ||
bool get_commutative_inputs(const std::set<size_t> &inputs = {}) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a brief description for this function, check if the inputs set are all commutative inputs, if inputs is an empty set, check if the op has commutative inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default argument here is confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default argument here is to maintain compatibility with zero-argument usage(in nested_matcher.cpp#L73), where we simply check if the op has commutative inputs. It might make sense to split this into two separate functions:
- is_commutative_inputs(inputs)
- is_commutative_op()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
df94010
to
91fcec4
Compare
91fcec4
to
1dcd72c
Compare
make test |
1dcd72c
to
49ad763
Compare
make test |
make test |
commutative_inputs_enabled_ = true; | ||
commutative_inputs_ = inputs; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need both commutative_inputs_enabled_
and commutative_inputs_
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with this, we plan to support only 1 pair of commutative inputs, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need both
commutative_inputs_enabled_
andcommutative_inputs_
?
The commutative_inputs_enabled_ flag can indeed be removed—checking whether commutative_inputs_ is empty is sufficient to determine if commutative behavior is enabled.
with this, we plan to support only 1 pair of commutative inputs, correct?
We support a set of commutative inputs, not just a single pair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does a set of commutative inputs
mean? More than two inputs can be commutative? Do you have a test case for that?
I was thinking about an extreme case that input 0 and 1 are commutative and input 2 and 3 are commutative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, a set of commutative inputs
means more than two inputs can be commutative. We currently don't have such an op or test case.
input 0 and 1 are commutative and input 2 and 3 are commutative
is another corner case I hadn't considered, and currently there are no corresponding test cases for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I would suggest to stick on what's supported and tested for now. That's only 1 pair of commutative inputs is supported. Update the code and comment to reflect this design and restriction.
For the cases with more than 2 commutative inputs, or two pairs of commutative inputs, we can add support once the request pops up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commutative_inputs_
has been changed to std::array<size_t, 2>
, strictly limiting it to two inputs. This change has been documented in the comments to clarify the restriction.
commutative_inputs_enabled_
has been removed. Since std::array<size_t, 2> has a fixed size of 2, we use a sentinel value—SIZE_MAX—to represent an "empty array" or an uninitialized state of commutative_inputs_
.
src/graph/interface/op_schema.cpp
Outdated
|
||
return commutative_inputs_enabled_ | ||
&& std::includes(commutative_inputs_.begin(), | ||
commutative_inputs_.end(), inputs.begin(), inputs.end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how this function is used. But from it's implementation, i think the function name can be improved:
bool op_schema_t::is_commutative_inputs(size_t input1, size_t input2) {
// make sure both {input1, input2} and {input2, input1} are checked.
const bool ret = .....;
return ret;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In actual usage, one way this function is called is with no parameters, solely to check whether the current op supports commutativity. If we change the parameters to size_t input1, size_t input2
, it would make this zero-argument usage less straightforward or even impossible without overloading or adding a separate method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is not the arguments of the function. It's more about the function name. You can still use std::set
to pass the inputs.
- is_commutative_inputs(inputs): check if the given inputs are commutative, return true or false accordingly.
- get_commutative_inputs(): return the set of commutative inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function name has been changed to is_commutative_inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the zero-argument usage is separated, the arguments have been changed to const size_t input0, const size_t input1
src/graph/interface/op_schema.hpp
Outdated
op_schema_t &set_commutative_inputs(); | ||
/*! @brief Enable commutative inputs for specific inputs*/ | ||
op_schema_t &set_commutative_inputs( | ||
const std::set<size_t> &inputs = {0, 1}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How many places is this function called? If possible, let's remove the default argument and add it explicitly at where it's needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are only four usages relying on the default argument. Removing it would indeed make the function interface clearer and more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
src/graph/interface/op_schema.hpp
Outdated
|
||
/*! @brief Get whether the commutative inputs option is enabled or not */ | ||
bool get_commutative_inputs() const; | ||
bool get_commutative_inputs(const std::set<size_t> &inputs = {}) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default argument here is confusing.
a1aaec7
to
655aaaf
Compare
make test |
src/graph/interface/op_schema.cpp
Outdated
|
||
bool op_schema_t::is_commutative_op() const { | ||
return commutative_inputs_[0] != SIZE_MAX | ||
&& commutative_inputs_[1] != SIZE_MAX; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to just make it empty to represent that there are no commutative ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::array<T, N>
is a fixed-size container, so empty()
always returns false. I also tried using std::pair
, but it likewise always contains two values—and it doesn't even have an empty()
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explain. I would suggest to not use SIZE_MAX
as magical numbers in this case.
If we don't consider std::optional
which requires C++17, two options I can imagine:
- Still use std::set as your previous solution, but check the size of the set when needed. It should be either empty or 2 elements.
- Use a vector/set/pair/etc of std::pair. Check empty if it's not set. Once set, access the elements in the pair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Casting @uxlfoundation/onednn-graph to see if other suggestions here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I adopted Method 1, performing a size check on the inputs in set_commutative_inputs
655aaaf
to
86bbc2b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please rebase and test the PR.
86bbc2b
to
b9aa1d6
Compare
make test |
Description
The original setting for "commutative" was limited to operators with only two inputs. However, the select operator has three inputs, and only two of them are commutative.
A new member variable,
std::set<size_t> commutative_inputs
, has been added toop_schema
to store the indices of commutative inputs. The functionset_commutative_inputs(std::set<size_t> inputs = {0, 1})
uses default parameters to maintain support for the original commutative operators.In
nested_matcher.cpp
, the functionshas_commutative_inputs
,match_commutative_inputs
, andmatch_node_inputs
have been updated to include checks for whether a port belongs to the commutative inputs.This PR is for fix MFDNN-13383. Another method is also implemented here, which is achieved by swapping the mapping method of ports inside and outside the subgraph.