-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[GPU] xattention_block_size 256 support. #33485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[GPU] xattention_block_size 256 support. #33485
Conversation
…ed with float precision to avoid an onverflow zp.
riverlijunjie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments, totally LGTM.
| svmptr_t sparse_mask_base [[type("svmptr_t")]], | ||
| svmptr_t wg_sparse_mask_base [[type("svmptr_t")]], | ||
| bool validate, | ||
| int SPARSE_BLOCK_SIZE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need this parameter if it is a macro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need. SPARSE_BLOCK_SIZE is a runtime parameter now for PA kernel, instead of a compile time jit const.
| res_event = {execute_stage(res_event, instance, xattn_estimate_find_block)}; | ||
| res_event = {execute_stage(res_event, instance, xattn_estimate_post_proc)}; | ||
| if (!bypass_xattn(params)) { | ||
| if (rt_params->xattn_block_size == 128) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If xattn_block_size is fixed value, we don't need add_stage for both 128 and 256.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately xattn_block_size is a compile time jit const for xattention kernels, while it is also a runtime parameter of model with PA node. This means users can dynamically switch it from time to time during inferencing. So this PR has to create two stages (one for 128, the other for 256) to switch in fly, accordingly.
Details:
Tickets: