Skip to content

Commit 6fbebea

Browse files
committed
Last touches
1 parent be753ec commit 6fbebea

File tree

2 files changed

+36
-24
lines changed

2 files changed

+36
-24
lines changed

src/brevitas/graph/gpxq.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,20 @@ def get_quant_weights(self, i, i1, permutation_list, with_quant_history=False):
252252
# We need to recompute quant weights at runtime since our float weights are being updated
253253
# Add offset in case of blockwise computation
254254
i = i1 + i
255+
255256
# For QuantLinear and for some QuantConvolutional layers, we exploit the possibility
256257
# of quantizing only a subset of the entire matrix speeding up the computation of GPxQ
258+
no_slice = False
259+
# Groupwise Quantization does not support slicing
260+
no_slice = no_slice or self.layer.weight_quant.is_groupwise
261+
# If we need quantization of past channels, we do not use slicing
262+
no_slice = no_slice or with_quant_history
263+
# If we are in export mode (i.e., inference mode), we do not slice for torch.compile
264+
# compatibility
265+
no_slice = no_slice or self.layer.weight_quant.export_mode
266+
257267
if isinstance(self.layer, qnn.QuantLinear):
258-
if True: #self.layer.weight_quant.is_groupwise or with_quant_history:
268+
if no_slice:
259269

260270
# No slicing, not optimized
261271
q = self.layer.quant_weight(quant_input=self.quant_metadata)
@@ -272,12 +282,12 @@ def get_quant_weights(self, i, i1, permutation_list, with_quant_history=False):
272282
subtensor_slice_list=subtensor_slice_list,
273283
quant_input=self.quant_metadata)).unsqueeze(0) # [1, OC, 1]
274284
elif isinstance(self.layer, SUPPORTED_CONV_OP):
275-
# For depthwise and ConvTranspose we fall back to quantizing the entire martix.
276-
# For all other cases, we create a mask that represent the slicing we will perform on the weight matrix
277-
# and we quantize only the selected dimensions.
278-
if True: #self.layer.weight_quant.is_groupwise or with_quant_history or self.groups > 1 or (
279-
# self.groups == 1 and
280-
# isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d))):
285+
# DepthWise and ConvTranspose does not support slicing
286+
no_slice_conv = no_slice or (
287+
self.groups > 1 or
288+
isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)))
289+
290+
if no_slice_conv:
281291

282292
quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata)
283293
quant_weight = _unpack_quant_tensor(quant_weight)

src/brevitas_examples/stable_diffusion/main.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -301,13 +301,14 @@ def main(args):
301301
subfolder='float')
302302

303303
if len(args.few_shot_calibration) > 0:
304+
args.few_shot_calibration = list(map(int, args.few_shot_calibration))
304305
pipe.set_progress_bar_config(disable=True)
305-
new_calib_set = []
306+
few_shot_calibration_prompts = []
306307
counter = [0]
307308

308309
def calib_hook(module, inp, inp_kwargs):
309310
if counter[0] in args.few_shot_calibration:
310-
new_calib_set.append((inp, inp_kwargs))
311+
few_shot_calibration_prompts.append((inp, inp_kwargs))
311312
counter[0] += 1
312313
if counter[0] == args.calibration_steps:
313314
counter[0] = 0
@@ -329,6 +330,8 @@ def calib_hook(module, inp, inp_kwargs):
329330
is_unet=is_unet,
330331
batch=args.calibration_batch_size)
331332
h.remove()
333+
else:
334+
few_shot_calibration_prompts = calibration_prompts
332335

333336
# Detect Stable Diffusion XL pipeline
334337
is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline)
@@ -358,15 +361,18 @@ def calib_hook(module, inp, inp_kwargs):
358361
if hasattr(m, 'lora_layer') and m.lora_layer is not None:
359362
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
360363

361-
def calibration_step(calibration_prompts, force_full_evaluation=False):
362-
if len(args.few_shot_calibration) > 0 or not force_full_evaluation:
363-
for i, (inp_args, inp_kwargs) in enumerate(new_calib_set):
364+
def calibration_step(force_full_calibration=False, num_prompts=None):
365+
if len(args.few_shot_calibration) > 0 or not force_full_calibration:
366+
for i, (inp_args, inp_kwargs) in enumerate(few_shot_calibration_prompts):
364367
denoising_network(*inp_args, **inp_kwargs)
368+
if num_prompts is not None and i == num_prompts:
369+
break
365370
else:
371+
prompts_subset = calibration_prompts[:num_prompts] if num_prompts is not None else calibration_prompts
366372
run_val_inference(
367373
pipe,
368374
args.resolution,
369-
calibration_prompts,
375+
prompts_subset,
370376
test_seeds,
371377
args.device,
372378
dtype,
@@ -389,12 +395,11 @@ def calibration_step(calibration_prompts, force_full_evaluation=False):
389395
for m in denoising_network.modules():
390396
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
391397
m.in_features = m.module.in_features
392-
393-
if args.dry_run or args.load_checkpoint is not None:
394-
calibration_prompts = [calibration_prompts[0]]
398+
act_eq_num_prompts = 1 if args.dry_run or args.load_checkpoint else len(
399+
calibration_prompts)
395400

396401
# SmoothQuant seems to be make better use of all the timesteps
397-
calibration_step(calibration_prompts, force_full_evaluation=True)
402+
calibration_step(force_full_calibration=True, num_prompts=act_eq_num_prompts)
398403

399404
# Workaround to expose `in_features` attribute from the EqualizedModule Wrapper
400405
for m in denoising_network.modules():
@@ -601,7 +606,7 @@ def sdpa_zp_stats_type():
601606
pipe.set_progress_bar_config(disable=True)
602607

603608
with torch.no_grad():
604-
calibration_step([calibration_prompts[0]])
609+
calibration_step(num_prompts=1)
605610

606611
if args.load_checkpoint is not None:
607612
with load_quant_model_mode(denoising_network):
@@ -616,7 +621,7 @@ def sdpa_zp_stats_type():
616621
if needs_calibration:
617622
print("Applying activation calibration")
618623
with torch.no_grad(), calibration_mode(denoising_network):
619-
calibration_step(calibration_prompts)
624+
calibration_step()
620625

621626
if args.svd_quant:
622627
print("Apply SVDQuant...")
@@ -634,24 +639,21 @@ def sdpa_zp_stats_type():
634639
m.compile_quant()
635640
if args.gptq:
636641
print("Applying GPTQ. It can take several hours")
637-
gptq_subset = calibration_prompts[:128]
638642
with torch.no_grad(), quant_inference_mode(denoising_network, compile=True):
639-
calibration_step([gptq_subset[0]])
640643
with gptq_mode(denoising_network,
641644
create_weight_orig=False,
642645
use_quant_activations=True,
643646
return_forward_output=False,
644647
act_order=True) as gptq:
645648
for _ in tqdm(range(gptq.num_layers)):
646-
calibration_step(gptq_subset)
649+
calibration_step(num_prompts=128)
647650
gptq.update()
648651

649652
if args.bias_correction:
650653
print("Applying bias correction")
651654
with torch.no_grad(), quant_inference_mode(denoising_network, compile=True):
652-
calibration_step([calibration_prompts[0]])
653655
with bias_correction_mode(denoising_network):
654-
calibration_step(calibration_prompts)
656+
calibration_step()
655657

656658
if args.vae_fp16_fix and is_sd_xl:
657659
vae_fix_scale = 128

0 commit comments

Comments
 (0)