Skip to content

Commit bf63a4c

Browse files
committed
perf: make MSE observer compatible with torch.compile
compile inner _compute_candidate_error via torch.compile(dynamic=True). early stopping preserved in outer loop. compile flag added as oneshot arg. requires: vllm-project/compressed-tensors#627 related: pytorch/pytorch#177131 Signed-off-by: Jaewoo Kim <pewpewplay315@gmail.com>
1 parent 370c04c commit bf63a4c

File tree

4 files changed

+198
-67
lines changed

4 files changed

+198
-67
lines changed

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from llmcompressor.core.session_functions import active_session
2323
from llmcompressor.datasets import get_calibration_dataloader
2424
from llmcompressor.entrypoints.utils import post_process, pre_process
25+
from llmcompressor.observers.compile_config import set_observer_compile
2526
from llmcompressor.modeling.moe_context import moe_calibration_context
2627
from llmcompressor.pipelines import CalibrationPipeline
2728

@@ -300,6 +301,7 @@ def oneshot(
300301
sequential_offload_device: str = "cpu",
301302
quantization_aware_calibration: bool = True,
302303
sequential_prefetch: bool = False,
304+
enable_observer_compile: bool = False,
303305
# Miscellaneous arguments
304306
output_dir: str | None = None,
305307
log_dir: str | None = None,
@@ -406,9 +408,10 @@ def oneshot(
406408

407409
# pass all args directly into Oneshot
408410
local_args = {
409-
k: v for k, v in locals().items() if k not in ("local_args", "kwargs")
411+
k: v for k, v in locals().items() if k not in ("local_args", "kwargs", "enable_observer_compile")
410412
}
411413
one_shot = Oneshot(**local_args, **kwargs)
414+
set_observer_compile(enable_observer_compile)
412415
one_shot()
413416

414417
return one_shot.model
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
Global configuration for observer torch.compile support.
3+
4+
The compile flag is set by the oneshot entrypoint and read by observer
5+
instances at call time. This avoids threading the flag through recipe
6+
and modifier layers.
7+
"""
8+
9+
_enable_observer_compile: bool = False
10+
11+
12+
def set_observer_compile(enabled: bool) -> None:
13+
global _enable_observer_compile
14+
_enable_observer_compile = enabled
15+
16+
17+
def get_observer_compile() -> bool:
18+
return _enable_observer_compile

src/llmcompressor/observers/mse.py

Lines changed: 143 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1-
from typing import Optional
1+
from typing import Optional, Tuple
22

33
import torch
4+
import torch._dynamo.config
45
from compressed_tensors.quantization import (
56
QuantizationArgs,
67
QuantizationStrategy,
78
)
89
from compressed_tensors.quantization.lifecycle import fake_quantize
910
from compressed_tensors.quantization.utils import calculate_qparams, generate_gparam
10-
from compressed_tensors.utils import patch_attr
1111

1212
from llmcompressor.observers.base import MinMaxTuple, Observer
13+
from llmcompressor.observers.compile_config import get_observer_compile
1314
from llmcompressor.observers.moving_base import MovingAverageObserverBase
1415

1516
__all__ = ["MovingAverageMSEObserver"]
1617

18+
# Allow torch.compile to handle scalar conversions inside
19+
# compressed_tensors' calculate_qparams (float(bit_range)).
20+
# Same approach as GPTQ compile path (commit a4f9ba2e).
21+
torch._dynamo.config.capture_scalar_outputs = True
22+
1723

1824
@Observer.register("memoryless_mse")
1925
class MemorylessMSEObserver(Observer):
@@ -32,7 +38,7 @@ class MemorylessMSEObserver(Observer):
3238
:param module: optional module with attached quantization parameters. This argument
3339
is required to utilize existing qparams such as global_scale or g_idx
3440
:param **observer_kwargs: keyword arguments for observer initialization\n
35-
maxshrink: maximum shrink amount (in grid steps). The number of
41+
maxshrink: maximum shrink amount (in "grid steps"). The number of
3642
search steps is int(maxshrink * grid)\n
3743
patience: number of consecutive search steps without improvement before
3844
early stopping\n
@@ -53,32 +59,39 @@ def __init__(self, *args, **kwargs):
5359
self.grid = observer_kwargs.get("grid", 100.0)
5460
self.norm = observer_kwargs.get("norm", 2.4)
5561

56-
def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
57-
# min[min_vals, max_vals](mse_quant_error)
58-
global_scale = self._get_module_param("global_scale")
62+
# Pre-create token_args to avoid patch_attr context manager
63+
# which causes torch.compile graph breaks
64+
self._token_args = self.args.model_copy(
65+
update={"strategy": QuantizationStrategy.TOKEN}
66+
)
67+
68+
def _call_grid_search(
69+
self,
70+
observed: torch.Tensor,
71+
global_scale: Optional[torch.Tensor],
72+
optimize_global_scale: bool,
73+
) -> MinMaxTuple:
5974
return _grid_search_mse(
6075
observed,
6176
self.args,
77+
self._token_args,
6278
self.maxshrink,
6379
self.patience,
6480
self.grid,
6581
self.norm,
6682
global_scale=global_scale,
67-
optimize_global_scale=False,
83+
optimize_global_scale=optimize_global_scale,
84+
enable_compile=get_observer_compile(),
6885
)
6986

87+
def get_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
88+
# min[min_vals, max_vals](mse_quant_error)
89+
global_scale = self._get_module_param("global_scale")
90+
return self._call_grid_search(observed, global_scale, False)
91+
7092
def get_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
7193
# min[min_vals, max_vals, global_scale](mse_quant_error)
72-
return _grid_search_mse(
73-
observed,
74-
self.args,
75-
self.maxshrink,
76-
self.patience,
77-
self.grid,
78-
self.norm,
79-
global_scale=None,
80-
optimize_global_scale=True,
81-
)
94+
return self._call_grid_search(observed, None, True)
8295

8396

8497
@Observer.register("mse")
@@ -98,7 +111,7 @@ class MovingAverageMSEObserver(MovingAverageObserverBase):
98111
:param module: optional module with attached quantization parameters. This argument
99112
is required to utilize existing qparams such as global_scale or g_idx
100113
:param **observer_kwargs: keyword arguments for observer initialization\n
101-
maxshrink: maximum shrink amount (in grid steps). The number of
114+
maxshrink: maximum shrink amount (in "grid steps"). The number of
102115
search steps is int(maxshrink * grid)\n
103116
patience: number of consecutive search steps without improvement before
104117
early stopping\n
@@ -119,55 +132,134 @@ def __init__(self, *args, **kwargs):
119132
self.grid = observer_kwargs.get("grid", 100.0)
120133
self.norm = observer_kwargs.get("norm", 2.4)
121134

122-
def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
123-
# min[min_vals, max_vals](mse_quant_error)
124-
global_scale = self._get_module_param("global_scale")
135+
# Pre-create token_args to avoid patch_attr context manager
136+
# which causes torch.compile graph breaks
137+
self._token_args = self.args.model_copy(
138+
update={"strategy": QuantizationStrategy.TOKEN}
139+
)
140+
141+
def _call_grid_search(
142+
self,
143+
observed: torch.Tensor,
144+
global_scale: Optional[torch.Tensor],
145+
optimize_global_scale: bool,
146+
) -> MinMaxTuple:
125147
return _grid_search_mse(
126148
observed,
127149
self.args,
150+
self._token_args,
128151
self.maxshrink,
129152
self.patience,
130153
self.grid,
131154
self.norm,
132155
global_scale=global_scale,
133-
optimize_global_scale=False,
156+
optimize_global_scale=optimize_global_scale,
157+
enable_compile=get_observer_compile(),
134158
)
135159

160+
def get_current_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
161+
# min[min_vals, max_vals](mse_quant_error)
162+
global_scale = self._get_module_param("global_scale")
163+
return self._call_grid_search(observed, global_scale, False)
164+
136165
def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
137166
# min[min_vals, max_vals, global_scale](mse_quant_error)
138-
return _grid_search_mse(
139-
observed,
140-
self.args,
141-
self.maxshrink,
142-
self.patience,
143-
self.grid,
144-
self.norm,
145-
global_scale=None,
146-
optimize_global_scale=True,
147-
)
167+
return self._call_grid_search(observed, None, True)
168+
169+
170+
def _compute_candidate_error(
171+
observed: torch.Tensor,
172+
args: QuantizationArgs,
173+
token_args: QuantizationArgs,
174+
min_val: torch.Tensor,
175+
max_val: torch.Tensor,
176+
p: float,
177+
norm: float,
178+
global_scale: Optional[torch.Tensor],
179+
optimize_global_scale: bool,
180+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181+
"""
182+
Compute the quantization error for a single shrink factor.
183+
184+
Shared helper used by the grid search. When enable_compile is set
185+
via oneshot, this function is called through its compiled wrapper
186+
for accelerated execution.
187+
188+
:param observed: value of shape (num_observations, *qparams_shape, group_size)
189+
:param args: quantization args used for computing qparams
190+
:param token_args: quantization args with strategy set to TOKEN, pre-created
191+
to avoid patch_attr context manager which causes torch.compile graph breaks
192+
:param min_val: per-channel minimum values
193+
:param max_val: per-channel maximum values
194+
:param p: shrink factor (1 - i/grid)
195+
:param norm: exponent used when computing the error
196+
:param global_scale: precomputed global scale to use for quantization
197+
:param optimize_global_scale: If True, recompute global_scale from candidates
198+
:return: (error, shrinked_min_val, shrinked_max_val)
199+
"""
200+
shrinked_min_val = p * min_val
201+
shrinked_max_val = p * max_val
202+
203+
if optimize_global_scale:
204+
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
205+
206+
candidate_scales, candidate_zero_points = calculate_qparams(
207+
min_vals=shrinked_min_val,
208+
max_vals=shrinked_max_val,
209+
quantization_args=args,
210+
global_scale=global_scale,
211+
)
212+
213+
# Use pre-created token_args instead of patch_attr context manager
214+
# to maintain torch.compile compatibility
215+
q = fake_quantize(
216+
observed,
217+
candidate_scales.unsqueeze(-1),
218+
candidate_zero_points.unsqueeze(-1),
219+
token_args,
220+
global_scale=global_scale,
221+
).to(observed.dtype)
222+
223+
err = torch.sum((q - observed).abs().pow(norm), dim=(0, -1))
224+
return err, shrinked_min_val, shrinked_max_val
225+
226+
227+
# Compiled variant of the inner computation.
228+
# The outer grid search loop stays in eager mode to preserve
229+
# early stopping (data-dependent control flow).
230+
_compute_candidate_error_compiled = torch.compile(
231+
_compute_candidate_error, dynamic=True
232+
)
148233

149234

150235
def _grid_search_mse(
151236
observed: torch.Tensor,
152237
args: QuantizationArgs,
238+
token_args: QuantizationArgs,
153239
maxshrink: float,
154240
patience: float,
155241
grid: float,
156242
norm: float,
157243
global_scale: Optional[torch.Tensor] = None,
158244
optimize_global_scale: bool = False,
245+
enable_compile: bool = False,
159246
) -> MinMaxTuple:
160247
"""
161248
Perform a 1-D grid search to find per-channel min/max ranges that minimize
162249
mean-squared quantization error.
163250
164-
This routine progressively “shrinks” the absolute min/max ranges of the
165-
observed tensor and evaluates the quantization error at each candidate
166-
range. For each shrink factor ``p = 1 - i/grid`` up to ``maxshrink``.
251+
Progressively shrinks the absolute min/max ranges of the observed tensor
252+
and evaluates the quantization error at each candidate. Early stopping
253+
exits when no improvement is found for ``patience`` consecutive steps.
254+
255+
When enable_compile is True, the inner error computation is executed
256+
through a torch.compiled wrapper for accelerated execution while
257+
preserving early stopping in the outer loop.
167258
168259
:param observed: value of shape (num_observations, *qparams_shape, group_size)
169260
:param args: quantization args used for computing qparams and fake quant
170-
:param maxshrink: maximum shrink amount (in “grid steps”). The number of
261+
:param token_args: quantization args with strategy set to TOKEN
262+
:param maxshrink: maximum shrink amount (in "grid steps"). The number of
171263
search steps is int(maxshrink * grid)
172264
:param patience: number of consecutive search steps without improvement before
173265
early stopping
@@ -178,50 +270,35 @@ def _grid_search_mse(
178270
`optimize_global_scale` is True
179271
:param optimize_global_scale: If True, recompute ``global_scale`` from the
180272
candidate min/max during each step of the search
273+
:param enable_compile: If True, use torch.compiled inner computation
181274
"""
182275
min_val = torch.amin(observed, dim=(0, -1))
183276
max_val = torch.amax(observed, dim=(0, -1))
184277
best_error = torch.full_like(min_val, torch.finfo(min_val.dtype).max)
185278
best_min_val = min_val.clone()
186279
best_max_val = max_val.clone()
187280

188-
# Early stopping params
281+
compute_fn = (
282+
_compute_candidate_error_compiled if enable_compile
283+
else _compute_candidate_error
284+
)
189285
no_improve_count = 0
190286

191287
# @ksayers @HGCharles: investigate searching over separate shrinking factors
192288
for i in range(int(maxshrink * grid)):
193289
p = 1 - i / grid
194-
shrinked_min_val = p * min_val
195-
shrinked_max_val = p * max_val
196-
197-
if optimize_global_scale:
198-
global_scale = generate_gparam(shrinked_min_val, shrinked_max_val)
199-
200-
candidate_scales, candidate_zero_points = calculate_qparams(
201-
min_vals=shrinked_min_val,
202-
max_vals=shrinked_max_val,
203-
quantization_args=args,
204-
global_scale=global_scale,
290+
err, shrinked_min_val, shrinked_max_val = compute_fn(
291+
observed,
292+
args,
293+
token_args,
294+
min_val,
295+
max_val,
296+
p,
297+
norm,
298+
global_scale,
299+
optimize_global_scale,
205300
)
206301

207-
# Note that observed.shape = (num_observations, *qparams_shape, group_size).
208-
# For the purposes of fake quantization, this is equivalent to token quant
209-
with patch_attr(args, "strategy", QuantizationStrategy.TOKEN):
210-
q = fake_quantize(
211-
observed,
212-
candidate_scales.unsqueeze(-1),
213-
candidate_zero_points.unsqueeze(-1),
214-
args,
215-
global_scale=global_scale,
216-
).to(observed.dtype)
217-
# Note that due to forward quantization implementation, token quant,
218-
# unlike tensor_group, requires extra dtype cast
219-
220-
q -= observed
221-
q.abs_()
222-
q.pow_(norm)
223-
err = torch.sum(q, dim=(0, -1))
224-
225302
tmp = err < best_error
226303
if torch.any(tmp):
227304
best_error[tmp] = err[tmp]

0 commit comments

Comments
 (0)