-
Notifications
You must be signed in to change notification settings - Fork 256
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor #3331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vpietila-amd
wants to merge
44
commits into
develop
Choose a base branch
from
vpietila/ckb-improve-elementwise-ops
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,726
−612
Open
[CK_BUILDER] Refactor convolution signature to provide data type/layout/elementwise op per tensor #3331
Changes from all commits
Commits
Show all changes
44 commits
Select commit
Hold shift + click to select a range
bbe5426
Separate layouts into separate entities for input, weight, and output…
6d7fdf6
Add test for handling bias tensor layouts.
626610f
Use instance string in builder tests.
edd180a
Add handling of output bias data types and layouts.
0bd50a5
Generalize handling of the elementwise ops.
c7c814f
Test fix.
809c8b4
Create builder for layouts.
f1cff54
Layout builder improvements.
3173d94
Improve layout builder.
d358895
Simplify bias layout handling.
69bfe64
Code clean-up.
aba3eea
Move layout utils into separate file.
f00ac4e
Remove hard-coded layout combinations.
74bc17a
Small code clean-up.
54c58f1
Move data type utils into a separate file.
d6fc6c8
Add data types, layouts, and elementwise ops per conv tensor.
05a4067
Builder bug fixes after refactoring.
c25eb65
Working baseline.
a4252c4
Make signature definition look nice in the test code.
7ed339f
Move TensorConfig into test implementations.
beac0e8
Fix all fwd conv builder tests.
05c46ba
Fix conv traits and descriptors tests.
67a607d
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-impro…
bab26ee
More factory assets under a separate directory.
8a9e22c
Fix building conv traits.
979a851
Fix clang-format.
e995fcf
Add Readme doc to describe the design.
a468fce
Add link to main Readme. Fix links in the builder design doc.
95fbe2d
Clean-up data type/layout/elementwise op conversions.
6850b9b
Switch from dimension and tensor type specific layouts to a flat list…
c27ea34
Fix clang-formatting.
ea080e6
Fix clang-format for test code.
b82f502
Simplify fwd conv signature definitions in the test code.
915e6ca
Remove accidental edits.
41d1bfd
Fix comment string.
b5bec0d
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-impro…
91a10e4
Fix instance factory after rebase.
5777d8b
Fix tests after rebase.
3949dc3
Unify layout handling.
3a9bac5
Add more conv layout unit tests.
b847ca5
Clang-format.
ba74e1a
Merge remote-tracking branch 'origin/develop' into vpietila/ckb-impro…
aed8730
Fix merge conflicts.
cd37f4e
Improve elementwise op handling.
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,244 @@ | ||
| # Composable Kernel Builder Design Documentation | ||
|
|
||
| This directory contains the builder framework for Composable Kernel, which provides a compile-time, type-safe interface for constructing convolution operations with various configurations. | ||
|
|
||
| ## Table of Contents | ||
|
|
||
| - [Convolution Signature Design](#convolution-signature-design) | ||
| - [Overview](#overview) | ||
| - [Architecture](#architecture) | ||
| - [Core Components](#core-components) | ||
| - [Concepts and Validation](#concepts-and-validation) | ||
| --- | ||
|
|
||
| ## Convolution Signature Design | ||
|
|
||
| ### Overview | ||
|
|
||
| The convolution signature system provides a **compile-time description** of grouped convolution operations. A signature is a collection of properties that fully characterize a convolution kernel's mathematical and operational behavior, enabling: | ||
|
|
||
| - **Compile-time validation**: Ensures type safety and correctness before kernel instantiation | ||
| - **Kernel selection**: Matches user requirements to optimized implementations | ||
| - **Specialization**: Enables optimized code paths for specific configurations | ||
| - **Composability**: Supports building complex operations from simpler components | ||
|
|
||
| The signature leverages modern C++20 features, particularly **concepts**, to provide expressive, self-documenting interfaces with compile-time guarantees. | ||
|
|
||
| ### Architecture | ||
|
|
||
| The signature system is organized into a hierarchical structure: | ||
|
|
||
| ``` | ||
| ┌─────────────────────────────────────────────────────────┐ | ||
| │ ConvSignature │ | ||
| ├─────────────────────────────────────────────────────────┤ | ||
| │ Properties: │ | ||
| │ • spatial_dim: int (1D, 2D, or 3D) │ | ||
| │ • direction: ConvDirection (Fwd/BwdData/BwdWeight) │ | ||
| │ • data_type: DataType (default data type) │ | ||
| │ • accumulation_data_type: DataType │ | ||
| │ • input: ConvTensor ──┐ │ | ||
| │ • weight: ConvTensor ──│ │ | ||
| │ • output: ConvTensor ──│ │ | ||
| └──────────────────────────────────┼──────────────────────┘ | ||
| │ | ||
| ▼ | ||
| ┌─────────────────────────────────────────┐ | ||
| │ ConvTensor │ | ||
| ├─────────────────────────────────────────┤ | ||
| │ ╔═════════════════════════════════════╗ │ | ||
| │ ║ TensorConfig (required) ║ │ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me why we need the TensorConfig wrapper instead of just directly having a layout, datatype, and compute_type for the tensor. |
||
| │ ╠═════════════════════════════════════╣ │ | ||
| │ ║ • layout: ConvLayout ║ │ | ||
| │ ║ • data_type: DataType (optional) ║ │ | ||
| │ ║ • compute_type: DataType (optional)║ │ | ||
| │ ╚═════════════════════════════════════╝ │ | ||
| │ │ | ||
| │ ┌─────────────────────────────────────┐ │ | ||
| │ │ TensorOperation (optional) │ │ | ||
| │ ├─────────────────────────────────────┤ │ | ||
| │ │ • elementwise_operation │ │ | ||
| │ │ • auxiliary_operand_configs[] │ │ | ||
| │ │ (each is also ConvTensor) ◄───────┼─┐ | ||
| │ └─────────────────────────────────────┘ │ │ | ||
| └─────────────────────────────────────────┘ │ | ||
| │ | ||
| Recursive ───────────────┘ | ||
| ``` | ||
| Key Design Points: | ||
| - ConvSignature contains three ConvTensor instances (input, weight, output) | ||
| - All tensors share the same ConvTensor structure | ||
| - Each ConvTensor has: | ||
| - TensorConfig (required): Defines layout as well as optional data and compute type overrides | ||
| - TensorOperation (optional): Defines fused elementwise operations | ||
| - Auxiliary operands (e.g., bias) in TensorOperation also use the ConvTensor type | ||
|
|
||
| ### Core Components | ||
|
|
||
| #### 1. Signature Level | ||
|
|
||
| The top-level signature contains global properties that apply to the entire convolution operation: | ||
|
|
||
| ```cpp | ||
| template <typename T> | ||
| concept ConvSignatureDescriptor = requires(T t) { | ||
| { t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3 | ||
| { t.data_type } -> std::convertible_to<DataType>; // Default data type | ||
| { t.input } -> ConvTensorDescriptor; | ||
| { t.weight } -> ConvTensorDescriptor; | ||
| { t.output } -> ConvTensorDescriptor; | ||
| requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction | ||
| }; | ||
| ``` | ||
|
|
||
| **Properties:** | ||
| - **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D) | ||
| - **`direction`**: Operation type (optional, defaults to FORWARD) | ||
| - `FORWARD`: Standard forward convolution | ||
| - `BACKWARD_DATA`: Gradient computation w.r.t. input | ||
| - `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights | ||
| - **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8) | ||
| - **`accumulation_data_type`**: Type used for internal accumulation | ||
|
|
||
| #### 2. Tensor Level | ||
|
|
||
| Each tensor (input, weight, output) has its own descriptor: | ||
|
|
||
| ```cpp | ||
| template <typename T> | ||
| concept ConvTensorDescriptor = requires(T t) { | ||
| { t.config } -> TensorConfigDescriptor; | ||
| requires ElementwiseOpWellDefinedIfProvided<T>; | ||
| }; | ||
| ``` | ||
|
|
||
| A tensor descriptor encapsulates: | ||
| - **Configuration**: Layout and data type information | ||
| - **Operation** (optional): Fused elementwise operations on this tensor | ||
|
|
||
| #### 3. Tensor Configuration | ||
|
|
||
| Describes the memory layout and data types: | ||
|
|
||
| ```cpp | ||
| template <typename T> | ||
| concept TensorConfigDescriptor = requires(T t) { | ||
| { t.layout } -> std::convertible_to<ConvLayout>; | ||
| { t.data_type } -> std::convertible_to<DataType>; // Optional override | ||
| }; | ||
| ``` | ||
|
|
||
| **Layout Types** (dimension-specific): | ||
| - **1D Convolution**: | ||
| - Input: `GNCW`, `GNWC`, `NWGC`, `NGCW`, `G_NW_C_strided` | ||
| - Weight: `GKXC`, `GKCX`, `KXGC`, `G_K_X_C_strided` | ||
| - Output: `GNKW`, `GNWK`, `NWGK`, `NGKW`, `G_NW_K_strided` | ||
|
|
||
| - **2D Convolution**: | ||
| - Input: `GNCHW`, `GNHWC`, `NHWGC`, `NGCHW`, `G_NHW_C_strided` | ||
| - Weight: `GKYXC`, `GKCYX`, `KYXGC`, `G_K_YX_C_strided` | ||
| - Output: `GNKHW`, `GNHWK`, `NHWGK`, `NGKHW`, `G_NHW_K_strided` | ||
|
|
||
| - **3D Convolution**: | ||
| - Input: `GNCDHW`, `GNDHWC`, `NDHWGC`, `NGCDHW`, `G_NDHW_C_strided` | ||
| - Weight: `GKZYXC`, `GKCZYX`, `KZYXGC`, `G_K_ZYX_C_strided` | ||
| - Output: `GNKDHW`, `GNDHWK`, `NDHWGK`, `NGKDHW`, `G_NDHW_K_strided` | ||
|
|
||
| Where: | ||
| - `G` = Groups | ||
| - `N` = Batch size | ||
| - `C` = Input channels | ||
| - `K` = Output channels (filters) | ||
| - `W`, `H`, `D` = Width, Height, Depth (spatial dimensions) | ||
| - `X`, `Y`, `Z` = Filter dimensions | ||
|
|
||
| #### 4. Tensor Operations | ||
|
|
||
| Describes fused elementwise operations applied to a tensor: | ||
|
|
||
| ```cpp | ||
| template <typename T> | ||
| concept TensorOperatorDescriptor = requires(T t) { | ||
| { t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>; | ||
| requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>; | ||
| }; | ||
| ``` | ||
|
|
||
| **Supported Operations:** | ||
| - `PASS_THROUGH`: No operation (identity) | ||
| - `SCALE`: Multiply by a scalar | ||
| - `CLAMP`: Clamp values to a range | ||
| - `BIAS_BNORM_CLAMP`: Bias addition + batch normalization + clamp | ||
| - `SCALEADD_SCALEADD_RELU`: Fused scale-add operations + ReLU activation | ||
|
|
||
| **Auxiliary Operands:** | ||
| Some operations require additional tensor inputs (e.g., bias tensors, scaling factors). These are specified through `auxiliary_operand_configs`, which is an array of `TensorConfigDescriptor` objects describing the layout and data type of each auxiliary input. | ||
|
|
||
| ### Concepts and Validation | ||
|
|
||
| The signature system uses C++20 concepts for compile-time validation at multiple levels: | ||
|
|
||
| #### Constraint Concepts | ||
|
|
||
| ```cpp | ||
| // Spatial dimension must be 1, 2, or 3 | ||
| template <auto N> | ||
| concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3); | ||
|
|
||
| // Valid data types for convolution | ||
| template <DataType T> | ||
| concept ValidConvDataType = | ||
| (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) || | ||
| (T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8); | ||
| ``` | ||
|
|
||
| #### Validation Concept | ||
|
|
||
| ```cpp | ||
| // Validates a complete signature | ||
| template <auto Sig> | ||
| concept ValidConvSignature = requires { | ||
| requires ConvSpatialDim<Sig.spatial_dim>; | ||
| requires ValidConvDataType<Sig.data_type>; | ||
| }; | ||
| ``` | ||
|
|
||
| #### Tensor Descriptors | ||
|
|
||
| The layout/data type/elementwise operation are described per tensor. This multi-level hierarchy allows: | ||
| - **Flexibility**: Each tensor can have independent layout and data type | ||
| - **Reusability**: Common configurations can be shared across different signatures | ||
| - **Extensibility**: New properties can be added to specific levels without affecting others | ||
| - **Clarity**: Separates concerns (global properties vs. tensor-specific properties) | ||
|
|
||
| #### Optional Signature Fields | ||
|
|
||
| Several fields in the signature are optional: | ||
| - **`direction`**: Defaults to `FORWARD` if not specified, reducing boilerplate for the common case | ||
| - **Tensor `data_type`**: Falls back to signature's default, allowing mixed-precision with minimal specification | ||
| - **Tensor `operation`**: Defaults to `PASS_THROUGH`, supporting both fused and non-fused operations with the same interface | ||
|
|
||
| This design follows the principle of "make the common case simple, the complex case possible." | ||
|
|
||
| #### Union-Based Layout Representation | ||
|
|
||
| The `ConvLayout` type uses unions to support dimension-agnostic code: | ||
|
|
||
| ```cpp | ||
| struct ConvLayout { | ||
| union { | ||
| ConvInputLayout _input_layout; | ||
| ConvWeightLayout _weight_layout; | ||
| ConvOutputLayout _output_layout; | ||
| ConvAuxiliaryTensorLayout _aux_tensor_layout; | ||
| }; | ||
| // ... constructors for each type | ||
| }; | ||
| ``` | ||
| This allows: | ||
| - Single type to represent all layout variants | ||
| - Type-safe construction through overloaded constructors | ||
| - Compile-time enforcement of valid combinations through concepts | ||
| --- | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice documentation! Thanks for taking time to create this!