|
| 1 | +//! Shared helpers for PqoCache and TqCache implementations. |
| 2 | +
|
| 3 | +use candle_core::{DType, Result, Tensor}; |
| 4 | +use mistralrs_kv_cache::DequantResult; |
| 5 | + |
| 6 | +use super::cache_err; |
| 7 | +use super::config::CacheConfig; |
| 8 | +use super::precomputed::GpuPrecomputed; |
| 9 | +use super::quantize_tensor::{polar_dequantize, QuantConfig}; |
| 10 | +use super::storage::CompressedStorage; |
| 11 | + |
| 12 | +/// Dequantize the full compressed cache for a layer. |
| 13 | +/// |
| 14 | +/// Shared implementation used by both `PqoCache` and `TqCache`. |
| 15 | +// qual:allow(TQ-003) — tested via cache_pqo_tests + cache_storage_tests integration tests |
| 16 | +pub(crate) fn dequantize_full_impl( |
| 17 | + storage: &CompressedStorage, |
| 18 | + config: &QuantConfig<'_>, |
| 19 | + layer: usize, |
| 20 | + orig_dtype: DType, |
| 21 | +) -> Result<(Tensor, Tensor)> { |
| 22 | + let total_seq = storage.seq_len(layer); |
| 23 | + let head_dim = storage.head_dim; |
| 24 | + let num_kv_heads = storage.num_kv_heads; |
| 25 | + let packed_dim = storage.packed_dim(); |
| 26 | + let num_blocks = storage.num_blocks(); |
| 27 | + |
| 28 | + let ki = storage |
| 29 | + .k_indices(layer) |
| 30 | + .ok_or_else(|| cache_err("k_indices not initialized"))?; |
| 31 | + let ks = storage |
| 32 | + .k_scales(layer) |
| 33 | + .ok_or_else(|| cache_err("k_scales not initialized"))?; |
| 34 | + let vi = storage |
| 35 | + .v_indices(layer) |
| 36 | + .ok_or_else(|| cache_err("v_indices not initialized"))?; |
| 37 | + let vs = storage |
| 38 | + .v_scales(layer) |
| 39 | + .ok_or_else(|| cache_err("v_scales not initialized"))?; |
| 40 | + |
| 41 | + let all_ki = ki |
| 42 | + .narrow(1, 0, total_seq)? |
| 43 | + .reshape((num_kv_heads * total_seq, packed_dim))?; |
| 44 | + let all_ks = ks |
| 45 | + .narrow(1, 0, total_seq)? |
| 46 | + .reshape((num_kv_heads * total_seq, num_blocks))?; |
| 47 | + let all_vi = vi |
| 48 | + .narrow(1, 0, total_seq)? |
| 49 | + .reshape((num_kv_heads * total_seq, packed_dim))?; |
| 50 | + let all_vs = vs |
| 51 | + .narrow(1, 0, total_seq)? |
| 52 | + .reshape((num_kv_heads * total_seq, num_blocks))?; |
| 53 | + |
| 54 | + let full_k = polar_dequantize(&all_ki, &all_ks, config)? |
| 55 | + .reshape((1, num_kv_heads, total_seq, head_dim))? |
| 56 | + .to_dtype(orig_dtype)?; |
| 57 | + let full_v = polar_dequantize(&all_vi, &all_vs, config)? |
| 58 | + .reshape((1, num_kv_heads, total_seq, head_dim))? |
| 59 | + .to_dtype(orig_dtype)?; |
| 60 | + |
| 61 | + Ok((full_k, full_v)) |
| 62 | +} |
| 63 | + |
| 64 | +/// Build a [`QuantConfig`] from precomputed tensors and cache configuration. |
| 65 | +pub(crate) fn make_quant_config<'a>( |
| 66 | + precomputed: &'a Option<GpuPrecomputed>, |
| 67 | + config: &CacheConfig, |
| 68 | +) -> Result<QuantConfig<'a>> { |
| 69 | + let pre = precomputed |
| 70 | + .as_ref() |
| 71 | + .ok_or_else(|| cache_err("precomputed not initialized"))?; |
| 72 | + Ok(QuantConfig { |
| 73 | + head_dim: config.head_dim, |
| 74 | + bits: config.bits, |
| 75 | + outlier_blocks: config.outlier_blocks, |
| 76 | + pre, |
| 77 | + }) |
| 78 | +} |
| 79 | + |
| 80 | +/// Flatten K/V tensors from `[1, heads, seq, dim]` to `[heads*seq, dim]` as f32. |
| 81 | +pub(crate) fn flatten_kv( |
| 82 | + k: &Tensor, |
| 83 | + v: &Tensor, |
| 84 | + num_kv_heads: usize, |
| 85 | + head_dim: usize, |
| 86 | +) -> Result<(Tensor, Tensor)> { |
| 87 | + let new_seq_len = k.dims()[2]; |
| 88 | + let k_flat = k |
| 89 | + .squeeze(0)? |
| 90 | + .to_dtype(DType::F32)? |
| 91 | + .reshape((num_kv_heads * new_seq_len, head_dim))?; |
| 92 | + let v_flat = v |
| 93 | + .squeeze(0)? |
| 94 | + .to_dtype(DType::F32)? |
| 95 | + .reshape((num_kv_heads * new_seq_len, head_dim))?; |
| 96 | + Ok((k_flat, v_flat)) |
| 97 | +} |
| 98 | + |
| 99 | +/// Quantize a K/V pair using polar quantization. |
| 100 | +/// |
| 101 | +/// Returns `(k_indices, k_scales, v_indices, v_scales)` in flat format. |
| 102 | +pub(crate) fn quantize_kv_pair( |
| 103 | + k_flat: &Tensor, |
| 104 | + v_flat: &Tensor, |
| 105 | + norm_mode: super::config::QuantNormMode, |
| 106 | + qc: &super::quantize_tensor::QuantConfig<'_>, |
| 107 | +) -> Result<(Tensor, Tensor, Tensor, Tensor)> { |
| 108 | + let (k_idx, k_sc) = super::quantize_tensor::polar_quantize(k_flat, norm_mode, qc)?; |
| 109 | + let (v_idx, v_sc) = super::quantize_tensor::polar_quantize(v_flat, norm_mode, qc)?; |
| 110 | + Ok((k_idx, k_sc, v_idx, v_sc)) |
| 111 | +} |
| 112 | + |
| 113 | +/// Create a `DequantResult` with no logit bias (PQO mode). |
| 114 | +// qual:allow(TQ-003) — trivial constructor, tested through PqoCache integration tests |
| 115 | +pub(crate) fn dequant_result(k: Tensor, v: Tensor) -> DequantResult { |
| 116 | + DequantResult { |
| 117 | + k, |
| 118 | + v, |
| 119 | + logit_bias: None, |
| 120 | + } |
| 121 | +} |
0 commit comments