-
Notifications
You must be signed in to change notification settings - Fork 210
Feat (llm/awq): activation-aware weight scaling #1213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
src/brevitas/graph/equalize.py
Outdated
@@ -780,9 +781,11 @@ def _no_equalize(): | |||
for module in chain(src_axes.values(), sink_axes.values()): | |||
rewriters.extend(module.instantiate_rewriters(rewriter_class, scaling_factors)) | |||
|
|||
# Apply rewriters before offloading | |||
# Apply rewriters before offloading, if parametrize_inplace is True. Note that parametrizations | |||
# are not immediately to prevent potential errors if the model is offloaded. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate a bit more the issue here?
raise ValueError # early exit to break later inference | ||
|
||
# patch layer 0 to catch input and kwargs | ||
layers[0] = Catcher(layers[0]) | ||
blocks[0] = Catcher(blocks[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this part of the codebase, why can't we do what we do in GPTQ to catch the input to the first block?
We can also move that piece of code to some utils in exmples/common/generative
src/brevitas/utils/python_utils.py
Outdated
@@ -64,3 +65,30 @@ def run(*args, **kwargs): | |||
return function(*args, **kwargs) | |||
|
|||
return run | |||
|
|||
|
|||
def longest_common_prefix(strings: List[str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems overly specific to AWQ, not sure if this should live here
"ffn.act": block.ffn.act, | ||
"ffn.down_proj": block.ffn.down_proj,}, | ||
)) | ||
elif "falcon" in str(block.__class__).lower(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only Llama for now
@@ -370,6 +370,18 @@ def create_llm_args_parser(): | |||
default=[], | |||
nargs='*', | |||
help='A list of module names to expand with hadamard rotation. Default: %(default)s') | |||
parser.add_argument( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Readme
src/brevitas/graph/calibrate.py
Outdated
@@ -251,6 +287,65 @@ def apply(self, model, is_training, quantization_enabled): | |||
self.enable_param_quantization(model, is_training) | |||
|
|||
|
|||
class disable_enable_quantization: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have another class that does this as well, not in a context manager fashion.
I think we might consider just switching to this new class everywhere?
The main consideration is that we need handle disabling quantization for activation calibration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll handle it in a separate PR.
src/brevitas/graph/equalize.py
Outdated
for r in rewriters: | ||
model = r.apply(model) | ||
if parametrize_inplace or not isinstance(r, ModuleInstanceRegisterParametrization): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this. The comment above doesn't address the parametrize_inplace
flag and how the two combines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a leftover. I've removed it.
9789efe
to
8a155a5
Compare
src/brevitas_examples/llm/main.py
Outdated
model=model, | ||
tokenizer=tokenizer, | ||
args=args, | ||
n_samples=128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be n_samples=args.n_samples
and seqlen=args.seqlen
instead of hard-coded to 128 and 512, respectively?
170b873
to
8d65b34
Compare
24f93ee
to
500fe77
Compare
Reason for this PR
Implementation of AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
Using weight-only quantization and the configuration:
*Minor differences observed in perplexity between the original repository and Brevitas are due to order of operations/differences in quantizers.
Changes Made in this PR
RegionAWQ
, inheriting fromRegion
to aggregate the information of the modules s on which AWQ optimizes the scale.auto_scale
andauto_clip
to rely on Brevitas quantizers.Testing Summary
Testing
apply_awq
against the author's repository.Risk Highlight
Checklist
dev
branch.