-
Notifications
You must be signed in to change notification settings - Fork 256
[CK_TILE][FMHA] Integrate FAv2 & FAv3 (WIP) in the single fmha_fwd() API #3153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…enc/composable_kernel into poyenc/integrate-fmha-fwd-v2-v3-apis
…enc/composable_kernel into poyenc/integrate-fmha-fwd-v2-v3-apis
| if short_circuit: | ||
| for rule in rules: | ||
| if not rule(problem_ctx, kernel_ctx): | ||
| return False | ||
| return True | ||
| return all(rule(problem_ctx, kernel_ctx) for rule in rules) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any real difference between the short_circuit path and all(rule(...)) path ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should have no difference because I didn't create a new list as function argument. I'll remove the short_circuit path
| is_v3_dedicated_tile = ( | ||
| kernel_ctx.tile.F_bm0 == 256 | ||
| and (kernel_ctx.tile.F_rm0 * kernel_ctx.tile.F_rn0 * kernel_ctx.tile.F_rk0) == 8 | ||
| and (kernel_ctx.tile.F_rm1 * kernel_ctx.tile.F_rn1 * kernel_ctx.tile.F_rk1) == 8 | ||
| ) # fmt: skip | ||
| is_v3_pipeline = kernel_ctx.pipeline.tag == "qr_async_trload_v3" | ||
| return is_v3_dedicated_tile == is_v3_pipeline |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a rule to restrict the problem_ctx and kernel_ctx, can the rule be solved by adding restrictions when constructing the kernel_ctx space ?
| (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) | ||
| and kernel_ctx.tile.F_bm0 != 128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This restriction makes no sense! (bm0=64 should be able to be used with other hdim other 128)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pre-existing check logics. If you'd remove it, we can create another PR to the purpose.
| if (problem_ctx.hdim, problem_ctx.hdim_v) == (192, 128): | ||
| if ( | ||
| kernel_ctx.pipeline.F_bias != "no" | ||
| or kernel_ctx.pipeline.F_dropout == "t" | ||
| ): | ||
| False | ||
| return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rule makes no sense! Whether a pipeline can use bias or dropout should have nothing to do with hdim sizes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pre-existing check logics for the qr_async_trload pipeline. If you'd remove it, we can create another PR to the purpose.
| if not ( | ||
| ( | ||
| kernel_ctx.pipeline.F_logits == "t" | ||
| and kernel_ctx.pipeline.F_bias == "no" | ||
| ) | ||
| or kernel_ctx.pipeline.F_logits == "f" | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this rule be solved inside the kernel_ctx space since it does not involve problem_ctx ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be another type of check that only consider the kernel_ctx attributes. We can seperate it if we later encounter more checks like this
| template<> | ||
| float fmha_fwd_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a) | ||
| float fmha_fwd_<trait, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_args a) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need trait as a template of fmha_fwd_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. By current design, we use the trait as a instance key to differentiate each template instantiations
|
Need to resolve the conflicts. |
Proposed changes
This PR merges the two APIs—fmha_fwd() and fmha_fwd_v3()—into a single unified interface. The same script, fmha_fwd.py, is now used to generate two underlying implementation functions: fmha_fwd_v2() and fmha_fwd_v3(). The public API fmha_fwd() conditionally dispatches to fmha_fwd_v3(), although the fmha_fwd_v3() path is temporarily disabled for now (the full implementation is not ready to merge due to compiler issues).
In addition, I redesigned the code-generation logic to allow users to generate multiple dispatcher functions and organize pipelines using appropriate filters.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered