Describe the bug
Look at test/attention/test_attn_weights.py, test_larger_comparison, the q_len=1 cases.
Currently, FlashInfer is skipped for this one. If the skip is removed (in get_variants), the tests fail for FlashInfer. All other variants give the same results.
FlashInfer works for all other q_len values. Likely this is because the case q_len=1 is implemented differently.
Describe the bug
Look at
test/attention/test_attn_weights.py,test_larger_comparison, theq_len=1cases.Currently, FlashInfer is skipped for this one. If the skip is removed (in
get_variants), the tests fail for FlashInfer. All other variants give the same results.FlashInfer works for all other
q_lenvalues. Likely this is because the caseq_len=1is implemented differently.