From 5352b1cec7c4b4469df9a62be1efe0aaa45cfcf9 Mon Sep 17 00:00:00 2001 From: Avishek Goswami Date: Thu, 11 Dec 2025 10:14:17 +0530 Subject: [PATCH 1/4] Add MSE vs MinMax observer comparison tests - Add comprehensive test suite comparing MSE and MinMax observers - Test on random tensors with various distributions - Test on real model weights from transformers - Add 'slow' pytest marker to pyproject.toml for long-running tests Signed-off-by: Avishek Goswami --- pyproject.toml | 1 + .../observers/test_mse_vs_minmax.py | 309 ++++++++++++++++++ 2 files changed, 310 insertions(+) create mode 100644 tests/llmcompressor/observers/test_mse_vs_minmax.py diff --git a/pyproject.toml b/pyproject.toml index 9398eddf8e..b79779eeba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,5 +23,6 @@ markers = [ "unit: tests to ensure code correctness and regression test functionality", "example: tests for content in the 'examples' folder", "multi_gpu: tests that require multiple GPUs", + "slow: tests that take a long time to run (e.g., downloading models)", ] tmp_path_retention_policy = "failed" diff --git a/tests/llmcompressor/observers/test_mse_vs_minmax.py b/tests/llmcompressor/observers/test_mse_vs_minmax.py new file mode 100644 index 0000000000..110e15754b --- /dev/null +++ b/tests/llmcompressor/observers/test_mse_vs_minmax.py @@ -0,0 +1,309 @@ +""" +Test to verify that MSE observer performs equal to or better than MinMax observer +on various tensor distributions, including normal distributions (similar to real weights) +and actual model weights. + +This test checks that the quantization error (MSE) from using MSE observer +is less than or equal to the error from using MinMax observer. +""" + +import pytest +import torch +from compressed_tensors.quantization import fake_quantize +from compressed_tensors.quantization.quant_args import QuantizationArgs + +from llmcompressor.observers import Observer + + +def _create_base_quantization_args(num_bits, strategy, symmetric, group_size): + """Helper to create base QuantizationArgs without observer field.""" + return QuantizationArgs( + num_bits=num_bits, + strategy=strategy, + symmetric=symmetric, + group_size=group_size, + ) + + +def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None): + """ + Helper function to run observer and compute quantization error. + + Returns: (scale, zero_point, quantized_tensor, mse, global_scale) + """ + weights = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) + weights.observer = observer_name + + observer = Observer.load_from_registry( + observer_name, base_name="weight", args=weights, module=module + ) + + global_scale = None + if strategy == "tensor_group" and module is not None: + global_scale = observer.get_global_scale(tensor) + module.weight_global_scale = global_scale + + scale, zero_point = observer(tensor) + + # Sanity check: scales should be non-negative + assert (scale >= 0).all(), "Scale values should be non-negative" + + weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) + quantized = fake_quantize( + tensor, scale, zero_point, weights_clean, + global_scale=global_scale if strategy == "tensor_group" else None + ) + mse = torch.nn.functional.mse_loss(quantized, tensor) + + return scale, zero_point, quantized, mse, global_scale + + +def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False): + """ + Assert MSE observer performance with appropriate slack. + + For tensor+symmetric: strict assertion (MSE should be better) + For others: allow slack (10% for synthetic, 20% for real weights) + Also add epsilon to handle cases where minmax_mse is near 0. + """ + epsilon = 1e-8 + slack = 1.20 if is_real_weights else 1.10 + + if strategy == "tensor" and symmetric: + # Cases where MSE SHOULD be better + assert mse_mse <= minmax_mse + epsilon, ( + f"MSE observer performed worse than MinMax observer!\n" + f"Strategy: {strategy}, Symmetric: {symmetric}\n" + f"MinMax MSE: {minmax_mse.item():.6e}\n" + f"MSE Observer MSE: {mse_mse.item():.6e}\n" + f"Difference: {(mse_mse - minmax_mse).item():.6e}" + ) + else: + # Not guaranteed, but ensure not catastrophically worse + assert mse_mse <= minmax_mse * slack + epsilon, ( + f"MSE observer performed significantly worse than MinMax observer!\n" + f"Strategy: {strategy}, Symmetric: {symmetric}\n" + f"MinMax MSE: {minmax_mse.item():.6e}\n" + f"MSE Observer MSE: {mse_mse.item():.6e}\n" + f"Difference: {(mse_mse - minmax_mse).item():.6e}\n" + f"Ratio: {(mse_mse / (minmax_mse + epsilon)).item():.4f}x" + ) + + +@pytest.mark.parametrize( + "strategy,symmetric,num_bits", + [ + ("tensor", True, 8), + ("tensor", False, 8), + ("channel", True, 8), + ("channel", False, 8), + ("tensor_group", True, 4), + ("tensor_group", False, 4), + ("channel", True, 4), + ("channel", False, 4), + ], +) +@pytest.mark.parametrize( + "std", + [0.05, 0.2, 1.0], + ids=["narrow", "medium", "wide"], +) +def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): + """ + Test that MSE observer produces quantization error <= MinMax observer + on random tensors with normal distribution (similar to real model weights). + + Real model weights typically follow a normal distribution with: + - Mean near 0 + - Standard deviation around 0.02-0.1 for initialized weights + - Range roughly [-0.5, 0.5] for most layers + + Testing with different std values exposes cases where MinMax performs poorly + on wide or heavy-tailed distributions, where MSE should shine. + """ + # Generate random tensor with normal distribution similar to real weights + torch.manual_seed(42) + # Use different std values to test various distribution widths + tensor = torch.randn(128, 256) * std # Normal distribution with specified std + + group_size = 32 if strategy == "tensor_group" else None + + # Create separate modules for tensor_group to avoid shared mutable state + module_minmax = None + module_mse = None + if strategy == "tensor_group": + module_minmax = torch.nn.Linear(256, 128) + module_minmax.weight.data = tensor.T + module_mse = torch.nn.Linear(256, 128) + module_mse.weight.data = tensor.T + + # Test with MinMax observer + _, _, _, minmax_mse, _ = _run_observer_test( + tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax + ) + + # Test with MSE observer + _, _, _, mse_mse, _ = _run_observer_test( + tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse + ) + + # Assert with appropriate slack for synthetic data + _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False) + + +@pytest.mark.parametrize( + "tensor_shape", + [ + (64, 128), + (128, 256), + (256, 512), + (32, 64, 128), # 3D tensor + ], +) +def test_mse_vs_minmax_various_shapes(tensor_shape): + """ + Test MSE vs MinMax on tensors of various shapes with normal distribution. + Uses realistic weight distribution parameters. + """ + torch.manual_seed(42) + # Use realistic weight distribution: mean=0, std=0.05 + tensor = torch.randn(*tensor_shape) * 0.05 + + # MinMax + _, _, _, minmax_mse, _ = _run_observer_test( + tensor, "memoryless_minmax", "channel", True, 8, None, None + ) + + # MSE + _, _, _, mse_mse, _ = _run_observer_test( + tensor, "memoryless_mse", "channel", True, 8, None, None + ) + + # Channel quantization: MSE not guaranteed better, allow 10% slack + _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) + + +def test_mse_vs_minmax_extreme_values(): + """Test MSE vs MinMax on tensors with extreme values.""" + torch.manual_seed(42) + + # Test with very small values + tensor_small = torch.randn(64, 128) * 0.01 + # Test with very large values + tensor_large = torch.randn(64, 128) * 100.0 + # Test with skewed distribution + tensor_skewed = torch.cat([ + torch.randn(64, 100) * 0.1, + torch.randn(64, 28) * 10.0 + ], dim=1) + + for tensor, name in [ + (tensor_small, "small"), + (tensor_large, "large"), + (tensor_skewed, "skewed"), + ]: + weights = QuantizationArgs( + num_bits=8, + strategy="channel", + symmetric=True, + observer="memoryless_minmax", + ) + + # MinMax + _, _, _, minmax_mse, _ = _run_observer_test( + tensor, "memoryless_minmax", "channel", True, 8, None, None + ) + + # MSE + _, _, _, mse_mse, _ = _run_observer_test( + tensor, "memoryless_mse", "channel", True, 8, None, None + ) + + # Channel quantization: MSE not guaranteed better, allow 10% slack + _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "strategy,symmetric,num_bits", + [ + ("channel", True, 8), + ("channel", False, 8), + ("tensor_group", True, 4), + ("tensor_group", False, 4), + ], +) +def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): + """ + Test that MSE observer produces quantization error <= MinMax observer + on actual model weights from a real neural network. + + This test loads weights from a small model to verify observer behavior + on real weight distributions, which may differ from synthetic data. + """ + try: + from transformers import AutoModelForCausalLM + except ImportError: + pytest.skip("transformers not available") + + # Use a small, publicly available model for testing + model_id = "nm-testing/tinysmokellama-3.2" + + try: + # Load model and extract a weight tensor + # Use no_grad context to avoid unnecessary gradient computation + with torch.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32 + ) + + # Get a representative weight tensor (e.g., from first Linear layer) + weight_tensor = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and weight_tensor is None: + weight_tensor = module.weight.data.clone() + break + + if weight_tensor is None: + pytest.skip("No Linear layer found in model") + + # Flatten or reshape to 2D if needed for testing + if weight_tensor.dim() > 2: + weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1]) + elif weight_tensor.dim() == 1: + weight_tensor = weight_tensor.unsqueeze(0) + + # Limit size for faster testing + if weight_tensor.shape[0] > 512: + weight_tensor = weight_tensor[:512, :] + if weight_tensor.shape[1] > 512: + weight_tensor = weight_tensor[:, :512] + + except Exception as e: + pytest.skip(f"Could not load model {model_id}: {e}") + + group_size = 32 if strategy == "tensor_group" else None + + # Create separate modules for tensor_group to avoid shared mutable state + module_minmax = None + module_mse = None + if strategy == "tensor_group": + module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) + module_minmax.weight.data = weight_tensor.T + module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) + module_mse.weight.data = weight_tensor.T + + # Test with MinMax observer + _, _, _, minmax_mse, _ = _run_observer_test( + weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax + ) + + # Test with MSE observer + _, _, _, mse_mse, _ = _run_observer_test( + weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse + ) + + # For channel and tensor_group strategies, MSE is not guaranteed to be better + # Allow 20% slack for real model weights (more structure & extreme channels) + _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True) + From 9e16c3c295d665a1cf17dfa0c9ccfb8d049fb0c4 Mon Sep 17 00:00:00 2001 From: Avishek Goswami Date: Thu, 11 Dec 2025 10:20:21 +0530 Subject: [PATCH 2/4] Fix shape mismatches and remove dead code in MSE vs MinMax tests - Fix tensor shape mismatch: use tensor directly instead of tensor.T - Fix weight_tensor shape mismatch: use weight_tensor directly instead of weight_tensor.T - Remove unused weights variable in test_mse_vs_minmax_extreme_values Signed-off-by: Avishek Goswami --- .../llmcompressor/observers/test_mse_vs_minmax.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/llmcompressor/observers/test_mse_vs_minmax.py b/tests/llmcompressor/observers/test_mse_vs_minmax.py index 110e15754b..c13c815e58 100644 --- a/tests/llmcompressor/observers/test_mse_vs_minmax.py +++ b/tests/llmcompressor/observers/test_mse_vs_minmax.py @@ -133,9 +133,9 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): module_mse = None if strategy == "tensor_group": module_minmax = torch.nn.Linear(256, 128) - module_minmax.weight.data = tensor.T + module_minmax.weight.data = tensor module_mse = torch.nn.Linear(256, 128) - module_mse.weight.data = tensor.T + module_mse.weight.data = tensor # Test with MinMax observer _, _, _, minmax_mse, _ = _run_observer_test( @@ -202,13 +202,6 @@ def test_mse_vs_minmax_extreme_values(): (tensor_large, "large"), (tensor_skewed, "skewed"), ]: - weights = QuantizationArgs( - num_bits=8, - strategy="channel", - symmetric=True, - observer="memoryless_minmax", - ) - # MinMax _, _, _, minmax_mse, _ = _run_observer_test( tensor, "memoryless_minmax", "channel", True, 8, None, None @@ -289,9 +282,9 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): module_mse = None if strategy == "tensor_group": module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) - module_minmax.weight.data = weight_tensor.T + module_minmax.weight.data = weight_tensor module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) - module_mse.weight.data = weight_tensor.T + module_mse.weight.data = weight_tensor # Test with MinMax observer _, _, _, minmax_mse, _ = _run_observer_test( From 30fff569e2465115170d9ea0915bde30ff82c56e Mon Sep 17 00:00:00 2001 From: Avishek Goswami Date: Thu, 11 Dec 2025 10:37:15 +0530 Subject: [PATCH 3/4] Remove excessive comments from test file Signed-off-by: Avishek Goswami --- .../observers/test_mse_vs_minmax.py | 67 ++----------------- 1 file changed, 5 insertions(+), 62 deletions(-) diff --git a/tests/llmcompressor/observers/test_mse_vs_minmax.py b/tests/llmcompressor/observers/test_mse_vs_minmax.py index c13c815e58..7541f68059 100644 --- a/tests/llmcompressor/observers/test_mse_vs_minmax.py +++ b/tests/llmcompressor/observers/test_mse_vs_minmax.py @@ -44,8 +44,6 @@ def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, gro module.weight_global_scale = global_scale scale, zero_point = observer(tensor) - - # Sanity check: scales should be non-negative assert (scale >= 0).all(), "Scale values should be non-negative" weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) @@ -59,18 +57,11 @@ def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, gro def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False): - """ - Assert MSE observer performance with appropriate slack. - - For tensor+symmetric: strict assertion (MSE should be better) - For others: allow slack (10% for synthetic, 20% for real weights) - Also add epsilon to handle cases where minmax_mse is near 0. - """ + """Assert MSE observer performance with appropriate slack.""" epsilon = 1e-8 slack = 1.20 if is_real_weights else 1.10 if strategy == "tensor" and symmetric: - # Cases where MSE SHOULD be better assert mse_mse <= minmax_mse + epsilon, ( f"MSE observer performed worse than MinMax observer!\n" f"Strategy: {strategy}, Symmetric: {symmetric}\n" @@ -79,7 +70,6 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei f"Difference: {(mse_mse - minmax_mse).item():.6e}" ) else: - # Not guaranteed, but ensure not catastrophically worse assert mse_mse <= minmax_mse * slack + epsilon, ( f"MSE observer performed significantly worse than MinMax observer!\n" f"Strategy: {strategy}, Symmetric: {symmetric}\n" @@ -109,26 +99,12 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei ids=["narrow", "medium", "wide"], ) def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): - """ - Test that MSE observer produces quantization error <= MinMax observer - on random tensors with normal distribution (similar to real model weights). - - Real model weights typically follow a normal distribution with: - - Mean near 0 - - Standard deviation around 0.02-0.1 for initialized weights - - Range roughly [-0.5, 0.5] for most layers - - Testing with different std values exposes cases where MinMax performs poorly - on wide or heavy-tailed distributions, where MSE should shine. - """ - # Generate random tensor with normal distribution similar to real weights + """Test that MSE observer produces quantization error <= MinMax observer on random tensors.""" torch.manual_seed(42) - # Use different std values to test various distribution widths - tensor = torch.randn(128, 256) * std # Normal distribution with specified std + tensor = torch.randn(128, 256) * std group_size = 32 if strategy == "tensor_group" else None - # Create separate modules for tensor_group to avoid shared mutable state module_minmax = None module_mse = None if strategy == "tensor_group": @@ -137,17 +113,14 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): module_mse = torch.nn.Linear(256, 128) module_mse.weight.data = tensor - # Test with MinMax observer _, _, _, minmax_mse, _ = _run_observer_test( tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax ) - # Test with MSE observer _, _, _, mse_mse, _ = _run_observer_test( tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse ) - # Assert with appropriate slack for synthetic data _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False) @@ -161,25 +134,18 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): ], ) def test_mse_vs_minmax_various_shapes(tensor_shape): - """ - Test MSE vs MinMax on tensors of various shapes with normal distribution. - Uses realistic weight distribution parameters. - """ + """Test MSE vs MinMax on tensors of various shapes.""" torch.manual_seed(42) - # Use realistic weight distribution: mean=0, std=0.05 tensor = torch.randn(*tensor_shape) * 0.05 - # MinMax _, _, _, minmax_mse, _ = _run_observer_test( tensor, "memoryless_minmax", "channel", True, 8, None, None ) - # MSE _, _, _, mse_mse, _ = _run_observer_test( tensor, "memoryless_mse", "channel", True, 8, None, None ) - # Channel quantization: MSE not guaranteed better, allow 10% slack _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) @@ -187,11 +153,8 @@ def test_mse_vs_minmax_extreme_values(): """Test MSE vs MinMax on tensors with extreme values.""" torch.manual_seed(42) - # Test with very small values tensor_small = torch.randn(64, 128) * 0.01 - # Test with very large values tensor_large = torch.randn(64, 128) * 100.0 - # Test with skewed distribution tensor_skewed = torch.cat([ torch.randn(64, 100) * 0.1, torch.randn(64, 28) * 10.0 @@ -202,17 +165,14 @@ def test_mse_vs_minmax_extreme_values(): (tensor_large, "large"), (tensor_skewed, "skewed"), ]: - # MinMax _, _, _, minmax_mse, _ = _run_observer_test( tensor, "memoryless_minmax", "channel", True, 8, None, None ) - # MSE _, _, _, mse_mse, _ = _run_observer_test( tensor, "memoryless_mse", "channel", True, 8, None, None ) - # Channel quantization: MSE not guaranteed better, allow 10% slack _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) @@ -227,30 +187,20 @@ def test_mse_vs_minmax_extreme_values(): ], ) def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): - """ - Test that MSE observer produces quantization error <= MinMax observer - on actual model weights from a real neural network. - - This test loads weights from a small model to verify observer behavior - on real weight distributions, which may differ from synthetic data. - """ + """Test that MSE observer produces quantization error <= MinMax observer on real model weights.""" try: from transformers import AutoModelForCausalLM except ImportError: pytest.skip("transformers not available") - # Use a small, publicly available model for testing model_id = "nm-testing/tinysmokellama-3.2" try: - # Load model and extract a weight tensor - # Use no_grad context to avoid unnecessary gradient computation with torch.no_grad(): model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32 ) - # Get a representative weight tensor (e.g., from first Linear layer) weight_tensor = None for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and weight_tensor is None: @@ -260,13 +210,11 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): if weight_tensor is None: pytest.skip("No Linear layer found in model") - # Flatten or reshape to 2D if needed for testing if weight_tensor.dim() > 2: weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1]) elif weight_tensor.dim() == 1: weight_tensor = weight_tensor.unsqueeze(0) - # Limit size for faster testing if weight_tensor.shape[0] > 512: weight_tensor = weight_tensor[:512, :] if weight_tensor.shape[1] > 512: @@ -277,7 +225,6 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): group_size = 32 if strategy == "tensor_group" else None - # Create separate modules for tensor_group to avoid shared mutable state module_minmax = None module_mse = None if strategy == "tensor_group": @@ -286,17 +233,13 @@ def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) module_mse.weight.data = weight_tensor - # Test with MinMax observer _, _, _, minmax_mse, _ = _run_observer_test( weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax ) - # Test with MSE observer _, _, _, mse_mse, _ = _run_observer_test( weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse ) - # For channel and tensor_group strategies, MSE is not guaranteed to be better - # Allow 20% slack for real model weights (more structure & extreme channels) _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True) From 2c9b5a19e9f92b647f531338e7c8dd81c6c0bcc1 Mon Sep 17 00:00:00 2001 From: Avishek Goswami Date: Wed, 17 Dec 2025 08:20:04 +0530 Subject: [PATCH 4/4] Fix ruff formatting in MSE vs MinMax observer tests Signed-off-by: Avishek Goswami --- .../observers/test_mse_vs_minmax.py | 132 ++++++++++++------ 1 file changed, 87 insertions(+), 45 deletions(-) diff --git a/tests/llmcompressor/observers/test_mse_vs_minmax.py b/tests/llmcompressor/observers/test_mse_vs_minmax.py index 7541f68059..d8a9609e2c 100644 --- a/tests/llmcompressor/observers/test_mse_vs_minmax.py +++ b/tests/llmcompressor/observers/test_mse_vs_minmax.py @@ -1,6 +1,7 @@ """ Test to verify that MSE observer performs equal to or better than MinMax observer -on various tensor distributions, including normal distributions (similar to real weights) +on various tensor distributions, including normal distributions (similar to real +weights) and actual model weights. This test checks that the quantization error (MSE) from using MSE observer @@ -25,42 +26,51 @@ def _create_base_quantization_args(num_bits, strategy, symmetric, group_size): ) -def _run_observer_test(tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None): +def _run_observer_test( + tensor, observer_name, strategy, symmetric, num_bits, group_size, module=None +): """ Helper function to run observer and compute quantization error. - + Returns: (scale, zero_point, quantized_tensor, mse, global_scale) """ weights = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) weights.observer = observer_name - + observer = Observer.load_from_registry( observer_name, base_name="weight", args=weights, module=module ) - + global_scale = None if strategy == "tensor_group" and module is not None: global_scale = observer.get_global_scale(tensor) module.weight_global_scale = global_scale - + scale, zero_point = observer(tensor) assert (scale >= 0).all(), "Scale values should be non-negative" - - weights_clean = _create_base_quantization_args(num_bits, strategy, symmetric, group_size) + + weights_clean = _create_base_quantization_args( + num_bits, strategy, symmetric, group_size + ) quantized = fake_quantize( - tensor, scale, zero_point, weights_clean, + tensor, + scale, + zero_point, + weights_clean, global_scale=global_scale if strategy == "tensor_group" else None ) mse = torch.nn.functional.mse_loss(quantized, tensor) - + return scale, zero_point, quantized, mse, global_scale -def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False): +def _assert_mse_comparison( + mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False +): """Assert MSE observer performance with appropriate slack.""" epsilon = 1e-8 slack = 1.20 if is_real_weights else 1.10 - + if strategy == "tensor" and symmetric: assert mse_mse <= minmax_mse + epsilon, ( f"MSE observer performed worse than MinMax observer!\n" @@ -99,12 +109,12 @@ def _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_wei ids=["narrow", "medium", "wide"], ) def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): - """Test that MSE observer produces quantization error <= MinMax observer on random tensors.""" + """Test MSE observer error <= MinMax observer error on random tensors.""" torch.manual_seed(42) tensor = torch.randn(128, 256) * std - + group_size = 32 if strategy == "tensor_group" else None - + module_minmax = None module_mse = None if strategy == "tensor_group": @@ -112,16 +122,30 @@ def test_mse_vs_minmax_on_random_tensor(strategy, symmetric, num_bits, std): module_minmax.weight.data = tensor module_mse = torch.nn.Linear(256, 128) module_mse.weight.data = tensor - + _, _, _, minmax_mse, _ = _run_observer_test( - tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax + tensor, + "memoryless_minmax", + strategy, + symmetric, + num_bits, + group_size, + module_minmax, ) - + _, _, _, mse_mse, _ = _run_observer_test( - tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse + tensor, + "memoryless_mse", + strategy, + symmetric, + num_bits, + group_size, + module_mse, + ) + + _assert_mse_comparison( + mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False ) - - _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=False) @pytest.mark.parametrize( @@ -137,29 +161,29 @@ def test_mse_vs_minmax_various_shapes(tensor_shape): """Test MSE vs MinMax on tensors of various shapes.""" torch.manual_seed(42) tensor = torch.randn(*tensor_shape) * 0.05 - + _, _, _, minmax_mse, _ = _run_observer_test( tensor, "memoryless_minmax", "channel", True, 8, None, None ) - + _, _, _, mse_mse, _ = _run_observer_test( tensor, "memoryless_mse", "channel", True, 8, None, None ) - + _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) def test_mse_vs_minmax_extreme_values(): """Test MSE vs MinMax on tensors with extreme values.""" torch.manual_seed(42) - + tensor_small = torch.randn(64, 128) * 0.01 tensor_large = torch.randn(64, 128) * 100.0 tensor_skewed = torch.cat([ torch.randn(64, 100) * 0.1, torch.randn(64, 28) * 10.0 ], dim=1) - + for tensor, name in [ (tensor_small, "small"), (tensor_large, "large"), @@ -168,12 +192,14 @@ def test_mse_vs_minmax_extreme_values(): _, _, _, minmax_mse, _ = _run_observer_test( tensor, "memoryless_minmax", "channel", True, 8, None, None ) - + _, _, _, mse_mse, _ = _run_observer_test( tensor, "memoryless_mse", "channel", True, 8, None, None ) - - _assert_mse_comparison(mse_mse, minmax_mse, "channel", True, is_real_weights=False) + + _assert_mse_comparison( + mse_mse, minmax_mse, "channel", True, is_real_weights=False + ) @pytest.mark.slow @@ -187,59 +213,75 @@ def test_mse_vs_minmax_extreme_values(): ], ) def test_mse_vs_minmax_on_real_model_weights(strategy, symmetric, num_bits): - """Test that MSE observer produces quantization error <= MinMax observer on real model weights.""" + """Test MSE observer error <= MinMax observer error on real model weights.""" try: from transformers import AutoModelForCausalLM except ImportError: pytest.skip("transformers not available") model_id = "nm-testing/tinysmokellama-3.2" - + try: with torch.no_grad(): model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float32 ) - + weight_tensor = None for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and weight_tensor is None: weight_tensor = module.weight.data.clone() break - + if weight_tensor is None: pytest.skip("No Linear layer found in model") - + if weight_tensor.dim() > 2: weight_tensor = weight_tensor.view(-1, weight_tensor.shape[-1]) elif weight_tensor.dim() == 1: weight_tensor = weight_tensor.unsqueeze(0) - + if weight_tensor.shape[0] > 512: weight_tensor = weight_tensor[:512, :] if weight_tensor.shape[1] > 512: weight_tensor = weight_tensor[:, :512] - + except Exception as e: pytest.skip(f"Could not load model {model_id}: {e}") - + group_size = 32 if strategy == "tensor_group" else None - + module_minmax = None module_mse = None if strategy == "tensor_group": - module_minmax = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) + module_minmax = torch.nn.Linear( + weight_tensor.shape[1], weight_tensor.shape[0] + ) module_minmax.weight.data = weight_tensor module_mse = torch.nn.Linear(weight_tensor.shape[1], weight_tensor.shape[0]) module_mse.weight.data = weight_tensor - + _, _, _, minmax_mse, _ = _run_observer_test( - weight_tensor, "memoryless_minmax", strategy, symmetric, num_bits, group_size, module_minmax + weight_tensor, + "memoryless_minmax", + strategy, + symmetric, + num_bits, + group_size, + module_minmax, ) - + _, _, _, mse_mse, _ = _run_observer_test( - weight_tensor, "memoryless_mse", strategy, symmetric, num_bits, group_size, module_mse + weight_tensor, + "memoryless_mse", + strategy, + symmetric, + num_bits, + group_size, + module_mse, + ) + + _assert_mse_comparison( + mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True ) - - _assert_mse_comparison(mse_mse, minmax_mse, strategy, symmetric, is_real_weights=True)