-
Notifications
You must be signed in to change notification settings - Fork 113
feat: add xAttention support for Qwen3 generative recommendation. #586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should inherit BatchInputBuilder not create a new standalone class.
| std::vector<SequenceOutput>& outputs, | ||
| const Tokenizer& tokenizer, | ||
| const Sequence& base) { | ||
| VLOG(1) << "[debug] generate_multi_round_output, bw: " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is VLOG(1)? why not use LOG?
| #include "acl/acl.h" | ||
| #include "aclnn_beam_search_group.h" | ||
| #include "acltensor_utils.h" | ||
| #include "util/tensor_helper.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are many repetitive headers included in .h and .cpp file in xllm_ops dir.
| #include "aclnn_select_unshared_kv.h" | ||
| #include "cache_select.h" | ||
|
|
||
| #define CHECK_ACL_SUCCESS(expr, msg) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move this macro to acltensor_utils.h.
| #include <tuple> | ||
|
|
||
| #include "kernels/npu/xllm_ops/beam_search_group.h" | ||
| #include "kernels/npu/xllm_ops/cache_select.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need add xllm_ops_api.h, like mlu_ops_api.h and cuda_ops_api.h.
cc0f151 to
b0f3758
Compare
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for xAttention in Qwen3 models for generative recommendation scenarios. The changes are extensive, adding new scheduling logic, KV cache management, and adapting the ACL Graph workflow for multi-step decoding. The introduction of QwenRecWorkerImpl and AclGraphRecExecutorImpl effectively isolates the new functionality. However, I have identified two significant issues. There's a high-severity issue concerning potentially excessive memory allocation for the shared KV cache, and a critical issue where a KV cache capacity check is bypassed, which could lead to memory corruption. These should be addressed to ensure the stability and efficiency of the new feature.
| void KVCacheState::incr_kv_cache_tokens_num(size_t num) { | ||
| if (FLAGS_max_decode_rounds > 0) { | ||
| kv_cache_tokens_num_ += num; | ||
| return; | ||
| } | ||
| CHECK(kv_cache_tokens_num_ + num <= current_max_tokens_capacity()); | ||
| kv_cache_tokens_num_ += num; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bypassing the KV cache capacity check when FLAGS_max_decode_rounds > 0 is dangerous. This removes a critical safety measure that prevents out-of-bounds writes to the KV cache. Even if a different memory management scheme is used for multi-round decoding, there is still a capacity limit for any allocated buffer. Without this check, it's possible for kv_cache_tokens_num_ to exceed the allocated capacity, which could lead to memory corruption and unpredictable behavior. This check should be restored or replaced with an equivalent check appropriate for the new memory scheme.
void KVCacheState::incr_kv_cache_tokens_num(size_t num) {
CHECK(kv_cache_tokens_num_ + num <= current_max_tokens_capacity());
kv_cache_tokens_num_ += num;
}| raw_forward_input.shared_kv_shape = { | ||
| batch_size * FLAGS_max_token_per_req, n_kv_heads, head_dim}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for shared_kv_shape uses FLAGS_max_token_per_req as a multiplier for the batch size. This can lead to significant over-allocation of memory for the shared KV cache, especially if FLAGS_max_token_per_req is set to a large value to accommodate a few long requests, while the average request length is much shorter. This could result in unnecessary memory pressure and potential out-of-memory errors. As noted in the TODO comment on lines 339-342, it would be much safer and more memory-efficient to calculate the required size based on the actual token needs of the sequences in the current batch.
Uh oh!
There was an error while loading. Please reload this page.