-
Notifications
You must be signed in to change notification settings - Fork 490
Observers refactor #2585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Observers refactor #2585
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,6 @@ | |
| from loguru import logger | ||
|
|
||
| from llmcompressor.modifiers.utils import SPARSITY_THRESHOLD | ||
| from llmcompressor.observers.base import Observer | ||
| from llmcompressor.pytorch.utils.helpers import tensor_sparsity | ||
|
|
||
| GPTQ_PRECISION = torch.float32 | ||
|
|
@@ -85,31 +84,17 @@ def quantize_weight( | |
| """ | ||
| strategy = quant_args.strategy | ||
| actorder = quant_args.actorder | ||
| global_scale = getattr(module, "weight_global_scale", None) | ||
| final_shape = module.weight.shape | ||
| final_dtype = module.weight.dtype | ||
| W = module.weight.clone() | ||
| H = hessian | ||
|
|
||
| # create observer for calculating quantization parameters | ||
| observer = Observer.load_from_registry( | ||
| quant_args.observer if quant_args.observer else "memoryless_minmax", | ||
| base_name="weight", | ||
| args=quant_args, | ||
| module=module, | ||
| ) | ||
| observer = module.weight_observer | ||
|
|
||
| # standardize shape and dtype | ||
| match module: | ||
| case torch.nn.Conv2d(): | ||
| W = W.flatten(1) | ||
| case transformers.Conv1D(): | ||
| W.transpose_(0, 1) | ||
| W = W.to(dtype=GPTQ_PRECISION) | ||
| num_rows = W.shape[0] | ||
|
Comment on lines
86
to
95
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Confirm whether Conv1D / Conv2d are still routed through quantize_weight
rg -nP -C3 '\b(Conv1D|Conv2d)\b' src/llmcompressor/modifiers/gptq/
rg -nP -C3 'quantize_weight\s*\(' --type=py
# Look for any guards that exclude Conv layers before calling quantize_weight
rg -nP -C3 'isinstance\([^)]+,\s*(transformers\.)?Conv1D' --type=pyRepository: vllm-project/llm-compressor Length of output: 8173 🏁 Script executed: # Read the quantize_weight function from start to verify shape handling
sed -n '68,160p' src/llmcompressor/modifiers/gptq/gptq_quantize.py | cat -nRepository: vllm-project/llm-compressor Length of output: 4064 🏁 Script executed: sed -n '38,66p' src/llmcompressor/modifiers/gptq/gptq_quantize.py | cat -nRepository: vllm-project/llm-compressor Length of output: 1055 🏁 Script executed: # Check sparsegpt_sparsify.py to see how it handles shape standardization
sed -n '80,110p' src/llmcompressor/modifiers/pruning/sparsegpt/sgpt_sparsify.py | cat -nRepository: vllm-project/llm-compressor Length of output: 1372 🏁 Script executed: # Check if there are any guards/filters that prevent Conv2d/Conv1D from reaching quantize_weight
rg -B5 'quantize_weight\(' src/llmcompressor/modifiers/gptq/base.py | head -40Repository: vllm-project/llm-compressor Length of output: 328 🏁 Script executed: # Get more context around the quantize_weight call in base.py
sed -n '290,320p' src/llmcompressor/modifiers/gptq/base.py | cat -nRepository: vllm-project/llm-compressor Length of output: 1592 🏁 Script executed: # Check the logic for which modules get quantized - look for module type checks
rg -B10 'compress_module_list\(' src/llmcompressor/modifiers/gptq/base.py | head -60Repository: vllm-project/llm-compressor Length of output: 1011 🏁 Script executed: # Find where accumulate_hessian is called to see module filtering
rg -B5 'accumulate_hessian\(' src/llmcompressor/modifiers/gptq/ | head -80Repository: vllm-project/llm-compressor Length of output: 1103 🏁 Script executed: # Check weight_observer implementation to verify global_scale key
rg -A10 'class.*Observer' src/llmcompressor/modifiers/gptq/ | head -80Repository: vllm-project/llm-compressor Length of output: 53 🏁 Script executed: # Search for weight_observer setup and get_qparams implementation
rg 'weight_observer' src/llmcompressor/modifiers/gptq/ -A3 | head -50Repository: vllm-project/llm-compressor Length of output: 371 🏁 Script executed: # Find where weight_observer is created
rg 'weight_observer\s*=' src/llmcompressor/ -B3 -A3 | head -100Repository: vllm-project/llm-compressor Length of output: 53 🏁 Script executed: # Search for observer setup in the quantization scheme
rg 'Observer' src/llmcompressor/modifiers/gptq/ --type=py -lRepository: vllm-project/llm-compressor Length of output: 53 🏁 Script executed: # Find where get_qparams is defined
rg 'def get_qparams' src/llmcompressor/ --type=py -B2 -A10 | head -100Repository: vllm-project/llm-compressor Length of output: 916 🏁 Script executed: # Find compute_qparams_from_statistics to verify it returns global_scale
rg 'def compute_qparams_from_statistics' src/llmcompressor/ --type=py -A30 | head -100Repository: vllm-project/llm-compressor Length of output: 2580 🏁 Script executed: # Check test to see if Conv2d is tested with GPTQ
cat -n src/llmcompressor/modifiers/gptq/test_gptq_quantize.py | head -100Repository: vllm-project/llm-compressor Length of output: 158 🏁 Script executed: # Find test files for GPTQ
find . -name '*test*gptq*' -o -name '*gptq*test*' 2>/dev/null | head -20Repository: vllm-project/llm-compressor Length of output: 191 🏁 Script executed: # Check test_gptq_quantize.py to see what module types are tested
cat -n tests/llmcompressor/modifiers/gptq/test_gptq_quantize.py | head -80Repository: vllm-project/llm-compressor Length of output: 1422 Restore upfront shape standardization for Conv1D/Conv2d weights. The code path removed the initial
Compare with if isinstance(module, torch.nn.Conv2d):
W = W.flatten(1)
elif isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)Add equivalent guards at the start of 🤖 Prompt for AI Agents |
||
| num_columns = W.shape[1] | ||
|
|
||
| scale, zero_point = observer(W) | ||
| # handle g_idx and activation ordering | ||
| if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP): | ||
| # mapping from column index to group index | ||
|
|
@@ -121,14 +106,21 @@ def quantize_weight( | |
| if actorder == ActivationOrdering.GROUP: | ||
| W, H, perm = _apply_activation_ordering(W, H) | ||
| # actually need scale/zp for permuted weight for this format | ||
| scale, zero_point = observer(W) | ||
| observer(W) | ||
| # use identity g_idx (invert permutation later) | ||
|
|
||
| elif actorder == ActivationOrdering.WEIGHT: | ||
| # permute weights and g_idx | ||
| W, H, perm = _apply_activation_ordering(W, H) | ||
| g_idx = g_idx[perm] | ||
|
|
||
| qparams = observer.get_qparams() | ||
| scale, zero_point, global_scale = ( | ||
| qparams["scale"], | ||
| qparams["zero_point"], | ||
| qparams["global_scale"], | ||
| ) | ||
|
|
||
| # sparsity mask | ||
| sparsity = tensor_sparsity(W) | ||
| preserve_zeros = sparsity >= SPARSITY_THRESHOLD | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fused observer cleanup — dangling references after freeze.
After
freeze_module_quantization(module)at Line 213 deletes each module'sweight_observer, the remaining fused observers in other modules still hold references to the deleted observer via their_fused_observerslist (populated byObserver.fuse). In this code path,update_qparamson the fused set has already completed before any freeze, so no incorrect results will occur. However, if the logic is ever refactored to recompute qparams after partial freezing, the stale references could surface as subtle bugs (e.g., referencing a detached observer's statistics).Consider either clearing
_fused_observersinObserver.detach/freeze_module_quantization, or documenting that freezing must only happen after all fused peers have completed qparam computation.🤖 Prompt for AI Agents