1- from typing import Optional
1+ from typing import Optional , Tuple
22
33import torch
44from compressed_tensors .quantization import (
77)
88from compressed_tensors .quantization .lifecycle import fake_quantize
99from compressed_tensors .quantization .utils import calculate_qparams , generate_gparam
10- from compressed_tensors .utils import patch_attr
1110
1211from llmcompressor .observers .base import MinMaxTuple , Observer
1312from llmcompressor .observers .moving_base import MovingAverageObserverBase
@@ -56,11 +55,11 @@ def __init__(self, *args, **kwargs):
5655 self .norm = observer_kwargs .get ("norm" , 2.4 )
5756 self .enable_torch_compile = observer_kwargs .get ("enable_torch_compile" , False )
5857
59- # Pre-create token_args for compiled path to avoid patch_attr
60- if self . enable_torch_compile :
61- self ._token_args = self .args .model_copy (
62- update = {"strategy" : QuantizationStrategy .TOKEN }
63- )
58+ # Pre-create token_args to avoid patch_attr context manager
59+ # which causes torch.compile graph breaks
60+ self ._token_args = self .args .model_copy (
61+ update = {"strategy" : QuantizationStrategy .TOKEN }
62+ )
6463
6564 def _call_grid_search (
6665 self ,
@@ -82,6 +81,7 @@ def _call_grid_search(
8281 return _grid_search_mse (
8382 observed ,
8483 self .args ,
84+ self ._token_args ,
8585 self .maxshrink ,
8686 self .patience ,
8787 self .grid ,
@@ -141,10 +141,11 @@ def __init__(self, *args, **kwargs):
141141 self .norm = observer_kwargs .get ("norm" , 2.4 )
142142 self .enable_torch_compile = observer_kwargs .get ("enable_torch_compile" , False )
143143
144- if self .enable_torch_compile :
145- self ._token_args = self .args .model_copy (
146- update = {"strategy" : QuantizationStrategy .TOKEN }
147- )
144+ # Pre-create token_args to avoid patch_attr context manager
145+ # which causes torch.compile graph breaks
146+ self ._token_args = self .args .model_copy (
147+ update = {"strategy" : QuantizationStrategy .TOKEN }
148+ )
148149
149150 def _call_grid_search (
150151 self ,
@@ -166,6 +167,7 @@ def _call_grid_search(
166167 return _grid_search_mse (
167168 observed ,
168169 self .args ,
170+ self ._token_args ,
169171 self .maxshrink ,
170172 self .patience ,
171173 self .grid ,
@@ -184,9 +186,68 @@ def get_current_global_min_max(self, observed: torch.Tensor) -> MinMaxTuple:
184186 return self ._call_grid_search (observed , None , True )
185187
186188
189+ def _compute_candidate_error (
190+ observed : torch .Tensor ,
191+ args : QuantizationArgs ,
192+ token_args : QuantizationArgs ,
193+ min_val : torch .Tensor ,
194+ max_val : torch .Tensor ,
195+ p : float ,
196+ norm : float ,
197+ global_scale : Optional [torch .Tensor ],
198+ optimize_global_scale : bool ,
199+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
200+ """
201+ Compute the quantization error for a single shrink factor.
202+
203+ Shared helper used by both the default and torch.compile-compatible
204+ grid search paths to avoid code duplication.
205+
206+ :param observed: value of shape (num_observations, *qparams_shape, group_size)
207+ :param args: quantization args used for computing qparams
208+ :param token_args: quantization args with strategy set to TOKEN, pre-created
209+ to avoid patch_attr context manager which causes torch.compile graph breaks
210+ :param min_val: per-channel minimum values
211+ :param max_val: per-channel maximum values
212+ :param p: shrink factor (1 - i/grid)
213+ :param norm: exponent used when computing the error
214+ :param global_scale: precomputed global scale to use for quantization
215+ :param optimize_global_scale: If True, recompute global_scale from candidates
216+ :return: (error, shrinked_min_val, shrinked_max_val)
217+ """
218+ shrinked_min_val = p * min_val
219+ shrinked_max_val = p * max_val
220+
221+ if optimize_global_scale :
222+ global_scale = generate_gparam (shrinked_min_val , shrinked_max_val )
223+
224+ candidate_scales , candidate_zero_points = calculate_qparams (
225+ min_vals = shrinked_min_val ,
226+ max_vals = shrinked_max_val ,
227+ quantization_args = args ,
228+ global_scale = global_scale ,
229+ )
230+
231+ # Use pre-created token_args instead of patch_attr context manager
232+ # to maintain torch.compile compatibility
233+ q = fake_quantize (
234+ observed ,
235+ candidate_scales .unsqueeze (- 1 ),
236+ candidate_zero_points .unsqueeze (- 1 ),
237+ token_args ,
238+ global_scale = global_scale ,
239+ ).to (observed .dtype )
240+ # Note that due to forward quantization implementation, token quant,
241+ # unlike tensor_group, requires extra dtype cast
242+
243+ err = torch .sum ((q - observed ).abs ().pow (norm ), dim = (0 , - 1 ))
244+ return err , shrinked_min_val , shrinked_max_val
245+
246+
187247def _grid_search_mse (
188248 observed : torch .Tensor ,
189249 args : QuantizationArgs ,
250+ token_args : QuantizationArgs ,
190251 maxshrink : float ,
191252 patience : float ,
192253 grid : float ,
@@ -202,8 +263,12 @@ def _grid_search_mse(
202263 observed tensor and evaluates the quantization error at each candidate
203264 range. For each shrink factor ``p = 1 - i/grid`` up to ``maxshrink``.
204265
266+ Uses early stopping to skip unnecessary search steps when no improvement
267+ is found for ``patience`` consecutive steps.
268+
205269 :param observed: value of shape (num_observations, *qparams_shape, group_size)
206270 :param args: quantization args used for computing qparams and fake quant
271+ :param token_args: quantization args with strategy set to TOKEN
207272 :param maxshrink: maximum shrink amount (in "grid steps"). The number of
208273 search steps is int(maxshrink * grid)
209274 :param patience: number of consecutive search steps without improvement before
@@ -222,43 +287,16 @@ def _grid_search_mse(
222287 best_min_val = min_val .clone ()
223288 best_max_val = max_val .clone ()
224289
225- # Early stopping params
226290 no_improve_count = 0
227291
228292 # @ksayers @HGCharles: investigate searching over separate shrinking factors
229293 for i in range (int (maxshrink * grid )):
230294 p = 1 - i / grid
231- shrinked_min_val = p * min_val
232- shrinked_max_val = p * max_val
233-
234- if optimize_global_scale :
235- global_scale = generate_gparam (shrinked_min_val , shrinked_max_val )
236-
237- candidate_scales , candidate_zero_points = calculate_qparams (
238- min_vals = shrinked_min_val ,
239- max_vals = shrinked_max_val ,
240- quantization_args = args ,
241- global_scale = global_scale ,
295+ err , shrinked_min_val , shrinked_max_val = _compute_candidate_error (
296+ observed , args , token_args , min_val , max_val , p , norm ,
297+ global_scale , optimize_global_scale ,
242298 )
243299
244- # Note that observed.shape = (num_observations, *qparams_shape, group_size).
245- # For the purposes of fake quantization, this is equivalent to token quant
246- with patch_attr (args , "strategy" , QuantizationStrategy .TOKEN ):
247- q = fake_quantize (
248- observed ,
249- candidate_scales .unsqueeze (- 1 ),
250- candidate_zero_points .unsqueeze (- 1 ),
251- args ,
252- global_scale = global_scale ,
253- ).to (observed .dtype )
254- # Note that due to forward quantization implementation, token quant,
255- # unlike tensor_group, requires extra dtype cast
256-
257- q -= observed
258- q .abs_ ()
259- q .pow_ (norm )
260- err = torch .sum (q , dim = (0 , - 1 ))
261-
262300 tmp = err < best_error
263301 if torch .any (tmp ):
264302 best_error [tmp ] = err [tmp ]
@@ -287,16 +325,13 @@ def _grid_search_mse_compiled(
287325 torch.compile-compatible version of _grid_search_mse.
288326
289327 Differences from the default path:
290- - Uses pre-created token_args instead of patch_attr context manager
291- (patch_attr causes graph breaks)
292328 - Uses torch.where instead of data-dependent control flow
293- (early stopping causes graph breaks)
329+ (early stopping and torch.any cause graph breaks)
294330 - No early stopping: runs all search steps for deterministic compilation
295331
296332 :param observed: value of shape (num_observations, *qparams_shape, group_size)
297333 :param args: quantization args used for computing qparams
298- :param token_args: quantization args with strategy set to TOKEN, pre-created
299- to avoid patch_attr context manager which causes torch.compile graph breaks
334+ :param token_args: quantization args with strategy set to TOKEN
300335 :param maxshrink: maximum shrink amount. The number of search steps is
301336 int(maxshrink * grid)
302337 :param grid: resolution of the shrink search. Larger values give finer granularity
@@ -316,31 +351,11 @@ def _grid_search_mse_compiled(
316351 num_steps = int (maxshrink * grid )
317352 for i in range (num_steps ):
318353 p = 1 - i / grid
319- shrinked_min_val = p * min_val
320- shrinked_max_val = p * max_val
321-
322- if optimize_global_scale :
323- global_scale = generate_gparam (shrinked_min_val , shrinked_max_val )
324-
325- candidate_scales , candidate_zero_points = calculate_qparams (
326- min_vals = shrinked_min_val ,
327- max_vals = shrinked_max_val ,
328- quantization_args = args ,
329- global_scale = global_scale ,
354+ err , shrinked_min_val , shrinked_max_val = _compute_candidate_error (
355+ observed , args , token_args , min_val , max_val , p , norm ,
356+ global_scale , optimize_global_scale ,
330357 )
331358
332- # Use pre-created token_args instead of patch_attr context manager
333- # to maintain torch.compile compatibility
334- q = fake_quantize (
335- observed ,
336- candidate_scales .unsqueeze (- 1 ),
337- candidate_zero_points .unsqueeze (- 1 ),
338- token_args ,
339- global_scale = global_scale ,
340- ).to (observed .dtype )
341-
342- err = torch .sum ((q - observed ).abs ().pow (norm ), dim = (0 , - 1 ))
343-
344359 # Use torch.where instead of boolean indexing + torch.any for
345360 # torch.compile compatibility (avoids data-dependent control flow)
346361 improved = err < best_error
0 commit comments