Skip to content

Conversation

@danielclough
Copy link
Contributor

Summary

Adds KV cache management and fixes critical causal mask bug for Qwen2 multi-turn inference. Includes numerical precision improvements for RoPE and attention.

Changes

  • Causal mask bug fix: Corrects mask shape for cached decoding (was [tgt, tgt], now [tgt, total]) - critical for multi-turn conversations
  • Precision improvements: RoPE and softmax now use F32 intermediates to match PyTorch behavior
  • KV cache API: Adds extract_kv_cache/restore_kv_cache methods for cache manipulation and inspection
  • Selective attention: New prepare_4d_causal_attention_mask_with_cache_position for non-contiguous cache positions
  • Embedding injection: forward_from_embeds methods enable custom embedding workflows (e.g., multimodal)
  • Stability fix: Replaces NEG_INFINITY with f32::MIN to avoid NaN propagation when combining masks
  • Cache manipulation: Adds shift_kv_cache_first_to_last for advanced patterns (e.g., negative prompt refresh)

Motivation

The causal mask bug prevented proper multi-turn decoding with KV cache. The new cache management APIs enable advanced inference patterns like streaming audio generation (VibeVoice) and speculative decoding while maintaining precision for F16/BF16 inference.

Breaking Changes

None - all changes are backward compatible additions or bug fixes.

✅ Validation

Routine

cargo fmt --all
cargo test -p candle-transformers
cargo clippy -p candle-transformers

Test Qwen2 Example

Simple Query

cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. ." --model "2-1.5b"

Test with very short prompt to ensure single-token decode works

cargo run --example qwen --features metal --release -- --prompt "Hi" --sample-len 10 --model "2-1.5b"

- Fix causal mask shape for cached decoding (critical for multi-turn)
- Add extract/restore methods for KV cache manipulation
- Add support for non-contiguous cache positions via `cache_position`
- Add forward_from_embeds for custom embedding workflows
- Improve RoPE and softmax precision with F32 intermediates (matching PyTorch)
- Replace NEG_INFINITY with f32::MIN to avoid NaN propagation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant