@@ -301,13 +301,14 @@ def main(args):
301301 subfolder = 'float' )
302302
303303 if len (args .few_shot_calibration ) > 0 :
304+ args .few_shot_calibration = list (map (int , args .few_shot_calibration ))
304305 pipe .set_progress_bar_config (disable = True )
305- new_calib_set = []
306+ few_shot_calibration_prompts = []
306307 counter = [0 ]
307308
308309 def calib_hook (module , inp , inp_kwargs ):
309310 if counter [0 ] in args .few_shot_calibration :
310- new_calib_set .append ((inp , inp_kwargs ))
311+ few_shot_calibration_prompts .append ((inp , inp_kwargs ))
311312 counter [0 ] += 1
312313 if counter [0 ] == args .calibration_steps :
313314 counter [0 ] = 0
@@ -329,6 +330,8 @@ def calib_hook(module, inp, inp_kwargs):
329330 is_unet = is_unet ,
330331 batch = args .calibration_batch_size )
331332 h .remove ()
333+ else :
334+ few_shot_calibration_prompts = calibration_prompts
332335
333336 # Detect Stable Diffusion XL pipeline
334337 is_sd_xl = isinstance (pipe , StableDiffusionXLPipeline )
@@ -358,15 +361,18 @@ def calib_hook(module, inp, inp_kwargs):
358361 if hasattr (m , 'lora_layer' ) and m .lora_layer is not None :
359362 raise RuntimeError ("LoRA layers should be fused in before calling into quantization." )
360363
361- def calibration_step (calibration_prompts , force_full_evaluation = False ):
362- if len (args .few_shot_calibration ) > 0 or not force_full_evaluation :
363- for i , (inp_args , inp_kwargs ) in enumerate (new_calib_set ):
364+ def calibration_step (force_full_calibration = False , num_prompts = None ):
365+ if len (args .few_shot_calibration ) > 0 or not force_full_calibration :
366+ for i , (inp_args , inp_kwargs ) in enumerate (few_shot_calibration_prompts ):
364367 denoising_network (* inp_args , ** inp_kwargs )
368+ if num_prompts is not None and i == num_prompts :
369+ break
365370 else :
371+ prompts_subset = calibration_prompts [:num_prompts ] if num_prompts is not None else calibration_prompts
366372 run_val_inference (
367373 pipe ,
368374 args .resolution ,
369- calibration_prompts ,
375+ prompts_subset ,
370376 test_seeds ,
371377 args .device ,
372378 dtype ,
@@ -389,12 +395,11 @@ def calibration_step(calibration_prompts, force_full_evaluation=False):
389395 for m in denoising_network .modules ():
390396 if isinstance (m , KwargsForwardHook ) and hasattr (m .module , 'in_features' ):
391397 m .in_features = m .module .in_features
392-
393- if args .dry_run or args .load_checkpoint is not None :
394- calibration_prompts = [calibration_prompts [0 ]]
398+ act_eq_num_prompts = 1 if args .dry_run or args .load_checkpoint else len (
399+ calibration_prompts )
395400
396401 # SmoothQuant seems to be make better use of all the timesteps
397- calibration_step (calibration_prompts , force_full_evaluation = True )
402+ calibration_step (force_full_calibration = True , num_prompts = act_eq_num_prompts )
398403
399404 # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper
400405 for m in denoising_network .modules ():
@@ -601,7 +606,7 @@ def sdpa_zp_stats_type():
601606 pipe .set_progress_bar_config (disable = True )
602607
603608 with torch .no_grad ():
604- calibration_step ([ calibration_prompts [ 0 ]] )
609+ calibration_step (num_prompts = 1 )
605610
606611 if args .load_checkpoint is not None :
607612 with load_quant_model_mode (denoising_network ):
@@ -616,7 +621,7 @@ def sdpa_zp_stats_type():
616621 if needs_calibration :
617622 print ("Applying activation calibration" )
618623 with torch .no_grad (), calibration_mode (denoising_network ):
619- calibration_step (calibration_prompts )
624+ calibration_step ()
620625
621626 if args .svd_quant :
622627 print ("Apply SVDQuant..." )
@@ -634,24 +639,21 @@ def sdpa_zp_stats_type():
634639 m .compile_quant ()
635640 if args .gptq :
636641 print ("Applying GPTQ. It can take several hours" )
637- gptq_subset = calibration_prompts [:128 ]
638642 with torch .no_grad (), quant_inference_mode (denoising_network , compile = True ):
639- calibration_step ([gptq_subset [0 ]])
640643 with gptq_mode (denoising_network ,
641644 create_weight_orig = False ,
642645 use_quant_activations = True ,
643646 return_forward_output = False ,
644647 act_order = True ) as gptq :
645648 for _ in tqdm (range (gptq .num_layers )):
646- calibration_step (gptq_subset )
649+ calibration_step (num_prompts = 128 )
647650 gptq .update ()
648651
649652 if args .bias_correction :
650653 print ("Applying bias correction" )
651654 with torch .no_grad (), quant_inference_mode (denoising_network , compile = True ):
652- calibration_step ([calibration_prompts [0 ]])
653655 with bias_correction_mode (denoising_network ):
654- calibration_step (calibration_prompts )
656+ calibration_step ()
655657
656658 if args .vae_fp16_fix and is_sd_xl :
657659 vae_fix_scale = 128
0 commit comments