-
Notifications
You must be signed in to change notification settings - Fork 210
Feat (brevitas_examples/sdxl): better GPTQ #1250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
(self.groups, self.columns, self.columns), | ||
device='cpu', | ||
dtype=torch.float32, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I should re-introduce this but I was getting weird CUDA errors. Also, since we are storing in CPU, not sure why we need to use pin_memory.
@i-colbert maybe you can comment before we merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I can remember, it was used to improve data transfer speeds from GPU to CPU, which is why it is only enabled if a GPU is available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments / question, but in principle, approved.
@@ -242,13 +242,26 @@ def single_layer_update(self): | |||
pass | |||
|
|||
def get_quant_weights(self, i, i1, permutation_list, with_quant_history=False): | |||
from brevitas.quant_tensor import _unpack_quant_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not required - already imported.
@@ -252,16 +251,21 @@ def main(args): | |||
# Extend seeds based on batch_size | |||
test_seeds = [TEST_SEED] + [TEST_SEED + i for i in range(1, args.batch_size)] | |||
|
|||
is_flux_model = 'flux' in args.model.lower() | |||
is_flux = 'flux' in args.model.lower() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a possible more generic name for this flag? I see this applying to more models than just flux...
create_weight_orig=False, | ||
use_quant_activations=False, | ||
return_forward_output=True, | ||
with torch.no_grad(), quant_inference_mode(denoising_network, compile=args.compile_eval): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we always want quant_inference_mode
to be applied?
test_latents=latents, | ||
guidance_scale=args.guidance_scale, | ||
is_unet=is_unet) | ||
with torch.no_grad(), quant_inference_mode(denoising_network, compile=args.compile_eval): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we always want quant_inference_mode
to be applied?
Reason for this PR
Few shot calibration + GPTQ
Some details need refining
Changes Made in this PR
Testing Summary
Risk Highlight
Checklist
dev
branch.