-
Notifications
You must be signed in to change notification settings - Fork 256
[CK Tile] Grouped GEMM aquant mode and non-persistent kernel #3337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
[CK Tile] Grouped GEMM aquant mode and non-persistent kernel #3337
Conversation
ThomasNing
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. LGTM overall except the above comments.
| static constexpr ck_tile::index_t K_Warp_Tile = | ||
| get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>(); | ||
|
|
||
| static constexpr bool TransposeC = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not need to add this the base TransposeC is already false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed it
| }; | ||
|
|
||
| template <typename PrecType> | ||
| struct GemmConfig_Aquant : public GemmConfigBase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not need to add a specific Gemm Config Aquant. It could directly use the GemmConfigComputeV3_2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It used to have different settings for Aquant, but indeed not needed anymore. Removed it.
| static constexpr ck_tile::index_t NumWaveGroups = 1; | ||
| static constexpr bool DoubleSmemBuffer = false; | ||
| static constexpr bool PreshuffleB = false; | ||
| static constexpr bool Persistent = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a Persistent Gemm Config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it into a template parameter and added it to the command-line parameters so that both can easily be used.
| template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; }; | ||
| template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; | ||
| template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; | ||
| template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 32, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not need the 32x32x32 for 8bit warp gemm scenario.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were instantiated with the previous config/pipeline. But indeed don't seem necessary anymore. I removed them.
|
|
||
| struct GroupedGemKernelParam_Mfma | ||
| { | ||
| // HACK: There's a bug in the AQuant pipeline that causes MRepeat > 1 to be incorrect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already solved the problem. Please sync up with the develop :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I removed the workaround.
…ed gemm quant, and add support code to example
…line selection logic
…spose C) and non-persistent kernel
…=32 variants" This reverts commit b3fd4d3.
… add persistency as runtime parameter
6d8f465 to
8985b70
Compare
|
@ErwinTerpstra Thanks for the change. Please merge with the develop then I could kick off the official CI run. |
Done! |
Proposed changes
This closes internal ticket LWPCK-4126.
The grouped GEMM quantized example already includes using a persistent kernel, but this is hard-coded and should be added to the
GemmConfigoptions. Non-persistent kernel support should be added to the kernel which currently contains astatic_assertthat requires the kernel to be configured as persistent.The example also only does grouped B quantization. Support for A quantization should be added. Note that it seems there is currently no pipeline for A quantization with B preshuffle, if needed that should be a follow-up issue.
Changes:
Note that there's still a problem in the AQuant pipeline with
MRepeat > 1andTransposeC == false(there seems to be a bug usingds_bpermute). The example and tests now conditionally only useMRepeat == 1in those cases.Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered