1- from typing import Optional
1+ from typing import Optional , Tuple
22
33import torch
4+ import torch ._dynamo .config
45from compressed_tensors .quantization import (
56 QuantizationArgs ,
67 QuantizationStrategy ,
78)
89from compressed_tensors .quantization .lifecycle import fake_quantize
910from compressed_tensors .quantization .utils import calculate_qparams , generate_gparam
10- from compressed_tensors .utils import patch_attr
1111
1212from llmcompressor .observers .base import MinMaxTuple , Observer
13+ from llmcompressor .observers .compile_config import get_observer_compile
1314from llmcompressor .observers .moving_base import MovingAverageObserverBase
1415
1516__all__ = ["MovingAverageMSEObserver" ]
1617
18+ # Allow torch.compile to handle scalar conversions inside
19+ # compressed_tensors' calculate_qparams (float(bit_range)).
20+ # Same approach as GPTQ compile path (commit a4f9ba2e).
21+ torch ._dynamo .config .capture_scalar_outputs = True
22+
1723
1824@Observer .register ("memoryless_mse" )
1925class MemorylessMSEObserver (Observer ):
@@ -32,7 +38,7 @@ class MemorylessMSEObserver(Observer):
3238 :param module: optional module with attached quantization parameters. This argument
3339 is required to utilize existing qparams such as global_scale or g_idx
3440 :param **observer_kwargs: keyword arguments for observer initialization\n
35- maxshrink: maximum shrink amount (in “ grid steps” ). The number of
41+ maxshrink: maximum shrink amount (in " grid steps" ). The number of
3642 search steps is int(maxshrink * grid)\n
3743 patience: number of consecutive search steps without improvement before
3844 early stopping\n
@@ -53,32 +59,39 @@ def __init__(self, *args, **kwargs):
5359 self .grid = observer_kwargs .get ("grid" , 100.0 )
5460 self .norm = observer_kwargs .get ("norm" , 2.4 )
5561
56- def get_min_max (self , observed : torch .Tensor ) -> MinMaxTuple :
57- # min[min_vals, max_vals](mse_quant_error)
58- global_scale = self ._get_module_param ("global_scale" )
62+ # Pre-create token_args to avoid patch_attr context manager
63+ # which causes torch.compile graph breaks
64+ self ._token_args = self .args .model_copy (
65+ update = {"strategy" : QuantizationStrategy .TOKEN }
66+ )
67+
68+ def _call_grid_search (
69+ self ,
70+ observed : torch .Tensor ,
71+ global_scale : Optional [torch .Tensor ],
72+ optimize_global_scale : bool ,
73+ ) -> MinMaxTuple :
5974 return _grid_search_mse (
6075 observed ,
6176 self .args ,
77+ self ._token_args ,
6278 self .maxshrink ,
6379 self .patience ,
6480 self .grid ,
6581 self .norm ,
6682 global_scale = global_scale ,
67- optimize_global_scale = False ,
83+ optimize_global_scale = optimize_global_scale ,
84+ enable_compile = get_observer_compile (),
6885 )
6986
87+ def get_min_max (self , observed : torch .Tensor ) -> MinMaxTuple :
88+ # min[min_vals, max_vals](mse_quant_error)
89+ global_scale = self ._get_module_param ("global_scale" )
90+ return self ._call_grid_search (observed , global_scale , False )
91+
7092 def get_global_min_max (self , observed : torch .Tensor ) -> MinMaxTuple :
7193 # min[min_vals, max_vals, global_scale](mse_quant_error)
72- return _grid_search_mse (
73- observed ,
74- self .args ,
75- self .maxshrink ,
76- self .patience ,
77- self .grid ,
78- self .norm ,
79- global_scale = None ,
80- optimize_global_scale = True ,
81- )
94+ return self ._call_grid_search (observed , None , True )
8295
8396
8497@Observer .register ("mse" )
@@ -98,7 +111,7 @@ class MovingAverageMSEObserver(MovingAverageObserverBase):
98111 :param module: optional module with attached quantization parameters. This argument
99112 is required to utilize existing qparams such as global_scale or g_idx
100113 :param **observer_kwargs: keyword arguments for observer initialization\n
101- maxshrink: maximum shrink amount (in “ grid steps” ). The number of
114+ maxshrink: maximum shrink amount (in " grid steps" ). The number of
102115 search steps is int(maxshrink * grid)\n
103116 patience: number of consecutive search steps without improvement before
104117 early stopping\n
@@ -119,55 +132,134 @@ def __init__(self, *args, **kwargs):
119132 self .grid = observer_kwargs .get ("grid" , 100.0 )
120133 self .norm = observer_kwargs .get ("norm" , 2.4 )
121134
122- def get_current_min_max (self , observed : torch .Tensor ) -> MinMaxTuple :
123- # min[min_vals, max_vals](mse_quant_error)
124- global_scale = self ._get_module_param ("global_scale" )
135+ # Pre-create token_args to avoid patch_attr context manager
136+ # which causes torch.compile graph breaks
137+ self ._token_args = self .args .model_copy (
138+ update = {"strategy" : QuantizationStrategy .TOKEN }
139+ )
140+
141+ def _call_grid_search (
142+ self ,
143+ observed : torch .Tensor ,
144+ global_scale : Optional [torch .Tensor ],
145+ optimize_global_scale : bool ,
146+ ) -> MinMaxTuple :
125147 return _grid_search_mse (
126148 observed ,
127149 self .args ,
150+ self ._token_args ,
128151 self .maxshrink ,
129152 self .patience ,
130153 self .grid ,
131154 self .norm ,
132155 global_scale = global_scale ,
133- optimize_global_scale = False ,
156+ optimize_global_scale = optimize_global_scale ,
157+ enable_compile = get_observer_compile (),
134158 )
135159
160+ def get_current_min_max (self , observed : torch .Tensor ) -> MinMaxTuple :
161+ # min[min_vals, max_vals](mse_quant_error)
162+ global_scale = self ._get_module_param ("global_scale" )
163+ return self ._call_grid_search (observed , global_scale , False )
164+
136165 def get_current_global_min_max (self , observed : torch .Tensor ) -> MinMaxTuple :
137166 # min[min_vals, max_vals, global_scale](mse_quant_error)
138- return _grid_search_mse (
139- observed ,
140- self .args ,
141- self .maxshrink ,
142- self .patience ,
143- self .grid ,
144- self .norm ,
145- global_scale = None ,
146- optimize_global_scale = True ,
147- )
167+ return self ._call_grid_search (observed , None , True )
168+
169+
170+ def _compute_candidate_error (
171+ observed : torch .Tensor ,
172+ args : QuantizationArgs ,
173+ token_args : QuantizationArgs ,
174+ min_val : torch .Tensor ,
175+ max_val : torch .Tensor ,
176+ p : float ,
177+ norm : float ,
178+ global_scale : Optional [torch .Tensor ],
179+ optimize_global_scale : bool ,
180+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
181+ """
182+ Compute the quantization error for a single shrink factor.
183+
184+ Shared helper used by the grid search. When enable_compile is set
185+ via oneshot, this function is called through its compiled wrapper
186+ for accelerated execution.
187+
188+ :param observed: value of shape (num_observations, *qparams_shape, group_size)
189+ :param args: quantization args used for computing qparams
190+ :param token_args: quantization args with strategy set to TOKEN, pre-created
191+ to avoid patch_attr context manager which causes torch.compile graph breaks
192+ :param min_val: per-channel minimum values
193+ :param max_val: per-channel maximum values
194+ :param p: shrink factor (1 - i/grid)
195+ :param norm: exponent used when computing the error
196+ :param global_scale: precomputed global scale to use for quantization
197+ :param optimize_global_scale: If True, recompute global_scale from candidates
198+ :return: (error, shrinked_min_val, shrinked_max_val)
199+ """
200+ shrinked_min_val = p * min_val
201+ shrinked_max_val = p * max_val
202+
203+ if optimize_global_scale :
204+ global_scale = generate_gparam (shrinked_min_val , shrinked_max_val )
205+
206+ candidate_scales , candidate_zero_points = calculate_qparams (
207+ min_vals = shrinked_min_val ,
208+ max_vals = shrinked_max_val ,
209+ quantization_args = args ,
210+ global_scale = global_scale ,
211+ )
212+
213+ # Use pre-created token_args instead of patch_attr context manager
214+ # to maintain torch.compile compatibility
215+ q = fake_quantize (
216+ observed ,
217+ candidate_scales .unsqueeze (- 1 ),
218+ candidate_zero_points .unsqueeze (- 1 ),
219+ token_args ,
220+ global_scale = global_scale ,
221+ ).to (observed .dtype )
222+
223+ err = torch .sum ((q - observed ).abs ().pow (norm ), dim = (0 , - 1 ))
224+ return err , shrinked_min_val , shrinked_max_val
225+
226+
227+ # Compiled variant of the inner computation.
228+ # The outer grid search loop stays in eager mode to preserve
229+ # early stopping (data-dependent control flow).
230+ _compute_candidate_error_compiled = torch .compile (
231+ _compute_candidate_error , dynamic = True
232+ )
148233
149234
150235def _grid_search_mse (
151236 observed : torch .Tensor ,
152237 args : QuantizationArgs ,
238+ token_args : QuantizationArgs ,
153239 maxshrink : float ,
154240 patience : float ,
155241 grid : float ,
156242 norm : float ,
157243 global_scale : Optional [torch .Tensor ] = None ,
158244 optimize_global_scale : bool = False ,
245+ enable_compile : bool = False ,
159246) -> MinMaxTuple :
160247 """
161248 Perform a 1-D grid search to find per-channel min/max ranges that minimize
162249 mean-squared quantization error.
163250
164- This routine progressively “shrinks” the absolute min/max ranges of the
165- observed tensor and evaluates the quantization error at each candidate
166- range. For each shrink factor ``p = 1 - i/grid`` up to ``maxshrink``.
251+ Progressively shrinks the absolute min/max ranges of the observed tensor
252+ and evaluates the quantization error at each candidate. Early stopping
253+ exits when no improvement is found for ``patience`` consecutive steps.
254+
255+ When enable_compile is True, the inner error computation is executed
256+ through a torch.compiled wrapper for accelerated execution while
257+ preserving early stopping in the outer loop.
167258
168259 :param observed: value of shape (num_observations, *qparams_shape, group_size)
169260 :param args: quantization args used for computing qparams and fake quant
170- :param maxshrink: maximum shrink amount (in “grid steps”). The number of
261+ :param token_args: quantization args with strategy set to TOKEN
262+ :param maxshrink: maximum shrink amount (in "grid steps"). The number of
171263 search steps is int(maxshrink * grid)
172264 :param patience: number of consecutive search steps without improvement before
173265 early stopping
@@ -178,50 +270,35 @@ def _grid_search_mse(
178270 `optimize_global_scale` is True
179271 :param optimize_global_scale: If True, recompute ``global_scale`` from the
180272 candidate min/max during each step of the search
273+ :param enable_compile: If True, use torch.compiled inner computation
181274 """
182275 min_val = torch .amin (observed , dim = (0 , - 1 ))
183276 max_val = torch .amax (observed , dim = (0 , - 1 ))
184277 best_error = torch .full_like (min_val , torch .finfo (min_val .dtype ).max )
185278 best_min_val = min_val .clone ()
186279 best_max_val = max_val .clone ()
187280
188- # Early stopping params
281+ compute_fn = (
282+ _compute_candidate_error_compiled if enable_compile
283+ else _compute_candidate_error
284+ )
189285 no_improve_count = 0
190286
191287 # @ksayers @HGCharles: investigate searching over separate shrinking factors
192288 for i in range (int (maxshrink * grid )):
193289 p = 1 - i / grid
194- shrinked_min_val = p * min_val
195- shrinked_max_val = p * max_val
196-
197- if optimize_global_scale :
198- global_scale = generate_gparam (shrinked_min_val , shrinked_max_val )
199-
200- candidate_scales , candidate_zero_points = calculate_qparams (
201- min_vals = shrinked_min_val ,
202- max_vals = shrinked_max_val ,
203- quantization_args = args ,
204- global_scale = global_scale ,
290+ err , shrinked_min_val , shrinked_max_val = compute_fn (
291+ observed ,
292+ args ,
293+ token_args ,
294+ min_val ,
295+ max_val ,
296+ p ,
297+ norm ,
298+ global_scale ,
299+ optimize_global_scale ,
205300 )
206301
207- # Note that observed.shape = (num_observations, *qparams_shape, group_size).
208- # For the purposes of fake quantization, this is equivalent to token quant
209- with patch_attr (args , "strategy" , QuantizationStrategy .TOKEN ):
210- q = fake_quantize (
211- observed ,
212- candidate_scales .unsqueeze (- 1 ),
213- candidate_zero_points .unsqueeze (- 1 ),
214- args ,
215- global_scale = global_scale ,
216- ).to (observed .dtype )
217- # Note that due to forward quantization implementation, token quant,
218- # unlike tensor_group, requires extra dtype cast
219-
220- q -= observed
221- q .abs_ ()
222- q .pow_ (norm )
223- err = torch .sum (q , dim = (0 , - 1 ))
224-
225302 tmp = err < best_error
226303 if torch .any (tmp ):
227304 best_error [tmp ] = err [tmp ]
0 commit comments