Skip to content

Commit f19fdd0

Browse files
authored
Use KV cache constant names provided by compressed tensors (#1200)
## Purpose ## * Harden code related to the kv_cache parameter names --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent f861145 commit f19fdd0

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from typing import Any, Dict, Optional, Tuple
22

33
import torch
4-
from compressed_tensors.quantization import QuantizationStatus, is_attention_module
4+
from compressed_tensors.quantization import (
5+
KVCacheScaleType,
6+
QuantizationStatus,
7+
is_attention_module,
8+
)
59
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
610
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
711
from compressed_tensors.utils.offload import is_module_offloaded, update_parameter_data
@@ -194,8 +198,10 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
194198
Hook to update k_scale and v_scale parameters when running kv_cache quantization.
195199
"""
196200
kv_cache = getattr(module, "kv_cache")
197-
update_parameter_data(module, kv_cache.k_scales[module.layer_idx], "k_scale")
198-
update_parameter_data(module, kv_cache.v_scales[module.layer_idx], "v_scale")
201+
k_scale = kv_cache.k_scales[module.layer_idx]
202+
v_scale = kv_cache.v_scales[module.layer_idx]
203+
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
204+
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)
199205

200206

201207
def set_unset_kv_cache(module: Module):

0 commit comments

Comments
 (0)