-
Notifications
You must be signed in to change notification settings - Fork 256
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor #3331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…ve-elementwise-ops
| @@ -0,0 +1,244 @@ | |||
| # Composable Kernel Builder Design Documentation | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice documentation! Thanks for taking time to create this!
|
|
||
| enum class DataType | ||
| { | ||
| UNDEFINDED = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: UNDEFINED
Only fix if you need to make other changes.
| // G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width | ||
| // Enum defines Input, Weight, and Output tensor layouts respectively. | ||
| enum class GroupConvLayout1D | ||
| enum class TensorLayout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty clean. Another strategy I sometimes see with flat lists of enums like this is to add extra enum values for start and end of ranges of grouped elements to enable classification checks with math. We can add that later if it helps logic, but concepts and lists with constexpr std::array<TensorLayout> can also be used if we need more structure.
| inline std::ostream& | ||
| operator<<(std::ostream& os, | ||
| const std::variant<GroupConvLayout1D, GroupConvLayout2D, GroupConvLayout3D>& layout) | ||
| inline std::ostream& operator<<(std::ostream& os, TensorLayout layout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to have a "toString" function with the switch case, and then use that to define ostream printing.
| │ ConvTensor │ | ||
| ├─────────────────────────────────────────┤ | ||
| │ ╔═════════════════════════════════════╗ │ | ||
| │ ║ TensorConfig (required) ║ │ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me why we need the TensorConfig wrapper instead of just directly having a layout, datatype, and compute_type for the tensor.
| EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); | ||
| EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); | ||
| EXPECT_THAT(Traits::layout, | ||
| ::testing::ElementsAre(ck_tile::builder::TensorLayout::GNHWC, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: generally add a using declaration at the top of the test file:
using ::testing::ElementsAre
There are a lot of advantages to that, including shorter code and documenting what helpers you are importing.
(The same goes for test utils, prefer using declarations instead of importing entire namespaces. Once you import an entire namespace, it can quickly become difficult to deduce which functions you are using or even which function is being called.)
| using ::ck_tile::builder::factory::internal::GetTensorLayout; | ||
| using ::ck_tile::builder::factory::internal::LayoutToCK; | ||
|
|
||
| using namespace ::ck_tile::builder::test; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's much better to have specific using delarations for symbols in a namespace instead of importing everything from a namespace.
| using ::ck_tile::builder::factory::internal::LayoutToCK; | ||
|
|
||
| using namespace ::ck_tile::builder::test; | ||
| using enum ::ck_tile::builder::ConvDirection; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's much cleaner to only have using enum in a function scope (even if it's repeated many times) rather than have it at file scope.
| TEST(AuxiliaryTensorLayout, AssignsLayoutForG_C_strided) | ||
| { | ||
| using CKLayout = LayoutToCK<TensorLayout::G_C_strided>::type; | ||
| EXPECT_TRUE((std::is_same_v<CKLayout, ck::tensor_layout::convolution::G_C>)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Google Test has a static assert types eq matcher.
https://google.github.io/googletest/gmock_cook_book.html#restricting-the-type-of-an-argument-or-parameter-in-an-action
shumway
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! I added a lot of comments, but you don't have to implement them unless you end up needing to do more edits (no reason to run this through testing again). The feedback is more for incremental improvement and to help us converge on common best practices.
Proposed changes
Refactored the CK Builder convolution signature such that
are defined per tensor. This refactoring allows us build instances for complex fused operations such as scale-add-scale-add-relu. I added a fwd conv builder test that demonstrates how the complex ops benefit from the new signature design pattern. At the high level, the signature is composed as follows
Design Points:
The convolution traits and descriptors do not fully utilize the new structure, but all tests are passing. They could be refactored separately.