Skip to content

Commit 6a415e6

Browse files
committed
refactor: extract shared helper, merge compile test into test_mse
- Extract _compute_candidate_error() shared by both code paths - Remove patch_attr from both paths (use pre-created token_args) - Preserve early stopping in non-compiled path - Move torch.compile test into test_mse.py per review feedback - Remove separate test_observer_compile.py Signed-off-by: Jaewoo Kim <pewpewplay315@gmail.com>
1 parent 0dee4e0 commit 6a415e6

File tree

3 files changed

+106
-90
lines changed

3 files changed

+106
-90
lines changed

src/llmcompressor/observers/mse.py

Lines changed: 84 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple
22

33
import torch
44
from compressed_tensors.quantization import (
@@ -7,7 +7,6 @@
77
)
88
from compressed_tensors.quantization.lifecycle import fake_quantize
99
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
10-
from compressed_tensors.utils import patch_attr
1110

1211
from llmcompressor.observers.base import MinMaxTuple, Observer
1312
from llmcompressor.observers.moving_base import MovingAverageObserverBase
@@ -56,11 +55,11 @@ def __init__(self, *args, **kwargs):
5655
self.norm = observer_kwargs.get("norm", 2.4)
5756
self.enable_torch_compile = observer_kwargs.get("enable_torch_compile", False)
5857

59-
# Pre-create token_args for compiled path to avoid patch_attr
60-
if self.enable_torch_compile:
61-
self._token_args = self.args.model_copy(
62-
update={"strategy": QuantizationStrategy.TOKEN}
63-
)
58+
# Pre-create token_args to avoid patch_attr context manager
59+
# which causes torch.compile graph breaks
60+
self._token_args = self.args.model_copy(
61+
update={"strategy": QuantizationStrategy.TOKEN}
62+
)
6463

6564
def _call_grid_search(
6665
self,
@@ -82,6 +81,7 @@ def _call_grid_search(
8281
return _grid_search_mse(
8382
observed,
8483
self.args,
84+
self._token_args,
8585
self.maxshrink,
8686
self.patience,
8787
self.grid,
@@ -141,10 +141,11 @@ def __init__(self, *args, **kwargs):
141141
self.norm = observer_kwargs.get("norm", 2.4)
142142
self.enable_torch_compile = observer_kwargs.get("enable_torch_compile", False)
143143

144-
if self.enable_torch_compile:
145-
self._token_args = self.args.model_copy(
146-
update={"strategy": QuantizationStrategy.TOKEN}
147-
)
144+
# Pre-create token_args to avoid patch_attr context manager
145+
# which causes torch.compile graph breaks
146+
self._token_args = self.args.model_copy(
147+
update={"strategy": QuantizationStrategy.TOKEN}
148+
)
148149

149150
def _call_grid_search(
150151
self,
@@ -166,6 +167,7 @@ def _call_grid_search(
166167
return _grid_search_mse(
167168
observed,
168169
self.args,
170+
self._token_args,
169171
self.maxshrink,
170172
self.patience,
171173
self.grid,
@@ -184,9 +186,68 @@ def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
184186
return self._call_grid_search(observed, None, True)
185187

186188

189+
def _compute_candidate_error(
190+
observed: torch.Tensor,
191+
args: QuantizationArgs,
192+
token_args: QuantizationArgs,
193+
min_val: torch.Tensor,
194+
max_val: torch.Tensor,
195+
p: float,
196+
norm: float,
197+
global_scale: Optional[torch.Tensor],
198+
optimize_global_scale: bool,
199+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
200+
"""
201+
Compute the quantization error for a single shrink factor.
202+
203+
Shared helper used by both the default and torch.compile-compatible
204+
grid search paths to avoid code duplication.
205+
206+
:param observed: value of shape (num_observations, *qparams_shape, group_size)
207+
:param args: quantization args used for computing qparams
208+
:param token_args: quantization args with strategy set to TOKEN, pre-created
209+
to avoid patch_attr context manager which causes torch.compile graph breaks
210+
:param min_val: per-channel minimum values
211+
:param max_val: per-channel maximum values
212+
:param p: shrink factor (1 - i/grid)
213+
:param norm: exponent used when computing the error
214+
:param global_scale: precomputed global scale to use for quantization
215+
:param optimize_global_scale: If True, recompute global_scale from candidates
216+
:return: (error, shrinked_min_val, shrinked_max_val)
217+
"""
218+
shrinked_min_val = p * min_val
219+
shrinked_max_val = p * max_val
220+
221+
if optimize_global_scale:
222+
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
223+
224+
candidate_scales, candidate_zero_points = calculate_qparams(
225+
min_vals=shrinked_min_val,
226+
max_vals=shrinked_max_val,
227+
quantization_args=args,
228+
global_scale=global_scale,
229+
)
230+
231+
# Use pre-created token_args instead of patch_attr context manager
232+
# to maintain torch.compile compatibility
233+
q = fake_quantize(
234+
observed,
235+
candidate_scales.unsqueeze(-1),
236+
candidate_zero_points.unsqueeze(-1),
237+
token_args,
238+
global_scale=global_scale,
239+
).to(observed.dtype)
240+
# Note that due to forward quantization implementation, token quant,
241+
# unlike tensor_group, requires extra dtype cast
242+
243+
err = torch.sum((q - observed).abs().pow(norm), dim=(0, -1))
244+
return err, shrinked_min_val, shrinked_max_val
245+
246+
187247
def _grid_search_mse(
188248
observed: torch.Tensor,
189249
args: QuantizationArgs,
250+
token_args: QuantizationArgs,
190251
maxshrink: float,
191252
patience: float,
192253
grid: float,
@@ -202,8 +263,12 @@ def _grid_search_mse(
202263
observed tensor and evaluates the quantization error at each candidate
203264
range. For each shrink factor ``p = 1 - i/grid`` up to ``maxshrink``.
204265
266+
Uses early stopping to skip unnecessary search steps when no improvement
267+
is found for ``patience`` consecutive steps.
268+
205269
:param observed: value of shape (num_observations, *qparams_shape, group_size)
206270
:param args: quantization args used for computing qparams and fake quant
271+
:param token_args: quantization args with strategy set to TOKEN
207272
:param maxshrink: maximum shrink amount (in "grid steps"). The number of
208273
search steps is int(maxshrink * grid)
209274
:param patience: number of consecutive search steps without improvement before
@@ -222,43 +287,16 @@ def _grid_search_mse(
222287
best_min_val = min_val.clone()
223288
best_max_val = max_val.clone()
224289

225-
# Early stopping params
226290
no_improve_count = 0
227291

228292
# @ksayers @HGCharles: investigate searching over separate shrinking factors
229293
for i in range(int(maxshrink * grid)):
230294
p = 1 - i / grid
231-
shrinked_min_val = p * min_val
232-
shrinked_max_val = p * max_val
233-
234-
if optimize_global_scale:
235-
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
236-
237-
candidate_scales, candidate_zero_points = calculate_qparams(
238-
min_vals=shrinked_min_val,
239-
max_vals=shrinked_max_val,
240-
quantization_args=args,
241-
global_scale=global_scale,
295+
err, shrinked_min_val, shrinked_max_val = _compute_candidate_error(
296+
observed, args, token_args, min_val, max_val, p, norm,
297+
global_scale, optimize_global_scale,
242298
)
243299

244-
# Note that observed.shape = (num_observations, *qparams_shape, group_size).
245-
# For the purposes of fake quantization, this is equivalent to token quant
246-
with patch_attr(args, "strategy", QuantizationStrategy.TOKEN):
247-
q = fake_quantize(
248-
observed,
249-
candidate_scales.unsqueeze(-1),
250-
candidate_zero_points.unsqueeze(-1),
251-
args,
252-
global_scale=global_scale,
253-
).to(observed.dtype)
254-
# Note that due to forward quantization implementation, token quant,
255-
# unlike tensor_group, requires extra dtype cast
256-
257-
q -= observed
258-
q.abs_()
259-
q.pow_(norm)
260-
err = torch.sum(q, dim=(0, -1))
261-
262300
tmp = err < best_error
263301
if torch.any(tmp):
264302
best_error[tmp] = err[tmp]
@@ -287,16 +325,13 @@ def _grid_search_mse_compiled(
287325
torch.compile-compatible version of _grid_search_mse.
288326
289327
Differences from the default path:
290-
- Uses pre-created token_args instead of patch_attr context manager
291-
(patch_attr causes graph breaks)
292328
- Uses torch.where instead of data-dependent control flow
293-
(early stopping causes graph breaks)
329+
(early stopping and torch.any cause graph breaks)
294330
- No early stopping: runs all search steps for deterministic compilation
295331
296332
:param observed: value of shape (num_observations, *qparams_shape, group_size)
297333
:param args: quantization args used for computing qparams
298-
:param token_args: quantization args with strategy set to TOKEN, pre-created
299-
to avoid patch_attr context manager which causes torch.compile graph breaks
334+
:param token_args: quantization args with strategy set to TOKEN
300335
:param maxshrink: maximum shrink amount. The number of search steps is
301336
int(maxshrink * grid)
302337
:param grid: resolution of the shrink search. Larger values give finer granularity
@@ -316,31 +351,11 @@ def _grid_search_mse_compiled(
316351
num_steps = int(maxshrink * grid)
317352
for i in range(num_steps):
318353
p = 1 - i / grid
319-
shrinked_min_val = p * min_val
320-
shrinked_max_val = p * max_val
321-
322-
if optimize_global_scale:
323-
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
324-
325-
candidate_scales, candidate_zero_points = calculate_qparams(
326-
min_vals=shrinked_min_val,
327-
max_vals=shrinked_max_val,
328-
quantization_args=args,
329-
global_scale=global_scale,
354+
err, shrinked_min_val, shrinked_max_val = _compute_candidate_error(
355+
observed, args, token_args, min_val, max_val, p, norm,
356+
global_scale, optimize_global_scale,
330357
)
331358

332-
# Use pre-created token_args instead of patch_attr context manager
333-
# to maintain torch.compile compatibility
334-
q = fake_quantize(
335-
observed,
336-
candidate_scales.unsqueeze(-1),
337-
candidate_zero_points.unsqueeze(-1),
338-
token_args,
339-
global_scale=global_scale,
340-
).to(observed.dtype)
341-
342-
err = torch.sum((q - observed).abs().pow(norm), dim=(0, -1))
343-
344359
# Use torch.where instead of boolean indexing + torch.any for
345360
# torch.compile compatibility (avoids data-dependent control flow)
346361
improved = err < best_error

tests/llmcompressor/observers/test_mse.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,25 @@ def test_mse_fp4():
8787
module.weight, scale, zero_point, weights, global_scale=global_scale
8888
)
8989
assert torch.nn.functional.mse_loss(qdq_tensor, module.weight) <= 0.0015 # 0.0013
90+
91+
92+
def test_mse_observer_torch_compile():
93+
"""Test that MSE observer works with torch.compile when enable_torch_compile=True"""
94+
args = QuantizationArgs(
95+
num_bits=8,
96+
type="int",
97+
symmetric=True,
98+
strategy="tensor",
99+
observer="memoryless_mse",
100+
observer_kwargs={"enable_torch_compile": True},
101+
)
102+
observer = Observer.load_from_registry("memoryless_mse", base_name="weight", args=args)
103+
104+
x = torch.randn(1, 1, 128)
105+
eager_scale, eager_zp = observer(x)
106+
107+
compiled = torch.compile(observer, fullgraph=True, backend="eager")
108+
compiled_scale, compiled_zp = compiled(x)
109+
110+
torch.testing.assert_close(eager_scale, compiled_scale)
111+
torch.testing.assert_close(eager_zp, compiled_zp)

tests/llmcompressor/observers/test_observer_compile.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

0 commit comments

Comments
 (0)