Skip to content

Commit 316114a

Browse files
committed
refactor: always disable quantization during calibration, re-enable for propagation
- pipeline.py: remove disable_qac / DISABLE_QAC_MODIFIERS conditional logic; quantization is now unconditionally disabled during calibration pass and re-enabled during propagation pass so downstream subgraphs receive quantized inputs - quantization/base.py: remove erroneous disable_quantization call from on_start; control now lives entirely in pipeline layer - observers/base.py: move update_offload_parameter to top-level import - calibration.py: fix hook docstrings to accurately describe stats-only behavior Signed-off-by: dqzhengAP <dqzheng1996@gmail.com>
1 parent 26c29ad commit 316114a

File tree

4 files changed

+25
-42
lines changed

4 files changed

+25
-42
lines changed

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,6 @@ def calibrate_activations(
195195
# min/max stats but do NOT write scale/zero_point yet.
196196
# Qparams are written once at epoch end via flush_activation_qparams.
197197
if stats_only:
198-
# Deferred mode: accumulate global min/max into the observer's
199-
# _deferred_min / _deferred_max. Works for ALL observer types,
200-
# including MemorylessMinMaxObserver which has no past_min_vals.
201-
# Qparams are written once at epoch end via flush_activation_qparams.
202198
observer = getattr(module, f"{base_name}_observer", None)
203199
if observer is not None:
204200
observer.update_deferred_stats(value)
@@ -215,21 +211,20 @@ def calibrate_activations(
215211

216212
def calibrate_input_hook(module: Module, args: Any):
217213
"""
218-
Hook to calibrate input activations.
219-
Accumulates running min/max statistics in the observer without computing
220-
scale/zero_point. Qparams are computed once at epoch end via
221-
flush_activation_qparams (deferred mode).
214+
Hook to accumulate input activation statistics (min/max) in the observer.
215+
Scale and zero_point are not written here; they are computed once per subgraph
216+
at epoch end via flush_activation_qparams.
222217
"""
223218
args = args[0] if isinstance(args, tuple) else args
224219
calibrate_activations(module, value=args, base_name="input", stats_only=True)
225220

226221

227222
def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
228223
"""
229-
Hook to calibrate output activations.
230-
Accumulates running min/max statistics only (deferred qparam mode).
231-
Qparams are computed at epoch end; forward_quantize is skipped during
232-
calibration batches since quantization is disabled in the sequential pipeline.
224+
Hook to accumulate output activation statistics (min/max) in the observer.
225+
Scale and zero_point are not written here; they are computed once per subgraph
226+
at epoch end via flush_activation_qparams.
227+
Note: forward_quantize is intentionally absent — hooks only collect statistics.
233228
"""
234229
calibrate_activations(
235230
module,

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
6767
def on_start(self, state: State, event: Event, **kwargs):
6868
"""
6969
Begin calibrating activations and weights. Calibrate weights only once on start.
70-
Quantization is kept DISABLED during calibration batches so that forward passes
71-
run in fp32. Activation qparams are computed once per subgraph at
72-
SEQUENTIAL_EPOCH_END via flush_activation_qparams (deferred mode).
70+
Activation qparams are computed once per subgraph at SEQUENTIAL_EPOCH_END via
71+
flush_activation_qparams, rather than per batch.
7372
"""
7473
self.started_ = True
7574
QuantizationMixin.start_calibration(self, state.model)
@@ -94,21 +93,14 @@ def on_start(self, state: State, event: Event, **kwargs):
9493
for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"):
9594
update_weight_zp_scale(module)
9695

97-
# Disable quantization during calibration batches so that fp32 activations
98-
# flow through the model unmodified while hooks accumulate running stats.
99-
# Re-enable once after epoch end when qparams have been flushed.
100-
from compressed_tensors.quantization import disable_quantization
101-
102-
state.model.apply(disable_quantization)
103-
10496
def on_event(self, state: State, event: Event, **kwargs):
10597
if event.type_ == EventType.CALIBRATION_EPOCH_START:
10698
if not self.started_:
10799
self.on_start(state, None)
108100

109101
if event.type_ == EventType.SEQUENTIAL_EPOCH_END:
110-
# Deferred qparam flush: compute scale/zero_point from accumulated
111-
# running statistics, then free those stats to reduce memory.
102+
# Compute scale/zero_point once from accumulated running statistics,
103+
# then free those stats to reduce memory.
112104
for _, module in match_named_modules(
113105
state.model, self.resolved_targets, self.ignore
114106
):

src/llmcompressor/observers/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
88
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
99
from compressed_tensors.registry.registry import RegistryMixin
10-
from compressed_tensors.utils import align_module_device
11-
10+
from compressed_tensors.utils import align_module_device, update_offload_parameter
1211
from llmcompressor.observers.helpers import flatten_for_calibration
1312

1413
__all__ = ["Observer", "MinMaxTuple", "ScaleZpTuple", "calibrate_module_from_observer"]
@@ -213,8 +212,6 @@ def calibrate_module_from_observer(
213212
:param base_name: one of "input", "output", "q", "k", "v"
214213
:return: True if qparams were updated, False if observer had no accumulated stats
215214
"""
216-
from compressed_tensors.utils import align_module_device, update_offload_parameter
217-
218215
observer: Optional[Observer] = getattr(module, f"{base_name}_observer", None)
219216
if observer is None:
220217
return False

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Iterator
44

55
import torch
6+
from compressed_tensors.quantization import disable_quantization, enable_quantization
67
from compressed_tensors.utils import disable_offloading
78
from torch.utils.data.dataloader import DataLoader
89
from tqdm import tqdm
@@ -19,7 +20,6 @@
1920
)
2021
from llmcompressor.utils.dev import get_main_device
2122
from llmcompressor.utils.helpers import (
22-
DISABLE_QAC_MODIFIERS,
2323
DisableQuantization,
2424
calibration_forward_context,
2525
)
@@ -111,18 +111,13 @@ def __call__(
111111

112112
LifecycleCallbacks.calibration_epoch_start()
113113

114-
# TODO: remove this to enable quantization aware calibration
115-
# for GPTQ, AWQ and AutoRound.
116-
disable_qac = any(
117-
type(mod).__name__ in DISABLE_QAC_MODIFIERS
118-
for mod in session.lifecycle.recipe.modifiers
119-
)
120-
121114
with contextlib.ExitStack() as stack:
122115
stack.enter_context(calibration_forward_context(model))
123-
# Optionally disable quantization
124-
if not dataset_args.quantization_aware_calibration or disable_qac:
125-
stack.enter_context(DisableQuantization(model))
116+
# Always disable quantization during calibration so that observer hooks
117+
# accumulate statistics from unquantized activations. Quantization is
118+
# re-enabled during the propagation pass so that downstream subgraphs
119+
# receive realistic (quantized) inputs.
120+
stack.enter_context(DisableQuantization(model))
126121

127122
# prepare intermediates cache
128123
activations = IntermediatesCache.from_dataloader(
@@ -148,7 +143,7 @@ def __call__(
148143
num_batches = len(dataloader)
149144
use_prefetch = getattr(dataset_args, "sequential_prefetch", False)
150145
with disable_offloading():
151-
# do a preliminary pass to trigger modifier hooks
146+
# calibration pass: hooks accumulate activation statistics
152147
for batch_idx, inputs in _get_batches(
153148
activations,
154149
num_batches,
@@ -159,10 +154,13 @@ def __call__(
159154
session.state.current_batch_idx = batch_idx
160155
subgraph.forward(model, **inputs)
161156

157+
# flush accumulated stats -> write scale/zero_point once per subgraph
162158
LifecycleCallbacks.sequential_epoch_end(subgraph)
163159

164-
# this pass does not trigger modifier hooks
165-
# and is only used for capturing outputs of newly compressed modules
160+
# propagation pass: modifier hooks are disabled but quantization is
161+
# re-enabled so that compressed module outputs are quantized.
162+
# This ensures downstream subgraphs receive realistic inputs.
163+
model.apply(enable_quantization)
166164
with HooksMixin.disable_hooks():
167165
for batch_idx, inputs in _get_batches(
168166
activations,
@@ -175,6 +173,7 @@ def __call__(
175173
if subgraph_index < num_subgraphs - 1:
176174
activations.update(batch_idx, output)
177175
activations.delete(batch_idx, subgraph.consumed_names)
176+
model.apply(disable_quantization)
178177

179178
# redundant, finish any remaining compression
180179
LifecycleCallbacks.calibration_epoch_end()

0 commit comments

Comments
 (0)