diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 057641bdf..79007a38b 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -362,8 +362,8 @@ def __init__(self, *args, **kwargs): "--static_kv_dtype", default=None, type=str, - choices=["fp8", "float8_e4m3fn"], - help="Data type for static quantize key and value. ", + help="KV cache backend for calibration-time quantize-dequant. " + "Supported values: 'fp8', 'float8_e4m3fn', 'turboquant', 'turboquant:2', 'turboquant:3', 'turboquant:4'.", ) scheme.add_argument( diff --git a/auto_round/experimental/kv_cache.py b/auto_round/experimental/kv_cache.py index 1774315d7..1a3de1157 100644 --- a/auto_round/experimental/kv_cache.py +++ b/auto_round/experimental/kv_cache.py @@ -17,13 +17,23 @@ import contextlib +from dataclasses import dataclass from enum import Enum from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import torch -from transformers.cache_utils import DynamicCache - +from transformers.cache_utils import Cache, CacheLayerMixin, DynamicCache + +from auto_round.experimental.turboquant import ( + QJLResidualConfig, + TurboQuantConfig, + TurboQuantPackedTensor, + build_turboquant_state, + turboquant_pack, + turboquant_qdq, + turboquant_unpack, +) from auto_round.experimental.utils import ( is_attention_module, normalize_static_kv_dtype, @@ -34,12 +44,172 @@ __all__ = [ "initialize_quantized_kv_cache", + "normalize_kv_cache_backend_config", "prep_attention_module_for_calibration", "freeze_module_quantization_", "kvcache_quant_context", + "TurboQuantPackedKVCache", + "build_turboquant_runtime_cache", ] +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +@dataclass(frozen=True) +class KVCacheBackendConfig: + backend: str + dtype: torch.dtype | None = None + bits: int | None = None + seed: int = 42 + packed: bool = False + residual_length: int = 128 + qjl_residual: bool = False + + +def _normalize_backend_name(name: str) -> str: + return name.strip().lower().replace("_", "-") + + +def normalize_kv_cache_backend_config(spec: Union[str, torch.dtype, Dict[str, Any], KVCacheBackendConfig]): + if isinstance(spec, KVCacheBackendConfig): + return spec + + if isinstance(spec, dict): + backend = _normalize_backend_name(spec.get("backend") or spec.get("name") or spec.get("dtype") or "") + if backend in ("fp8", "float8-e4m3fn", "float8_e4m3fn"): + dtype = normalize_static_kv_dtype(spec.get("dtype", "fp8")) + return KVCacheBackendConfig(backend="fp8", dtype=dtype, seed=int(spec.get("seed", 42))) + if backend == "turboquant": + return KVCacheBackendConfig( + backend="turboquant", + bits=int(spec.get("bits", 4)), + seed=int(spec.get("seed", 42)), + packed=bool(spec.get("packed", False)), + residual_length=int(spec.get("residual_length", 128)), + qjl_residual=bool(spec.get("qjl_residual", False)), + ) + raise ValueError(f"Unsupported kv cache backend config: {spec}") + + if isinstance(spec, torch.dtype): + return KVCacheBackendConfig(backend="fp8", dtype=normalize_static_kv_dtype(spec)) + + if not isinstance(spec, str): + raise TypeError(f"Unsupported kv cache backend spec type: {type(spec)}") + + normalized = spec.strip().lower() + if normalized in ("fp8", "float8_e4m3fn"): + return KVCacheBackendConfig(backend="fp8", dtype=normalize_static_kv_dtype(normalized)) + + if normalized.startswith("turboquant"): + bits = 4 + packed = False + residual_length = 128 + qjl_residual = False + tokens = normalized.split(":") + for token in tokens[1:]: + if token.isdigit(): + bits = int(token) + elif token == "packed": + packed = True + elif token.startswith("residual="): + residual_length = int(token.split("=", 1)[1]) + elif token in ("qjl",): + qjl_residual = True + return KVCacheBackendConfig( + backend="turboquant", + bits=bits, + seed=42, + packed=packed, + residual_length=residual_length, + qjl_residual=qjl_residual, + ) + + raise ValueError( + "Invalid static kv dtype/backend: %s. Supported values include 'fp8', 'float8_e4m3fn', " + "'turboquant', and 'turboquant:2|3|4'." % spec + ) + + +class KVCacheBackend: + name = "base" + + def __init__(self, config: KVCacheBackendConfig): + self.config = config + + def init_module_parameters(self, module: torch.nn.Module): + init_scale = torch.tensor([0.0], device=next(module.parameters()).device) + update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value) + update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value) + + def reset(self): + return None + + def quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): + raise NotImplementedError + + +class FP8KVCacheBackend(KVCacheBackend): + name = "fp8" + + def __init__(self, config: KVCacheBackendConfig): + super().__init__(config) + if config.dtype != torch.float8_e4m3fn: + raise ValueError(f"Only fp8_e4m3fn KV cache is supported, but got {config.dtype}.") + + def quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): + del kv_type, layer_idx + return per_tensor_fp8_qdq(tensor) + + +class TurboQuantKVCacheBackend(KVCacheBackend): + name = "turboquant" + + def __init__(self, config: KVCacheBackendConfig): + super().__init__(config) + self.turboquant_config = TurboQuantConfig(bits=config.bits or 4, seed=config.seed) + self._state_cache: Dict[tuple[int, str, int, str], Any] = {} + + def reset(self): + self._state_cache = {} + + def _get_state(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): + head_dim = tensor.shape[-1] + state_key = (layer_idx, kv_type.value, head_dim, str(tensor.device)) + if state_key not in self._state_cache: + state_seed = self.turboquant_config.seed + layer_idx * 17 + (0 if kv_type == KVCacheScaleType.KEY else 1) + self._state_cache[state_key] = build_turboquant_state( + head_dim=head_dim, + bits=self.turboquant_config.bits, + seed=state_seed, + device=tensor.device, + ) + return self._state_cache[state_key] + + def quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): + state = self._get_state(tensor, kv_type, layer_idx) + return turboquant_qdq(tensor, state, eps=self.turboquant_config.eps) + + +def build_kv_cache_backend(config: KVCacheBackendConfig) -> KVCacheBackend: + if config.backend == "fp8": + return FP8KVCacheBackend(config) + if config.backend == "turboquant": + return TurboQuantKVCacheBackend(config) + raise ValueError(f"Unsupported kv cache backend {config.backend}") + + +def _cleanup_kv_cache_hooks(module: torch.nn.Module): + hook_handles = getattr(module, "_kv_cache_hook_handles", None) + if hook_handles is None: + return + for handle in hook_handles: + handle.remove() + delattr(module, "_kv_cache_hook_handles") + + def freeze_module_quantization_(module: torch.nn.Module): """ deletes observers when calibration is complete. @@ -49,21 +219,16 @@ def freeze_module_quantization_(module: torch.nn.Module): :param module: module to freeze quantization for """ - # remove observers if needed for name in ("input", "weight", "output"): obs_name = f"{name}_observer" if hasattr(module, obs_name): delattr(module, obs_name) - # remove quantized kv_cache kv_cache = getattr(module, "kv_cache", None) if isinstance(kv_cache, QuantizedKVParameterCache): delattr(module, "kv_cache") - -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" + _cleanup_kv_cache_hooks(module) # NOTE: Using _ suffix to denote l is modified in place @@ -94,28 +259,23 @@ class QuantizedKVParameterCache(DynamicCache): Each time forward is called, .update() is called, and ._quant_dequant() gets called appropriately. The size of tensor is `[batch_size, num_heads, seq_len - residual_length, head_dim]`. - """ _instance = None _initialized = False def __new__(cls, *args, **kwargs): - """Singleton""" if cls._instance is None: cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls) return cls._instance - def __init__(self, dtype: torch.dtype = torch.float8_e4m3fn): - - assert dtype == torch.float8_e4m3fn, "Only fp8_e4m3fn is supported for now." + def __init__(self, config: KVCacheBackendConfig): if not self._initialized: super().__init__() - - # each index corresponds to layer_idx of the attention layer - self.k_scales: List[torch.Tensor] = [] - self.v_scales: List[torch.Tensor] = [] self._initialized = True + self.backend_config = config + self.backend = build_kv_cache_backend(config) + self.reset_states() def update( self, @@ -124,61 +284,447 @@ def update( layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Get the k_scale and v_scale and output the quant-dequant key_states and value_states - """ + del cache_kwargs qdq_key_states = self._quant_dequant(key_states.contiguous(), KVCacheScaleType.KEY, layer_idx) qdq_value_states = self._quant_dequant(value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx) - - keys_to_return, values_to_return = qdq_key_states, qdq_value_states - - return keys_to_return, values_to_return + return qdq_key_states, qdq_value_states def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """ - Returns the sequence length of the cached states. - A layer index can be optionally passed. - """ if len(self.key_cache) <= layer_idx: return 0 - # since we cannot get the seq_length of each layer directly and - # rely on `_seen_tokens` which is updated every "layer_idx" == 0, - # this is a hack to get the actual seq_length for the given layer_idx - # this part of code otherwise fails when used to - # verify attn_weight shape in some models return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1 def reset_states(self): - """reset the kv states (used in calibration)""" self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = 0 self._quantized_key_cache: List[torch.Tensor] = [] self._quantized_value_cache: List[torch.Tensor] = [] + self.k_scales: List[torch.Tensor] = [] + self.v_scales: List[torch.Tensor] = [] + if hasattr(self, "backend"): + self.backend.reset() def reset(self): - """ - Reset the instantiation, create new instance on init - """ QuantizedKVParameterCache._instance = None QuantizedKVParameterCache._initialized = False def _quant_dequant(self, tensor: torch.Tensor, kv_type: KVCacheScaleType, layer_idx: int): - """Quantizes a key/value using a defined quantization method.""" - if kv_type == KVCacheScaleType.KEY: # key type - scales = self.k_scales - else: - assert kv_type == KVCacheScaleType.VALUE - scales = self.v_scales - - qdq_tensor, scale = per_tensor_fp8_qdq(tensor) - # Detach scale to prevent holding computation graph references + scales = self.k_scales if kv_type == KVCacheScaleType.KEY else self.v_scales + qdq_tensor, scale = self.backend.quant_dequant(tensor, kv_type, layer_idx) _pad_and_append_at_idx_(scales, layer_idx, scale.squeeze(0).detach()) return qdq_tensor -def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4m3fn): +class TurboQuantPackedKVCacheLayer(CacheLayerMixin): + is_compilable = False + + def __init__( + self, + bits: int = 4, + residual_length: int = 128, + seed: int = 42, + qjl_residual: bool = False, + ): + super().__init__() + self.bits = bits + self.residual_length = residual_length + self.seed = seed + self.qjl_residual = qjl_residual + self.cumulative_length = 0 + self._packed_key_segments: list[TurboQuantPackedTensor] = [] + self._packed_value_segments: list[TurboQuantPackedTensor] = [] + self._key_state = None + self._value_state = None + self._residual_config = None + self.dtype = None + self.device = None + self._batch_size = None + self._num_heads = None + self._k_head_dim = None + self._v_head_dim = None + + def lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor) -> None: + self.dtype = key_states.dtype + self.device = key_states.device + self._batch_size, self._num_heads = key_states.shape[:2] + self._k_head_dim = key_states.shape[-1] + self._v_head_dim = value_states.shape[-1] + self.keys = torch.empty( + (self._batch_size, self._num_heads, 0, self._k_head_dim), dtype=self.dtype, device=self.device + ) + self.values = torch.empty( + (self._batch_size, self._num_heads, 0, self._v_head_dim), dtype=self.dtype, device=self.device + ) + self._residual_config = QJLResidualConfig( + enabled=self.qjl_residual, + seed=self.seed + 7919, + ) + self._key_state = build_turboquant_state( + self._k_head_dim, + self.bits, + self.seed, + self.device, + qjl_config=self._residual_config, + ) + self._value_state = build_turboquant_state( + self._v_head_dim, + self.bits, + self.seed + 1, + self.device, + qjl_config=self._residual_config, + ) + self.is_initialized = True + + def _empty_like_keys(self): + return torch.empty( + (self._batch_size, self._num_heads, 0, self._k_head_dim), dtype=self.dtype, device=self.device + ) + + def _empty_like_values(self): + return torch.empty( + (self._batch_size, self._num_heads, 0, self._v_head_dim), dtype=self.dtype, device=self.device + ) + + def _spill_residual_to_packed(self): + if self.residual_length < 0: + raise ValueError(f"residual_length must be >= 0, but got {self.residual_length}.") + + buf_len = self.keys.shape[-2] + # Only spill when buffer reaches 2x residual_length (or residual_length + # if it's 0). This ensures each packed segment contains at least + # residual_length tokens, avoiding 1-token segments during token-by-token + # decode that would cause O(n²) unpack overhead. + spill_threshold = max(2 * self.residual_length, 1) + if buf_len < spill_threshold: + return + + spill = buf_len - self.residual_length + + prefix_keys = self.keys[..., :spill, :].contiguous() + prefix_values = self.values[..., :spill, :].contiguous() + if prefix_keys.numel() > 0: + self._packed_key_segments.append( + turboquant_pack(prefix_keys, self._key_state, residual_config=self._residual_config) + ) + self._packed_value_segments.append( + turboquant_pack(prefix_values, self._value_state, residual_config=self._residual_config) + ) + + self.keys = self.keys[..., spill:, :].contiguous() if self.residual_length > 0 else self._empty_like_keys() + self.values = ( + self.values[..., spill:, :].contiguous() if self.residual_length > 0 else self._empty_like_values() + ) + # NOTE: we intentionally do NOT merge packed segments by dequantize→requantize, + # because each round of requantization compounds quantization noise. + # With token-by-token decode and small residual_length, the oldest tokens + # could be requantized 100+ times, destroying quality completely. + # Instead, we keep segments as-is and pay O(n_segments) unpack cost per step. + + def _dequantize_segments(self, packed_segments, state, empty_tensor): + if len(packed_segments) == 0: + return empty_tensor + if len(packed_segments) == 1: + return turboquant_unpack(packed_segments[0], state, dtype=self.dtype, residual_config=self._residual_config) + + # Merge all packed segments into one and do a single unpack call. + # This turns O(n_segments) kernel launches into O(1). + # Safe because each segment's packed bytes are byte-aligned + # (n_values_per_token * bits is always divisible by 8). + base = packed_segments[0] + total_seq = sum(s.original_shape[-2] for s in packed_segments) + merged_shape = base.original_shape[:-2] + (total_seq,) + base.original_shape[-1:] + + merged_codes = torch.cat([s.packed_codes for s in packed_segments]) + merged_norms = torch.cat([s.norms for s in packed_segments], dim=-2) + + qjl_signs = None + qjl_norms = None + if base.qjl_packed_signs is not None: + qjl_signs = torch.cat([s.qjl_packed_signs for s in packed_segments]) + qjl_norms = torch.cat([s.qjl_norms for s in packed_segments], dim=-1) + + merged = TurboQuantPackedTensor( + packed_codes=merged_codes, + norms=merged_norms, + original_shape=merged_shape, + bits=base.bits, + qjl_packed_signs=qjl_signs, + qjl_norms=qjl_norms, + ) + return turboquant_unpack(merged, state, dtype=self.dtype, residual_config=self._residual_config) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + del cache_kwargs + if not self.is_initialized: + self.lazy_initialization(key_states, value_states) + + self.cumulative_length += key_states.shape[-2] + self.keys = torch.cat([self.keys, key_states], dim=-2) + self.values = torch.cat([self.values, value_states], dim=-2) + self._spill_residual_to_packed() + + dequantized_keys = self._dequantize_segments( + self._packed_key_segments, self._key_state, self._empty_like_keys() + ) + dequantized_values = self._dequantize_segments( + self._packed_value_segments, self._value_state, self._empty_like_values() + ) + keys_to_return = torch.cat([dequantized_keys, self.keys], dim=-2) + values_to_return = torch.cat([dequantized_values, self.values], dim=-2) + return keys_to_return, values_to_return + + def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: + kv_offset = 0 + kv_length = self.get_seq_length() + cache_position.shape[0] + return kv_length, kv_offset + + def get_seq_length(self) -> int: + return self.cumulative_length + + def get_max_cache_shape(self) -> int: + return -1 + + def _move_packed_segments(self, segments, device: str): + moved = [] + for segment in segments: + moved.append( + TurboQuantPackedTensor( + packed_codes=segment.packed_codes.to(device, non_blocking=True), + norms=segment.norms.to(device, non_blocking=True), + original_shape=segment.original_shape, + bits=segment.bits, + qjl_packed_signs=( + None + if segment.qjl_packed_signs is None + else segment.qjl_packed_signs.to(device, non_blocking=True) + ), + qjl_norms=(None if segment.qjl_norms is None else segment.qjl_norms.to(device, non_blocking=True)), + ) + ) + return moved + + def offload(self): + super().offload() + self._packed_key_segments = self._move_packed_segments(self._packed_key_segments, "cpu") + self._packed_value_segments = self._move_packed_segments(self._packed_value_segments, "cpu") + + def prefetch(self): + super().prefetch() + if ( + self.is_initialized + and len(self._packed_key_segments) > 0 + and self._packed_key_segments[0].packed_codes.device != self.device + ): + self._packed_key_segments = self._move_packed_segments(self._packed_key_segments, self.device) + self._packed_value_segments = self._move_packed_segments(self._packed_value_segments, self.device) + + def reset(self) -> None: + if self.is_initialized: + self.keys = self._empty_like_keys() + self.values = self._empty_like_values() + self._packed_key_segments = [] + self._packed_value_segments = [] + self.cumulative_length = 0 + + def _repack_segments_with_batch_indices(self, segments, state, indices: torch.Tensor): + repacked = [] + for segment in segments: + unpacked = turboquant_unpack(segment, state, dtype=self.dtype, residual_config=self._residual_config) + repacked.append( + turboquant_pack( + unpacked.index_select(0, indices.to(unpacked.device)), state, residual_config=self._residual_config + ) + ) + return repacked + + def reorder_cache(self, beam_idx: torch.LongTensor) -> None: + if self.keys is not None and self.keys.numel() > 0: + self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device)) + self.values = self.values.index_select(0, beam_idx.to(self.values.device)) + self._packed_key_segments = self._repack_segments_with_batch_indices( + self._packed_key_segments, self._key_state, beam_idx + ) + self._packed_value_segments = self._repack_segments_with_batch_indices( + self._packed_value_segments, self._value_state, beam_idx + ) + + def batch_repeat_interleave(self, repeats: int) -> None: + if self.keys is not None and self.keys.numel() > 0: + self.keys = self.keys.repeat_interleave(repeats, dim=0) + self.values = self.values.repeat_interleave(repeats, dim=0) + indices = torch.arange(self._batch_size, device=self.device).repeat_interleave(repeats) + self._packed_key_segments = self._repack_segments_with_batch_indices( + self._packed_key_segments, self._key_state, indices + ) + self._packed_value_segments = self._repack_segments_with_batch_indices( + self._packed_value_segments, self._value_state, indices + ) + self._batch_size *= repeats + + def batch_select_indices(self, indices: torch.Tensor) -> None: + if self.keys is not None and self.keys.numel() > 0: + self.keys = self.keys[indices, ...] + self.values = self.values[indices, ...] + self._packed_key_segments = self._repack_segments_with_batch_indices( + self._packed_key_segments, self._key_state, indices + ) + self._packed_value_segments = self._repack_segments_with_batch_indices( + self._packed_value_segments, self._value_state, indices + ) + self._batch_size = indices.numel() + + def packed_memory_bytes(self) -> int: + return sum(segment.memory_bytes() for segment in self._packed_key_segments + self._packed_value_segments) + + def residual_memory_bytes(self) -> int: + return self.keys.numel() * self.keys.element_size() + self.values.numel() * self.values.element_size() + + def raw_memory_bytes(self) -> int: + return ( + self.cumulative_length + * self._batch_size + * self._num_heads + * (self._k_head_dim + self._v_head_dim) + * torch.tensor([], dtype=self.dtype).element_size() + ) + + +class TurboQuantPackedKVCache(Cache): + def __init__( + self, + bits: int = 4, + residual_length: int = 128, + seed: int = 42, + qjl_residual: bool = False, + offloading: bool = False, + offload_only_non_sliding: bool = False, + ): + super().__init__(layers=[], offloading=offloading, offload_only_non_sliding=offload_only_non_sliding) + self.bits = bits + self.residual_length = residual_length + self.seed = seed + self.qjl_residual = qjl_residual + + def _new_layer(self, layer_idx: int) -> TurboQuantPackedKVCacheLayer: + return TurboQuantPackedKVCacheLayer( + bits=self.bits, + residual_length=self.residual_length, + seed=self.seed + layer_idx * 17, + qjl_residual=self.qjl_residual, + ) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + while len(self.layers) <= layer_idx: + self.layers.append(self._new_layer(len(self.layers))) + return super().update(key_states, value_states, layer_idx, cache_kwargs) + + def packed_memory_bytes(self) -> int: + return sum(layer.packed_memory_bytes() for layer in self.layers) + + def residual_memory_bytes(self) -> int: + return sum(layer.residual_memory_bytes() for layer in self.layers) + + def total_memory_bytes(self) -> int: + return self.packed_memory_bytes() + self.residual_memory_bytes() + + def raw_memory_bytes(self) -> int: + return sum(layer.raw_memory_bytes() for layer in self.layers) + + def compression_ratio(self) -> float: + packed_bytes = max(self.total_memory_bytes(), 1) + return self.raw_memory_bytes() / packed_bytes + + +class TurboQuantPreDequantCache(DynamicCache): + """quantize→dequantize K/V, store bf16 in standard cache. + + At write time: K,V → encode → decode → store dequantized bf16 + At read time: read bf16, zero overhead (standard DynamicCache) + + This correctly simulates the quantization error without the decode overhead + at attention time. No actual memory compression. + """ + + def __init__(self, bits: int = 4, seed: int = 42, qjl_residual: bool = False): + super().__init__() + self.bits = bits + self.seed = seed + self.qjl_residual = qjl_residual + self._states: dict[int, object] = {} + + def _get_state(self, layer_idx: int, head_dim: int, device: torch.device): + if layer_idx not in self._states: + qjl_config = QJLResidualConfig(enabled=self.qjl_residual, seed=1729) if self.qjl_residual else None + self._states[layer_idx] = build_turboquant_state( + head_dim=head_dim, + bits=self.bits, + seed=self.seed + layer_idx * 17, + device=device, + qjl_config=qjl_config, + ) + return self._states[layer_idx] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + head_dim = key_states.shape[-1] + state = self._get_state(layer_idx, head_dim, key_states.device) + residual_config = QJLResidualConfig(enabled=self.qjl_residual, seed=1729) if self.qjl_residual else None + # Pre-dequant: quant→dequant before storing + key_dq, _ = turboquant_qdq(key_states, state, residual_config=residual_config) + value_dq, _ = turboquant_qdq(value_states, state, residual_config=residual_config) + return super().update(key_dq, value_dq, layer_idx, cache_kwargs) + + +def build_turboquant_runtime_cache( + bits: int = 4, + residual_length: int = 128, + seed: int = 42, + qjl_residual: bool = False, + offloading: bool = False, + offload_only_non_sliding: bool = False, + mode: str = "packed", +) -> Cache: + """Build a TurboQuant KV cache. + + Args: + mode: "packed" stores bit-packed codes (real compression, higher latency), + "pre_dequant" stores bf16 after quant→dequant (zero read overhead, + no compression — matches vLLM Phase 1 approach). + """ + if mode == "pre_dequant": + return TurboQuantPreDequantCache( + bits=bits, + seed=seed, + qjl_residual=qjl_residual, + ) + return TurboQuantPackedKVCache( + bits=bits, + residual_length=residual_length, + seed=seed, + qjl_residual=qjl_residual, + offloading=offloading, + offload_only_non_sliding=offload_only_non_sliding, + ) + + +def initialize_quantized_kv_cache(module: torch.nn.Module, config: KVCacheBackendConfig): """ Initialize a quantized kv_cache on a module (analogous to initializing an observer) """ @@ -186,28 +732,22 @@ def initialize_quantized_kv_cache(module: torch.nn.Module, dtype=torch.float8_e4 return existing_kv_cache = getattr(module, "kv_cache", None) - if isinstance(existing_kv_cache, QuantizedKVParameterCache): + if isinstance(existing_kv_cache, QuantizedKVParameterCache) and existing_kv_cache.backend_config == config: return - quantized_kv_cache = QuantizedKVParameterCache(dtype=dtype) + quantized_kv_cache = QuantizedKVParameterCache(config=config) setattr(module, "kv_cache", quantized_kv_cache) logger.debug(f"Initialized quantized kv_cache for {module.__class__.__name__} {getattr(module, 'layer_idx', None)}") - init_scale = torch.tensor([0.0], device=next(module.parameters()).device) - update_parameter_data(module, init_scale.clone(), KVCacheScaleType.KEY.value) - update_parameter_data(module, init_scale.clone(), KVCacheScaleType.VALUE.value) + quantized_kv_cache.backend.init_module_parameters(module) def calibrate_kv_cache_input_hook( module: torch.nn.Module, args: Any, kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: """ - Hook to update inputs to attention layers when running - kv_cache quantization. Will update the passed in - kv_cache to singleton QuantizedKVParameterCache. + Hook to update inputs to attention layers when running kv_cache quantization. """ kv_cache = getattr(module, "kv_cache") - # Start from transformers 4.55.2, the `past_key_value` was renamed to `past_key_values`. - # https://github.com/huggingface/transformers/blob/52c6c1bb6e27ca87c4faede34a4c2a7404c17c4d/src/transformers/models/llama/modeling_llama.py#L279-L280 if "past_key_values" in kwargs: kwargs["past_key_values"] = kv_cache else: @@ -229,26 +769,40 @@ def calibrate_kv_cache_output_hook(module: torch.nn.Module, _args: Any, _output: def prep_attention_module_for_calibration(module: torch.nn.Module): if is_attention_module(module): - module.register_forward_pre_hook(calibrate_kv_cache_input_hook, with_kwargs=True) - module.register_forward_hook(calibrate_kv_cache_output_hook) + if hasattr(module, "_kv_cache_hook_handles"): + return + pre_handle = module.register_forward_pre_hook(calibrate_kv_cache_input_hook, with_kwargs=True) + post_handle = module.register_forward_hook(calibrate_kv_cache_output_hook) + module._kv_cache_hook_handles = (pre_handle, post_handle) @contextlib.contextmanager def kvcache_quant_context(model: torch.nn.Module, static_kv_dtype=torch.float8_e4m3fn): - """Context manager for FP8 KV cache quantization operations.""" + """Context manager for KV cache quantization operations.""" try: - # Setup phase: Initialize KV cache for quantization - static_kv_dtype = normalize_static_kv_dtype(static_kv_dtype) - if static_kv_dtype != torch.float8_e4m3fn: - logger.warning(f"Ignoring static kv dtype {static_kv_dtype}, only fp8_e4m3fn is supported.") + backend_config = normalize_kv_cache_backend_config(static_kv_dtype) + attention_module_count = sum(1 for module in model.modules() if is_attention_module(module)) + if backend_config.backend == "turboquant": + logger.info( + "Enable KV cache backend turboquant (bits=%s, seed=%s) for %s attention modules.", + backend_config.bits, + backend_config.seed, + attention_module_count, + ) else: - initialize_fn = partial(initialize_quantized_kv_cache, dtype=static_kv_dtype) - model.apply(initialize_fn) - model.apply(prep_attention_module_for_calibration) - - # Provide the model to the with block + logger.info( + "Enable KV cache backend %s (dtype=%s) for %s attention modules.", + backend_config.backend, + backend_config.dtype, + attention_module_count, + ) + initialize_fn = partial(initialize_quantized_kv_cache, config=backend_config) + model.apply(initialize_fn) + model.apply(prep_attention_module_for_calibration) yield model finally: - # Cleanup phase: Freeze quantization parameters + model.apply(_cleanup_kv_cache_hooks) model.apply(freeze_module_quantization_) + QuantizedKVParameterCache._instance = None + QuantizedKVParameterCache._initialized = False diff --git a/auto_round/experimental/turboquant.py b/auto_round/experimental/turboquant.py new file mode 100644 index 000000000..70685c79f --- /dev/null +++ b/auto_round/experimental/turboquant.py @@ -0,0 +1,358 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from dataclasses import dataclass +from functools import lru_cache +from math import prod, sqrt +from typing import Optional + +import numpy as np +import torch + + +@dataclass(frozen=True) +class TurboQuantConfig: + bits: int = 4 + seed: int = 42 + codebook_samples: int = 65536 + codebook_iters: int = 24 + eps: float = 1e-6 + + +@dataclass(frozen=True) +class QJLResidualConfig: + """1-bit QJL residual correction for unbiased inner products (paper §4). + + Stores sign(residual @ S^T) as ±1 int8 + residual norm. + Reconstruction: sqrt(π/2) / head_dim * r_norm * (signs @ S). + """ + + enabled: bool = False + seed: int = 1729 + + +@dataclass +class TurboQuantState: + head_dim: int + bits: int + seed: int + rotation: torch.Tensor # (head_dim, head_dim) + inverse_rotation: torch.Tensor # (head_dim, head_dim) + codebook: torch.Tensor + boundaries: torch.Tensor + qjl_matrix: Optional[torch.Tensor] = None # (head_dim, head_dim) random projection + + +@dataclass +class TurboQuantPackedTensor: + packed_codes: torch.Tensor + norms: torch.Tensor # per-vector L2 norms + original_shape: tuple[int, ...] + bits: int + qjl_packed_signs: Optional[torch.Tensor] = None # uint8, bit-packed (1 bit/sign) + qjl_norms: Optional[torch.Tensor] = None # float16, residual norms, shape = original_shape[:-1] + + def memory_bytes(self) -> int: + size = self.packed_codes.numel() * self.packed_codes.element_size() + size += self.norms.numel() * self.norms.element_size() + if self.qjl_packed_signs is not None: + size += self.qjl_packed_signs.numel() * self.qjl_packed_signs.element_size() + if self.qjl_norms is not None: + size += self.qjl_norms.numel() * self.qjl_norms.element_size() + return size + + +def _make_generator(seed: int) -> torch.Generator: + generator = torch.Generator(device="cpu") + generator.manual_seed(seed) + return generator + + +@lru_cache(maxsize=None) +def _lloyd_max_codebook(bits: int, samples: int, iters: int) -> tuple[torch.Tensor, torch.Tensor]: + if bits < 1 or bits > 4: + raise ValueError(f"TurboQuant only supports 1-4 bits in the codebook solver, but got {bits}.") + + levels = 1 << bits + generator = _make_generator(1000 + bits) + data = torch.randn(samples, generator=generator, dtype=torch.float32) + centroids = torch.linspace(-2.5, 2.5, levels, dtype=torch.float32) + + for _ in range(iters): + boundaries = (centroids[:-1] + centroids[1:]) * 0.5 + bucket_ids = torch.bucketize(data, boundaries) + new_centroids = [] + for idx in range(levels): + mask = bucket_ids == idx + if mask.any(): + new_centroids.append(data[mask].mean()) + else: + new_centroids.append(centroids[idx]) + centroids = torch.stack(new_centroids) + + if levels % 2 == 0: + half = levels // 2 + positive = 0.5 * (centroids[half:] - centroids[:half].flip(0)) + positive = positive.abs() + centroids = torch.cat((-positive.flip(0), positive)) + + centroids = torch.sort(centroids).values + boundaries = (centroids[:-1] + centroids[1:]) * 0.5 + return centroids, boundaries + + +@lru_cache(maxsize=None) +def _rotation_matrix(head_dim: int, seed: int) -> torch.Tensor: + if head_dim <= 0: + raise ValueError(f"head_dim must be positive, but got {head_dim}.") + + generator = _make_generator(seed) + gaussian = torch.randn((head_dim, head_dim), generator=generator, dtype=torch.float32) + q_mat, r_mat = torch.linalg.qr(gaussian) + diag = torch.sign(torch.diag(r_mat)) + diag[diag == 0] = 1 + q_mat = q_mat * diag.unsqueeze(0) + return q_mat.contiguous() + + +@lru_cache(maxsize=None) +def _qjl_random_matrix(head_dim: int, seed: int) -> torch.Tensor: + """Generate a random Gaussian matrix S for QJL 1-bit projection. + + S is (head_dim, head_dim). The QJL transform stores sign(residual @ S^T). + """ + generator = _make_generator(seed) + return torch.randn((head_dim, head_dim), generator=generator, dtype=torch.float32) + + +def _pack_codes(codes: torch.Tensor, bits: int) -> torch.Tensor: + if bits <= 0 or bits > 8: + raise ValueError(f"bits must be in [1, 8], but got {bits}.") + + codes_np = codes.reshape(-1).to(torch.uint8).cpu().numpy().astype(np.uint64) + n = len(codes_np) + total_bits = n * bits + n_bytes = (total_bits + 7) // 8 + result = np.zeros(n_bytes, dtype=np.uint64) + + positions = np.arange(n, dtype=np.uint64) * bits + for b in range(bits): + bit_pos = positions + b + byte_idx = bit_pos >> 3 + bit_off = bit_pos & 7 + bit_val = (codes_np >> b) & 1 + np.add.at(result, byte_idx, bit_val << bit_off) + + return torch.tensor(result.astype(np.uint8), dtype=torch.uint8, device=codes.device) + + +def _unpack_codes(packed_codes: torch.Tensor, num_values: int, bits: int, device: torch.device) -> torch.Tensor: + if bits <= 0 or bits > 8: + raise ValueError(f"bits must be in [1, 8], but got {bits}.") + + packed_np = packed_codes.cpu().numpy().astype(np.uint64) + positions = np.arange(num_values, dtype=np.uint64) * bits + result = np.zeros(num_values, dtype=np.int64) + + for b in range(bits): + bit_pos = positions + b + byte_idx = bit_pos >> 3 + bit_off = bit_pos & 7 + bit_val = (packed_np[byte_idx] >> bit_off) & 1 + result |= bit_val.astype(np.int64) << b + + return torch.tensor(result, dtype=torch.long, device=device) + + +def build_turboquant_state( + head_dim: int, + bits: int, + seed: int, + device: torch.device, + qjl_config: Optional[QJLResidualConfig] = None, +) -> TurboQuantState: + if bits not in (2, 3, 4): + raise ValueError(f"TurboQuant only supports 2/3/4-bit KV cache quantization, but got {bits}.") + + rotation = _rotation_matrix(head_dim, seed).to(device=device) + codebook, boundaries = _lloyd_max_codebook(bits, 65536, 24) + + qjl_matrix = None + if qjl_config is not None and qjl_config.enabled: + qjl_matrix = _qjl_random_matrix(head_dim, qjl_config.seed).to(device=device) + + return TurboQuantState( + head_dim=head_dim, + bits=bits, + seed=seed, + rotation=rotation, + inverse_rotation=rotation.transpose(0, 1).contiguous(), + codebook=codebook.to(device=device), + boundaries=boundaries.to(device=device), + qjl_matrix=qjl_matrix, + ) + + +def turboquant_pack( + tensor: torch.Tensor, + state: TurboQuantState, + eps: float = 1e-6, + residual_config: Optional[QJLResidualConfig] = None, +) -> TurboQuantPackedTensor: + if tensor.numel() == 0: + return TurboQuantPackedTensor( + packed_codes=torch.empty(0, dtype=torch.uint8, device=tensor.device), + norms=torch.empty(0, dtype=tensor.dtype, device=tensor.device), + original_shape=tuple(tensor.shape), + bits=state.bits, + ) + + use_triton_bitpack = _HAS_TRITON and tensor.is_cuda + + # Encode: normalize → rotate → scale → bucketize (cuBLAS matmul is faster than Triton loop) + work_tensor = tensor.to(torch.float32) + norms = torch.linalg.vector_norm(work_tensor, dim=-1, keepdim=True).clamp_min(eps) + normalized = work_tensor / norms + rotated = torch.matmul(normalized, state.rotation) + scale = sqrt(state.head_dim) + scaled = rotated * scale + bucket_ids = torch.bucketize(scaled.reshape(-1), state.boundaries) + + # Bit-pack: Triton on CUDA, numpy on CPU + if use_triton_bitpack: + packed_codes = _triton_pack(bucket_ids, state.bits) + else: + packed_codes = _pack_codes(bucket_ids, state.bits) + + qjl_packed_signs = None + qjl_norms = None + if residual_config is not None and residual_config.enabled and state.qjl_matrix is not None: + quantized = state.codebook[bucket_ids].view(tensor.shape) / scale + reconstructed = torch.matmul(quantized, state.inverse_rotation) + residual = normalized - reconstructed + r_norm = torch.linalg.vector_norm(residual, dim=-1) + projected = torch.matmul(residual, state.qjl_matrix.T) + sign_bits = (projected >= 0).to(torch.uint8) # 0/1 + # Bit-pack signs: head_dim bits → head_dim/8 bytes per vector + if use_triton_bitpack: + qjl_packed_signs = _triton_pack(sign_bits.reshape(-1), 1) + else: + qjl_packed_signs = _pack_codes(sign_bits.reshape(-1), 1) + qjl_norms = r_norm.to(torch.float16) + + return TurboQuantPackedTensor( + packed_codes=packed_codes, + norms=norms.to(tensor.dtype), + original_shape=tuple(tensor.shape), + bits=state.bits, + qjl_packed_signs=qjl_packed_signs, + qjl_norms=qjl_norms, + ) + + +def turboquant_unpack( + packed: TurboQuantPackedTensor, + state: TurboQuantState, + dtype: torch.dtype = torch.float32, + residual_config: Optional[QJLResidualConfig] = None, +) -> torch.Tensor: + if len(packed.original_shape) == 0 or prod(packed.original_shape) == 0: + return torch.empty(packed.original_shape, dtype=dtype, device=state.rotation.device) + + num_values = prod(packed.original_shape) + use_triton_bitpack = _HAS_TRITON and state.rotation.is_cuda + + # Bit-unpack: Triton on CUDA, numpy on CPU + if use_triton_bitpack: + bucket_ids = _triton_unpack(packed.packed_codes, num_values, packed.bits) + else: + bucket_ids = _unpack_codes(packed.packed_codes, num_values, packed.bits, state.rotation.device) + + # Decode: codebook gather → inverse-rotate → scale (cuBLAS matmul always) + scale = sqrt(state.head_dim) + quantized = state.codebook[bucket_ids].view(packed.original_shape) / scale + reconstructed = torch.matmul(quantized, state.inverse_rotation) + + if packed.qjl_packed_signs is not None and packed.qjl_norms is not None: + if residual_config is None or not residual_config.enabled or state.qjl_matrix is None: + raise ValueError("QJL signs are present, but residual_config is missing/disabled or qjl_matrix is None.") + # Unpack 1-bit signs → ±1 float + num_sign_values = prod(packed.original_shape) + if use_triton_bitpack: + sign_bits = _triton_unpack(packed.qjl_packed_signs, num_sign_values, 1) + else: + sign_bits = _unpack_codes(packed.qjl_packed_signs, num_sign_values, 1, state.rotation.device) + signs_f = (sign_bits.to(torch.float32) * 2 - 1).view(packed.original_shape) + d = state.head_dim + qjl_scale = sqrt(math.pi / 2.0) / d + r_norms = packed.qjl_norms.to(torch.float32).unsqueeze(-1) + qjl_correction = qjl_scale * r_norms * torch.matmul(signs_f, state.qjl_matrix) + reconstructed = reconstructed + qjl_correction + + return (reconstructed * packed.norms.to(torch.float32)).to(dtype) + + +def turboquant_qdq( + tensor: torch.Tensor, + state: TurboQuantState, + eps: float = 1e-6, + residual_config: Optional[QJLResidualConfig] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if tensor.numel() == 0: + return tensor, torch.zeros(1, device=tensor.device, dtype=tensor.dtype) + + packed = turboquant_pack(tensor, state, eps=eps, residual_config=residual_config) + reconstructed = turboquant_unpack(packed, state, dtype=tensor.dtype, residual_config=residual_config) + avg_norm = packed.norms.mean().to(tensor.dtype).reshape(1) + return reconstructed, avg_norm + + +# --------------------------------------------------------------------------- +# Triton-accelerated encode/decode (optional, CUDA only) +# --------------------------------------------------------------------------- + +_HAS_TRITON = False +try: + from auto_round_extension.triton.turboquant import triton_pack_codes as _triton_pack + from auto_round_extension.triton.turboquant import triton_unpack_codes as _triton_unpack + from auto_round_extension.triton.turboquant import turboquant_decode as _triton_decode + from auto_round_extension.triton.turboquant import turboquant_encode as _triton_encode + + _HAS_TRITON = True +except ImportError: + pass + + +def has_triton_turboquant() -> bool: + """Check if Triton TurboQuant kernels are available.""" + return _HAS_TRITON and torch.cuda.is_available() + + +def turboquant_qdq_triton( + tensor: torch.Tensor, + state: TurboQuantState, +) -> torch.Tensor: + """Fast quantize→dequantize using cuBLAS matmul + Triton bitpack (no QJL). + + Input: (..., head_dim) on CUDA. + Output: same shape, same dtype. + """ + if not has_triton_turboquant(): + raise RuntimeError("Triton TurboQuant kernels not available.") + + packed = turboquant_pack(tensor, state) + return turboquant_unpack(packed, state, dtype=tensor.dtype) diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index f0aec180a..3a4af2d89 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -1305,7 +1305,7 @@ def mv_module_from_gpu(module): for attr_name in list(module._parameters.keys()): p = module._parameters[attr_name] if p is not None and p.device.type != "meta" and p.device.type != "cpu": - module._parameters[attr_name] = p.to("cpu") + module._parameters[attr_name] = torch.nn.Parameter(p.to("cpu"), requires_grad=p.requires_grad) for attr_name in list(module._buffers.keys()): b = module._buffers[attr_name] if b is not None and b.device.type != "meta" and b.device.type != "cpu": diff --git a/auto_round_extension/triton/turboquant.py b/auto_round_extension/triton/turboquant.py new file mode 100644 index 000000000..dbdf49025 --- /dev/null +++ b/auto_round_extension/triton/turboquant.py @@ -0,0 +1,403 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Triton kernels for TurboQuant KV cache encode/decode. + +Encode: normalize → rotate → scale → scalar quantize → store indices + norm +Decode: load indices → codebook lookup → scale⁻¹ → unrotate → scale by norm + +Each kernel processes one (token, head) pair as a single Triton program. +The rotation matrix multiply is done column-by-column inside the kernel +to keep the full head vector in SRAM. +""" + +from __future__ import annotations + +import math + +import torch +import triton +import triton.language as tl + + +def _next_power_of_2(n: int) -> int: + if n <= 0: + return 1 + return 1 << (n - 1).bit_length() + + +# --------------------------------------------------------------------------- +# Encode kernel +# --------------------------------------------------------------------------- + + +@triton.jit +def _turboquant_encode_kernel( + # Input: [num_tokens, num_kv_heads, head_size] + x_ptr, + # Rotation matrix PiT: [head_size, head_size], row-major + pit_ptr, + # Boundaries: [num_centroids - 1] + boundaries_ptr, + # Output indices: [num_tokens, num_kv_heads, head_size] as uint8 + indices_ptr, + # Output norms: [num_tokens, num_kv_heads] as float32 + norms_ptr, + # Shapes + head_size: tl.constexpr, + num_boundaries: tl.constexpr, + # Scale factor: sqrt(head_size) + scale: tl.constexpr, + # Strides + x_stride_token: tl.int64, + x_stride_head: tl.int64, + idx_stride_token: tl.int64, + idx_stride_head: tl.int64, + norm_stride_token: tl.int64, + # Padded dim (power of 2) + BLOCK_D: tl.constexpr, +): + """Encode one (token, head): normalize → rotate → quantize.""" + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + dim_offs = tl.arange(0, BLOCK_D) + mask = dim_offs < head_size + + # Load input vector + x_base = token_idx * x_stride_token + head_idx * x_stride_head + x_vec = tl.load(x_ptr + x_base + dim_offs, mask=mask, other=0.0).to(tl.float32) + + # L2 norm + norm_sq = tl.sum(x_vec * x_vec, axis=0) + norm = tl.sqrt(norm_sq + 1e-12) + x_normed = x_vec / norm + + idx_base = token_idx * idx_stride_token + head_idx * idx_stride_head + + # For each output dim j: compute rotated[j] = dot(x_normed, PiT[:, j]) + for j in range(head_size): + pit_col = tl.load(pit_ptr + dim_offs * head_size + j, mask=mask, other=0.0) + y_j = tl.sum(x_normed * pit_col, axis=0) * scale + + # Scalar quantize: count how many boundaries y_j exceeds + idx = tl.zeros([], dtype=tl.int32) + for b in range(num_boundaries): + bnd = tl.load(boundaries_ptr + b) + idx = idx + (y_j > bnd).to(tl.int32) + + tl.store(indices_ptr + idx_base + j, idx.to(tl.uint8)) + + # Store norm + tl.store(norms_ptr + token_idx * norm_stride_token + head_idx, norm) + + +# --------------------------------------------------------------------------- +# Decode kernel +# --------------------------------------------------------------------------- + + +@triton.jit +def _turboquant_decode_kernel( + # Input indices: [num_tokens, num_kv_heads, head_size] as uint8 + indices_ptr, + # Input norms: [num_tokens, num_kv_heads] as float32 + norms_ptr, + # Rotation matrix Pi: [head_size, head_size], row-major + pi_ptr, + # Codebook: [num_centroids] + codebook_ptr, + # Output: [num_tokens, num_kv_heads, head_size] + out_ptr, + # Shapes + head_size: tl.constexpr, + # Scale factor: sqrt(head_size) + scale: tl.constexpr, + # Strides + idx_stride_token: tl.int64, + idx_stride_head: tl.int64, + norm_stride_token: tl.int64, + out_stride_token: tl.int64, + out_stride_head: tl.int64, + # Padded dim + BLOCK_D: tl.constexpr, + OUTPUT_BF16: tl.constexpr, +): + """Decode one (token, head): codebook lookup → unrotate → scale.""" + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + dim_offs = tl.arange(0, BLOCK_D) + mask = dim_offs < head_size + + # Load indices and codebook lookup + idx_base = token_idx * idx_stride_token + head_idx * idx_stride_head + indices = tl.load(indices_ptr + idx_base + dim_offs, mask=mask, other=0).to(tl.int32) + # Codebook gather, then divide by scale (undo the sqrt(d) scaling) + reconstructed = tl.load(codebook_ptr + indices) / scale + reconstructed = tl.where(mask, reconstructed, 0.0) + + # Load norm + norm = tl.load(norms_ptr + token_idx * norm_stride_token + head_idx).to(tl.float32) + + # Unrotate: out[j] = sum_i(reconstructed[i] * Pi[i, j]) * norm + out_base = token_idx * out_stride_token + head_idx * out_stride_head + + for j in range(head_size): + pi_col = tl.load(pi_ptr + dim_offs * head_size + j, mask=mask, other=0.0) + val = tl.sum(reconstructed * pi_col, axis=0) * norm + + if OUTPUT_BF16: + tl.store(out_ptr + out_base + j, val.to(tl.bfloat16)) + else: + tl.store(out_ptr + out_base + j, val.to(tl.float16)) + + +# --------------------------------------------------------------------------- +# Python wrappers +# --------------------------------------------------------------------------- + + +def turboquant_encode( + x: torch.Tensor, # [num_tokens, num_kv_heads, head_size] + pit: torch.Tensor, # [head_size, head_size] rotation^T + codebook: torch.Tensor, # [num_centroids] + boundaries: torch.Tensor, # [num_centroids - 1] + head_dim: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Encode K or V vectors using TurboQuant. + + Returns: + indices: [num_tokens, num_kv_heads, head_size] uint8 + norms: [num_tokens, num_kv_heads] float32 + """ + num_tokens, num_kv_heads, head_size = x.shape + if head_dim is None: + head_dim = head_size + num_boundaries = boundaries.shape[0] + BLOCK_D = _next_power_of_2(head_size) + scale = math.sqrt(head_dim) + + indices = torch.empty( + (num_tokens, num_kv_heads, head_size), + dtype=torch.uint8, + device=x.device, + ) + norms = torch.empty( + (num_tokens, num_kv_heads), + dtype=torch.float32, + device=x.device, + ) + + grid = (num_tokens, num_kv_heads) + _turboquant_encode_kernel[grid]( + x_ptr=x, + pit_ptr=pit, + boundaries_ptr=boundaries, + indices_ptr=indices, + norms_ptr=norms, + head_size=head_size, + num_boundaries=num_boundaries, + scale=scale, + x_stride_token=x.stride(0), + x_stride_head=x.stride(1), + idx_stride_token=indices.stride(0), + idx_stride_head=indices.stride(1), + norm_stride_token=norms.stride(0), + BLOCK_D=BLOCK_D, + num_warps=4, + num_stages=1, + ) + + return indices, norms + + +def turboquant_decode( + indices: torch.Tensor, # [num_tokens, num_kv_heads, head_size] uint8 + norms: torch.Tensor, # [num_tokens, num_kv_heads] float32 + pi: torch.Tensor, # [head_size, head_size] rotation matrix + codebook: torch.Tensor, # [num_centroids] + head_dim: int | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Decode TurboQuant indices back to K or V vectors. + + Returns: + out: [num_tokens, num_kv_heads, head_size] in output_dtype + """ + num_tokens, num_kv_heads, head_size = indices.shape + if head_dim is None: + head_dim = head_size + BLOCK_D = _next_power_of_2(head_size) + scale = math.sqrt(head_dim) + + out = torch.empty( + (num_tokens, num_kv_heads, head_size), + dtype=output_dtype, + device=indices.device, + ) + + grid = (num_tokens, num_kv_heads) + _turboquant_decode_kernel[grid]( + indices_ptr=indices, + norms_ptr=norms, + pi_ptr=pi, + codebook_ptr=codebook, + out_ptr=out, + head_size=head_size, + scale=scale, + idx_stride_token=indices.stride(0), + idx_stride_head=indices.stride(1), + norm_stride_token=norms.stride(0), + out_stride_token=out.stride(0), + out_stride_head=out.stride(1), + BLOCK_D=BLOCK_D, + OUTPUT_BF16=(output_dtype == torch.bfloat16), + num_warps=4, + num_stages=1, + ) + + return out + + +# --------------------------------------------------------------------------- +# Bit-pack / unpack kernels +# --------------------------------------------------------------------------- + + +@triton.jit +def _bitpack_kernel( + # Input: [N] uint8 codes (each in [0, 2^bits - 1]) + codes_ptr, + # Output: [n_bytes] uint8 packed + packed_ptr, + N, + n_bytes, + bits: tl.constexpr, + BLOCK: tl.constexpr, +): + """Pack N code values into bit-packed bytes. Each program handles BLOCK output bytes.""" + pid = tl.program_id(0) + byte_offs = pid * BLOCK + tl.arange(0, BLOCK) + byte_mask = byte_offs < n_bytes + + # Each output byte covers bit range [byte_offs*8, byte_offs*8 + 8). + # Build the byte bit-by-bit: for each of the 8 bit positions, find + # which code and which bit within that code contributes. + bit_start = byte_offs * 8 # vector[BLOCK] + result = tl.zeros([BLOCK], dtype=tl.int32) + + for bit_in_byte in range(8): + global_bit = bit_start + bit_in_byte + code_idx = global_bit // bits + bit_in_code = global_bit - code_idx * bits # = global_bit % bits + + valid = (code_idx < N) & byte_mask + code_val = tl.load(codes_ptr + code_idx, mask=valid, other=0).to(tl.int32) + bit_val = (code_val >> bit_in_code) & 1 + result = result | (bit_val << bit_in_byte) + + tl.store(packed_ptr + byte_offs, result.to(tl.uint8), mask=byte_mask) + + +@triton.jit +def _bitunpack_kernel( + # Input: [n_bytes] uint8 packed + packed_ptr, + # Output: [N] int64 codes + codes_ptr, + N, + n_bytes, + bits: tl.constexpr, + BLOCK: tl.constexpr, +): + """Unpack N code values from bit-packed bytes. Each program handles BLOCK codes.""" + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < N + + # Each code spans [offs*bits, offs*bits + bits) in the bit stream, + # crossing at most 2 bytes. Load both and extract. + bit_pos = offs * bits + byte_lo = bit_pos >> 3 # first byte index + bit_off = bit_pos & 7 # bit offset within that byte + + lo = tl.load(packed_ptr + byte_lo, mask=mask, other=0).to(tl.int32) + hi_mask = mask & ((byte_lo + 1) < n_bytes) + hi = tl.load(packed_ptr + byte_lo + 1, mask=hi_mask, other=0).to(tl.int32) + combined = lo | (hi << 8) # 16 bits is always enough for bits <= 8 + + code_mask = (1 << bits) - 1 + codes = (combined >> bit_off) & code_mask + + tl.store(codes_ptr + offs, codes.to(tl.int64), mask=mask) + + +# --------------------------------------------------------------------------- +# Python wrappers for bit-pack / unpack +# --------------------------------------------------------------------------- + + +def triton_pack_codes(codes: torch.Tensor, bits: int) -> torch.Tensor: + """Bit-pack a flat uint8/long tensor of codes on GPU. + + Args: + codes: 1-D tensor of quantization indices on CUDA. + bits: bits per code (2, 3, or 4). + + Returns: + Packed uint8 tensor on the same device. + """ + codes_flat = codes.reshape(-1).to(torch.uint8).contiguous() + N = codes_flat.shape[0] + n_bytes = (N * bits + 7) // 8 + packed = torch.empty(n_bytes, dtype=torch.uint8, device=codes.device) + + BLOCK = 1024 + grid = ((n_bytes + BLOCK - 1) // BLOCK,) + _bitpack_kernel[grid]( + codes_ptr=codes_flat, + packed_ptr=packed, + N=N, + n_bytes=n_bytes, + bits=bits, + BLOCK=BLOCK, + ) + return packed + + +def triton_unpack_codes(packed: torch.Tensor, num_values: int, bits: int) -> torch.Tensor: + """Unpack bit-packed bytes into a flat int64 tensor of codes on GPU. + + Args: + packed: 1-D uint8 tensor of packed bytes on CUDA. + num_values: number of code values to extract. + bits: bits per code (2, 3, or 4). + + Returns: + 1-D int64 tensor of codes on the same device. + """ + codes = torch.empty(num_values, dtype=torch.int64, device=packed.device) + + BLOCK = 1024 + grid = ((num_values + BLOCK - 1) // BLOCK,) + _bitunpack_kernel[grid]( + packed_ptr=packed.contiguous(), + codes_ptr=codes, + N=num_values, + n_bytes=packed.numel(), + bits=bits, + BLOCK=BLOCK, + ) + return codes diff --git a/test/test_cpu/quantization/test_kv_cache_backend.py b/test/test_cpu/quantization/test_kv_cache_backend.py new file mode 100644 index 000000000..a8a989291 --- /dev/null +++ b/test/test_cpu/quantization/test_kv_cache_backend.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + +from auto_round.experimental.kv_cache import ( + QuantizedKVParameterCache, + build_turboquant_runtime_cache, + kvcache_quant_context, + normalize_kv_cache_backend_config, +) +from auto_round.experimental.turboquant import ( + QJLResidualConfig, + build_turboquant_state, + turboquant_pack, + turboquant_qdq, + turboquant_unpack, +) + + +class TinySelfAttention(nn.Module): + def __init__(self, hidden_size=8): + super().__init__() + self.layer_idx = 0 + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.last_past_key_value = None + self.last_use_cache = None + + def forward(self, hidden_states, past_key_value=None, use_cache=True): + self.last_past_key_value = past_key_value + self.last_use_cache = use_cache + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + if hasattr(past_key_value, "update"): + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + return key_states + value_states + + +class TinyAttentionModel(nn.Module): + def __init__(self, hidden_size=8): + super().__init__() + self.self_attention = TinySelfAttention(hidden_size=hidden_size) + + def forward(self, hidden_states, past_key_value=None, use_cache=True): + return self.self_attention(hidden_states, past_key_value=past_key_value, use_cache=use_cache) + + +def test_normalize_kv_cache_backend_config_turboquant(): + config = normalize_kv_cache_backend_config("turboquant:3") + assert config.backend == "turboquant" + assert config.bits == 3 + + +def test_normalize_kv_cache_backend_config_fp8(): + config = normalize_kv_cache_backend_config("fp8") + assert config.backend == "fp8" + assert config.dtype == torch.float8_e4m3fn + + +def test_turboquant_qdq_shape_and_dtype(): + tensor = torch.randn(2, 4, 3, 16, dtype=torch.float32) + state = build_turboquant_state(head_dim=16, bits=4, seed=7, device=tensor.device) + reconstructed, avg_norm = turboquant_qdq(tensor, state) + + assert reconstructed.shape == tensor.shape + assert reconstructed.dtype == tensor.dtype + assert avg_norm.shape == (1,) + assert torch.isfinite(reconstructed).all() + assert torch.isfinite(avg_norm).all() + assert not torch.equal(reconstructed, tensor) + + +def test_kvcache_quant_context_with_tiny_attention_module(): + model = TinyAttentionModel(hidden_size=8) + hidden_states = torch.randn(2, 3, 8, dtype=torch.float32) + + with kvcache_quant_context(model, static_kv_dtype="turboquant:4"): + attention = model.self_attention + output = model(hidden_states, past_key_value="sentinel", use_cache=True) + + assert output.shape == hidden_states.shape + assert isinstance(attention.last_past_key_value, QuantizedKVParameterCache) + assert attention.last_use_cache is False + assert hasattr(attention, "k_scale") + assert hasattr(attention, "v_scale") + assert torch.isfinite(attention.k_scale).all() + assert torch.isfinite(attention.v_scale).all() + assert not torch.equal(attention.k_scale, torch.zeros_like(attention.k_scale)) + assert not torch.equal(attention.v_scale, torch.zeros_like(attention.v_scale)) + assert hasattr(attention, "_kv_cache_hook_handles") + assert hasattr(attention, "kv_cache") + + attention = model.self_attention + assert not hasattr(attention, "_kv_cache_hook_handles") + assert not hasattr(attention, "kv_cache") + + +def test_turboquant_pack_reduces_storage_bytes(): + tensor = torch.randn(2, 4, 32, 16, dtype=torch.float32) + state = build_turboquant_state(head_dim=16, bits=4, seed=7, device=tensor.device) + packed = turboquant_pack(tensor, state) + + assert packed.memory_bytes() < tensor.numel() * tensor.element_size() + reconstructed = turboquant_unpack(packed, state, dtype=tensor.dtype) + assert reconstructed.shape == tensor.shape + assert torch.isfinite(reconstructed).all() + + +def test_turboquant_qjl_residual_gives_unbiased_inner_product(): + """1-bit QJL makes inner product estimation unbiased (paper Theorem 2). + + MSE may increase slightly, but the mean bias of vs should be ~0. + """ + head_dim = 64 + n_vectors = 200 + torch.manual_seed(42) + x = torch.randn(n_vectors, 1, 1, head_dim, dtype=torch.float32) + x = x / x.norm(dim=-1, keepdim=True) # unit vectors + y = torch.randn(n_vectors, 1, 1, head_dim, dtype=torch.float32) + + qjl_config = QJLResidualConfig(enabled=True, seed=123) + state = build_turboquant_state(head_dim=head_dim, bits=2, seed=11, device=x.device, qjl_config=qjl_config) + + # Without QJL + packed_plain = turboquant_pack(x, state) + x_hat_plain = turboquant_unpack(packed_plain, state, dtype=x.dtype) + + # With QJL + packed_qjl = turboquant_pack(x, state, residual_config=qjl_config) + x_hat_qjl = turboquant_unpack(packed_qjl, state, dtype=x.dtype, residual_config=qjl_config) + + ip_true = (x * y).sum(dim=-1) + ip_plain = (x_hat_plain * y).sum(dim=-1) + ip_qjl = (x_hat_qjl * y).sum(dim=-1) + + bias_plain = (ip_plain - ip_true).mean().abs().item() + bias_qjl = (ip_qjl - ip_true).mean().abs().item() + + # QJL should have lower inner-product bias than plain quantization + assert bias_qjl < 0.1, f"QJL bias too large: {bias_qjl}" + # Also verify both produce finite results + assert torch.isfinite(x_hat_plain).all() + assert torch.isfinite(x_hat_qjl).all() + + +def test_runtime_turboquant_packed_cache_has_benefit_over_raw_kv(): + cache = build_turboquant_runtime_cache(bits=4, residual_length=4, seed=17, qjl_residual=True) + + key_states_1 = torch.randn(1, 2, 3, 16, dtype=torch.float32) + value_states_1 = torch.randn(1, 2, 3, 16, dtype=torch.float32) + key_states_2 = torch.randn(1, 2, 5, 16, dtype=torch.float32) + value_states_2 = torch.randn(1, 2, 5, 16, dtype=torch.float32) + + combined_keys = torch.cat([key_states_1, key_states_2], dim=-2) + combined_values = torch.cat([value_states_1, value_states_2], dim=-2) + + returned_keys_1, returned_values_1 = cache.update(key_states_1, value_states_1, layer_idx=0) + returned_keys_2, returned_values_2 = cache.update(key_states_2, value_states_2, layer_idx=0) + + assert returned_keys_1.shape == key_states_1.shape + assert returned_values_1.shape == value_states_1.shape + assert returned_keys_2.shape == combined_keys.shape + assert returned_values_2.shape == combined_values.shape + assert cache.get_seq_length(0) == combined_keys.shape[-2] + assert cache.packed_memory_bytes() > 0 + assert cache.total_memory_bytes() < cache.raw_memory_bytes() + assert cache.compression_ratio() > 1.0 + + reconstruction_error = torch.mean((returned_keys_2 - combined_keys) ** 2) + assert torch.isfinite(reconstruction_error) diff --git a/test/test_cpu/utils/test_utils.py b/test/test_cpu/utils/test_utils.py index 3dec97010..7b1330d12 100644 --- a/test/test_cpu/utils/test_utils.py +++ b/test/test_cpu/utils/test_utils.py @@ -1,6 +1,9 @@ from unittest.mock import patch +import torch + import auto_round.utils.device as auto_round_utils +from auto_round.utils.model import mv_module_from_gpu class TestPackingWithNumba: @@ -19,3 +22,28 @@ def test_tbb_installed_but_not_configured_right(self): @patch.object(auto_round_utils, "is_numba_available", lambda: False) def test_numba_not_installed(self): assert auto_round_utils.can_pack_with_numba() is False, "`can_pack_with_numba` should return False." + + +class _FakeAcceleratorParameter: + def __init__(self): + self.device = torch.device("cuda") + self.requires_grad = False + + def to(self, _device): + return torch.ones(1, dtype=torch.float32) + + +class _FakeMetaBuffer: + def __init__(self): + self.device = torch.device("meta") + + +class TestMetaMoveHelpers: + def test_mv_module_from_gpu_preserves_parameter_type(self): + module = torch.nn.Module() + module._parameters["fake_weight"] = _FakeAcceleratorParameter() + module._buffers["meta_marker"] = _FakeMetaBuffer() + + mv_module_from_gpu(module) + + assert isinstance(module._parameters["fake_weight"], torch.nn.Parameter)