|
| 1 | +""" |
| 2 | +Test to verify that MSE observer performs equal to or better than MinMax observer |
| 3 | +on various tensor distributions, including normal distributions (similar to real weights) |
| 4 | +and actual model weights. |
| 5 | +
|
| 6 | +This test checks that the quantization error (MSE) from using MSE observer |
| 7 | +is less than or equal to the error from using MinMax observer. |
| 8 | +""" |
| 9 | + |
| 10 | +import pytest |
| 11 | +import torch |
| 12 | +from compressed_tensors.quantization import fake_quantize |
| 13 | +from compressed_tensors.quantization.quant_args import QuantizationArgs |
| 14 | + |
| 15 | +from llmcompressor.observers import Observer |
| 16 | + |
| 17 | + |
| 18 | +def _create_base_quantization_args(num_bits, strategy, symmetric, group_size): |
| 19 | + """Helper to create base QuantizationArgs without observer field.""" |
| 20 | + return QuantizationArgs( |
| 21 | + num_bits=num_bits, |
| 22 | + strategy=strategy, |
| 23 | + symmetric=symmetric, |
| 24 | + group_size=group_size, |
| 25 | + ) |
| 26 | + |
| 27 | + |
| 28 | +def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None): |
| 29 | + """ |
| 30 | + Helper function to run observer and compute quantization error. |
| 31 | + |
| 32 | + Returns: (scale, zero_point, quantized_tensor, mse, global_scale) |
| 33 | + """ |
| 34 | + weights = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) |
| 35 | + weights.observer = observer_name |
| 36 | + |
| 37 | + observer = Observer.load_from_registry( |
| 38 | + observer_name, base_name="weight", args=weights, module=module |
| 39 | + ) |
| 40 | + |
| 41 | + global_scale = None |
| 42 | + if strategy == "tensor_group" and module is not None: |
| 43 | + global_scale = observer.get_global_scale(tensor) |
| 44 | + module.weight_global_scale = global_scale |
| 45 | + |
| 46 | + scale, zero_point = observer(tensor) |
| 47 | + |
| 48 | + # Sanity check: scales should be non-negative |
| 49 | + assert (scale >= 0).all(), "Scale values should be non-negative" |
| 50 | + |
| 51 | + weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) |
| 52 | + quantized = fake_quantize( |
| 53 | + tensor, scale, zero_point, weights_clean, |
| 54 | + global_scale=global_scale if strategy == "tensor_group" else None |
| 55 | + ) |
| 56 | + mse = torch.nn.functional.mse_loss(quantized, tensor) |
| 57 | + |
| 58 | + return scale, zero_point, quantized, mse, global_scale |
| 59 | + |
| 60 | + |
| 61 | +def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False): |
| 62 | + """ |
| 63 | + Assert MSE observer performance with appropriate slack. |
| 64 | + |
| 65 | + For tensor+symmetric: strict assertion (MSE should be better) |
| 66 | + For others: allow slack (10% for synthetic, 20% for real weights) |
| 67 | + Also add epsilon to handle cases where minmax_mse is near 0. |
| 68 | + """ |
| 69 | + epsilon = 1e-8 |
| 70 | + slack = 1.20 if is_real_weights else 1.10 |
| 71 | + |
| 72 | + if strategy == "tensor" and symmetric: |
| 73 | + # Cases where MSE SHOULD be better |
| 74 | + assert mse_mse <= minmax_mse + epsilon, ( |
| 75 | + f"MSE observer performed worse than MinMax observer!\n" |
| 76 | + f"Strategy: {strategy}, Symmetric: {symmetric}\n" |
| 77 | + f"MinMax MSE: {minmax_mse.item():.6e}\n" |
| 78 | + f"MSE Observer MSE: {mse_mse.item():.6e}\n" |
| 79 | + f"Difference: {(mse_mse - minmax_mse).item():.6e}" |
| 80 | + ) |
| 81 | + else: |
| 82 | + # Not guaranteed, but ensure not catastrophically worse |
| 83 | + assert mse_mse <= minmax_mse * slack + epsilon, ( |
| 84 | + f"MSE observer performed significantly worse than MinMax observer!\n" |
| 85 | + f"Strategy: {strategy}, Symmetric: {symmetric}\n" |
| 86 | + f"MinMax MSE: {minmax_mse.item():.6e}\n" |
| 87 | + f"MSE Observer MSE: {mse_mse.item():.6e}\n" |
| 88 | + f"Difference: {(mse_mse - minmax_mse).item():.6e}\n" |
| 89 | + f"Ratio: {(mse_mse / (minmax_mse + epsilon)).item():.4f}x" |
| 90 | + ) |
| 91 | + |
| 92 | + |
| 93 | +@pytest.mark.parametrize( |
| 94 | + "strategy,symmetric,num_bits", |
| 95 | + [ |
| 96 | + ("tensor", True, 8), |
| 97 | + ("tensor", False, 8), |
| 98 | + ("channel", True, 8), |
| 99 | + ("channel", False, 8), |
| 100 | + ("tensor_group", True, 4), |
| 101 | + ("tensor_group", False, 4), |
| 102 | + ("channel", True, 4), |
| 103 | + ("channel", False, 4), |
| 104 | + ], |
| 105 | +) |
| 106 | +@pytest.mark.parametrize( |
| 107 | + "std", |
| 108 | + [0.05, 0.2, 1.0], |
| 109 | + ids=["narrow", "medium", "wide"], |
| 110 | +) |
| 111 | +def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): |
| 112 | + """ |
| 113 | + Test that MSE observer produces quantization error <= MinMax observer |
| 114 | + on random tensors with normal distribution (similar to real model weights). |
| 115 | + |
| 116 | + Real model weights typically follow a normal distribution with: |
| 117 | + - Mean near 0 |
| 118 | + - Standard deviation around 0.02-0.1 for initialized weights |
| 119 | + - Range roughly [-0.5, 0.5] for most layers |
| 120 | + |
| 121 | + Testing with different std values exposes cases where MinMax performs poorly |
| 122 | + on wide or heavy-tailed distributions, where MSE should shine. |
| 123 | + """ |
| 124 | + # Generate random tensor with normal distribution similar to real weights |
| 125 | + torch.manual_seed(42) |
| 126 | + # Use different std values to test various distribution widths |
| 127 | + tensor = torch.randn(128, 256) * std # Normal distribution with specified std |
| 128 | + |
| 129 | + group_size = 32 if strategy == "tensor_group" else None |
| 130 | + |
| 131 | + # Create separate modules for tensor_group to avoid shared mutable state |
| 132 | + module_minmax = None |
| 133 | + module_mse = None |
| 134 | + if strategy == "tensor_group": |
| 135 | + module_minmax = torch.nn.Linear(256, 128) |
| 136 | + module_minmax.weight.data = tensor.T |
| 137 | + module_mse = torch.nn.Linear(256, 128) |
| 138 | + module_mse.weight.data = tensor.T |
| 139 | + |
| 140 | + # Test with MinMax observer |
| 141 | + _, _, _, minmax_mse, _ = _run_observer_test( |
| 142 | + tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax |
| 143 | + ) |
| 144 | + |
| 145 | + # Test with MSE observer |
| 146 | + _, _, _, mse_mse, _ = _run_observer_test( |
| 147 | + tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse |
| 148 | + ) |
| 149 | + |
| 150 | + # Assert with appropriate slack for synthetic data |
| 151 | + _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False) |
| 152 | + |
| 153 | + |
| 154 | +@pytest.mark.parametrize( |
| 155 | + "tensor_shape", |
| 156 | + [ |
| 157 | + (64, 128), |
| 158 | + (128, 256), |
| 159 | + (256, 512), |
| 160 | + (32, 64, 128), # 3D tensor |
| 161 | + ], |
| 162 | +) |
| 163 | +def test_mse_vs_minmax_various_shapes(tensor_shape): |
| 164 | + """ |
| 165 | + Test MSE vs MinMax on tensors of various shapes with normal distribution. |
| 166 | + Uses realistic weight distribution parameters. |
| 167 | + """ |
| 168 | + torch.manual_seed(42) |
| 169 | + # Use realistic weight distribution: mean=0, std=0.05 |
| 170 | + tensor = torch.randn(*tensor_shape) * 0.05 |
| 171 | + |
| 172 | + # MinMax |
| 173 | + _, _, _, minmax_mse, _ = _run_observer_test( |
| 174 | + tensor, "memoryless_minmax", "channel", True, 8, None, None |
| 175 | + ) |
| 176 | + |
| 177 | + # MSE |
| 178 | + _, _, _, mse_mse, _ = _run_observer_test( |
| 179 | + tensor, "memoryless_mse", "channel", True, 8, None, None |
| 180 | + ) |
| 181 | + |
| 182 | + # Channel quantization: MSE not guaranteed better, allow 10% slack |
| 183 | + _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) |
| 184 | + |
| 185 | + |
| 186 | +def test_mse_vs_minmax_extreme_values(): |
| 187 | + """Test MSE vs MinMax on tensors with extreme values.""" |
| 188 | + torch.manual_seed(42) |
| 189 | + |
| 190 | + # Test with very small values |
| 191 | + tensor_small = torch.randn(64, 128) * 0.01 |
| 192 | + # Test with very large values |
| 193 | + tensor_large = torch.randn(64, 128) * 100.0 |
| 194 | + # Test with skewed distribution |
| 195 | + tensor_skewed = torch.cat([ |
| 196 | + torch.randn(64, 100) * 0.1, |
| 197 | + torch.randn(64, 28) * 10.0 |
| 198 | + ], dim=1) |
| 199 | + |
| 200 | + for tensor, name in [ |
| 201 | + (tensor_small, "small"), |
| 202 | + (tensor_large, "large"), |
| 203 | + (tensor_skewed, "skewed"), |
| 204 | + ]: |
| 205 | + weights = QuantizationArgs( |
| 206 | + num_bits=8, |
| 207 | + strategy="channel", |
| 208 | + symmetric=True, |
| 209 | + observer="memoryless_minmax", |
| 210 | + ) |
| 211 | + |
| 212 | + # MinMax |
| 213 | + _, _, _, minmax_mse, _ = _run_observer_test( |
| 214 | + tensor, "memoryless_minmax", "channel", True, 8, None, None |
| 215 | + ) |
| 216 | + |
| 217 | + # MSE |
| 218 | + _, _, _, mse_mse, _ = _run_observer_test( |
| 219 | + tensor, "memoryless_mse", "channel", True, 8, None, None |
| 220 | + ) |
| 221 | + |
| 222 | + # Channel quantization: MSE not guaranteed better, allow 10% slack |
| 223 | + _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) |
| 224 | + |
| 225 | + |
| 226 | +@pytest.mark.slow |
| 227 | +@pytest.mark.parametrize( |
| 228 | + "strategy,symmetric,num_bits", |
| 229 | + [ |
| 230 | + ("channel", True, 8), |
| 231 | + ("channel", False, 8), |
| 232 | + ("tensor_group", True, 4), |
| 233 | + ("tensor_group", False, 4), |
| 234 | + ], |
| 235 | +) |
| 236 | +def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): |
| 237 | + """ |
| 238 | + Test that MSE observer produces quantization error <= MinMax observer |
| 239 | + on actual model weights from a real neural network. |
| 240 | + |
| 241 | + This test loads weights from a small model to verify observer behavior |
| 242 | + on real weight distributions, which may differ from synthetic data. |
| 243 | + """ |
| 244 | + try: |
| 245 | + from transformers import AutoModelForCausalLM |
| 246 | + except ImportError: |
| 247 | + pytest.skip("transformers not available") |
| 248 | + |
| 249 | + # Use a small, publicly available model for testing |
| 250 | + model_id = "nm-testing/tinysmokellama-3.2" |
| 251 | + |
| 252 | + try: |
| 253 | + # Load model and extract a weight tensor |
| 254 | + # Use no_grad context to avoid unnecessary gradient computation |
| 255 | + with torch.no_grad(): |
| 256 | + model = AutoModelForCausalLM.from_pretrained( |
| 257 | + model_id, torch_dtype=torch.float32 |
| 258 | + ) |
| 259 | + |
| 260 | + # Get a representative weight tensor (e.g., from first Linear layer) |
| 261 | + weight_tensor = None |
| 262 | + for name, module in model.named_modules(): |
| 263 | + if isinstance(module, torch.nn.Linear) and weight_tensor is None: |
| 264 | + weight_tensor = module.weight.data.clone() |
| 265 | + break |
| 266 | + |
| 267 | + if weight_tensor is None: |
| 268 | + pytest.skip("No Linear layer found in model") |
| 269 | + |
| 270 | + # Flatten or reshape to 2D if needed for testing |
| 271 | + if weight_tensor.dim() > 2: |
| 272 | + weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1]) |
| 273 | + elif weight_tensor.dim() == 1: |
| 274 | + weight_tensor = weight_tensor.unsqueeze(0) |
| 275 | + |
| 276 | + # Limit size for faster testing |
| 277 | + if weight_tensor.shape[0] > 512: |
| 278 | + weight_tensor = weight_tensor[:512, :] |
| 279 | + if weight_tensor.shape[1] > 512: |
| 280 | + weight_tensor = weight_tensor[:, :512] |
| 281 | + |
| 282 | + except Exception as e: |
| 283 | + pytest.skip(f"Could not load model {model_id}: {e}") |
| 284 | + |
| 285 | + group_size = 32 if strategy == "tensor_group" else None |
| 286 | + |
| 287 | + # Create separate modules for tensor_group to avoid shared mutable state |
| 288 | + module_minmax = None |
| 289 | + module_mse = None |
| 290 | + if strategy == "tensor_group": |
| 291 | + module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) |
| 292 | + module_minmax.weight.data = weight_tensor.T |
| 293 | + module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) |
| 294 | + module_mse.weight.data = weight_tensor.T |
| 295 | + |
| 296 | + # Test with MinMax observer |
| 297 | + _, _, _, minmax_mse, _ = _run_observer_test( |
| 298 | + weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax |
| 299 | + ) |
| 300 | + |
| 301 | + # Test with MSE observer |
| 302 | + _, _, _, mse_mse, _ = _run_observer_test( |
| 303 | + weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse |
| 304 | + ) |
| 305 | + |
| 306 | + # For channel and tensor_group strategies, MSE is not guaranteed to be better |
| 307 | + # Allow 20% slack for real model weights (more structure & extreme channels) |
| 308 | + _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True) |
| 309 | + |
0 commit comments