@@ -170,30 +170,26 @@ def run_val_inference(
170170 batch = 1 ):
171171 with torch .no_grad ():
172172
173- if test_latents is None :
174- test_latents = generate_latents ([seeds [0 ]] * batch ,
175- device ,
176- dtype ,
177- unet_input_shape (resolution ))
178173 extra_kwargs = {}
179174
180- if is_unet :
175+ if is_unet and test_latents is not None :
181176 extra_kwargs ['latents' ] = test_latents
182177
183178 generator = torch .Generator (device ).manual_seed (0 ) if deterministic else None
184179 neg_prompts = NEGATIVE_PROMPTS if use_negative_prompts else []
185- for index in range (0 , len (prompts ), batch ):
186- curr_prompts = prompts [index ]
187- # We don't want to generate any image, so we return only the latent encoding pre VAE
188- pipe ([curr_prompts ] * batch ,
189- negative_prompt = neg_prompts ,
190- output_type = output_type ,
191- guidance_scale = guidance_scale ,
192- height = resolution ,
193- width = resolution ,
194- num_inference_steps = total_steps ,
195- generator = generator ,
196- ** extra_kwargs )
180+ for index in tqdm (range (0 , len (prompts ), batch )):
181+ curr_prompts = prompts [index :index + batch ]
182+
183+ pipe (
184+ curr_prompts ,
185+ negative_prompt = neg_prompts * len (curr_prompts ),
186+ output_type = output_type ,
187+ guidance_scale = guidance_scale ,
188+ height = resolution ,
189+ width = resolution ,
190+ num_inference_steps = total_steps ,
191+ generator = generator ,
192+ ** extra_kwargs )
197193
198194
199195def collect_vae_calibration (pipe , calibration , test_seeds , dtype , latents , args ):
@@ -256,16 +252,20 @@ def main(args):
256252 # Extend seeds based on batch_size
257253 test_seeds = [TEST_SEED ] + [TEST_SEED + i for i in range (1 , args .batch_size )]
258254
259- is_flux_model = 'flux' in args .model .lower ()
255+ is_flux = 'flux' in args .model .lower ()
260256 # Load model from float checkpoint
261257 print (f"Loading model from { args .model } ..." )
262- if is_flux_model :
258+ if is_flux :
263259 extra_kwargs = {}
264260 else :
265261 extra_kwargs = {'variant' : 'fp16' if dtype == torch .float16 else None }
266262
267263 pipe = DiffusionPipeline .from_pretrained (args .model , torch_dtype = dtype , use_safetensors = True )
268- is_unet = isinstance (pipe , (StableDiffusionXLPipeline , StableDiffusionPipeline ))
264+
265+ # Detect if is unet-based pipeline
266+ is_unet = hasattr (pipe , 'unet' )
267+ # Detect Stable Diffusion XL pipeline
268+ is_sd_xl = isinstance (pipe , StableDiffusionXLPipeline )
269269
270270 if is_unet :
271271 pipe .scheduler = EulerDiscreteScheduler .from_config (pipe .scheduler .config )
@@ -300,6 +300,8 @@ def main(args):
300300 use_negative_prompts = args .use_negative_prompts ,
301301 subfolder = 'float' )
302302
303+ # Compute a few-shot calibration set
304+ few_shot_calibration_prompts = None
303305 if len (args .few_shot_calibration ) > 0 :
304306 args .few_shot_calibration = list (map (int , args .few_shot_calibration ))
305307 pipe .set_progress_bar_config (disable = True )
@@ -330,11 +332,6 @@ def calib_hook(module, inp, inp_kwargs):
330332 is_unet = is_unet ,
331333 batch = args .calibration_batch_size )
332334 h .remove ()
333- else :
334- few_shot_calibration_prompts = calibration_prompts
335-
336- # Detect Stable Diffusion XL pipeline
337- is_sd_xl = isinstance (pipe , StableDiffusionXLPipeline )
338335
339336 # Enable attention slicing
340337 if args .attention_slicing :
@@ -362,7 +359,7 @@ def calib_hook(module, inp, inp_kwargs):
362359 raise RuntimeError ("LoRA layers should be fused in before calling into quantization." )
363360
364361 def calibration_step (force_full_calibration = False , num_prompts = None ):
365- if len (args .few_shot_calibration ) > 0 or not force_full_calibration :
362+ if len (args .few_shot_calibration ) > 0 and not force_full_calibration :
366363 for i , (inp_args , inp_kwargs ) in enumerate (few_shot_calibration_prompts ):
367364 denoising_network (* inp_args , ** inp_kwargs )
368365 if num_prompts is not None and i == num_prompts :
@@ -385,6 +382,7 @@ def calibration_step(force_full_calibration=False, num_prompts=None):
385382
386383 if args .activation_equalization :
387384 pipe .set_progress_bar_config (disable = True )
385+ print ("Applying Activation Equalization" )
388386 with torch .no_grad (), activation_equalization_mode (
389387 denoising_network ,
390388 alpha = args .act_eq_alpha ,
@@ -563,7 +561,7 @@ def sdpa_zp_stats_type():
563561 lambda module : module .cross_attention_dim
564562 if module .is_cross_attention else None }
565563
566- if is_flux_model :
564+ if is_flux :
567565 extra_kwargs ['qk_norm' ] = 'rms_norm'
568566 extra_kwargs ['bias' ] = True
569567 extra_kwargs ['processor' ] = FusedFluxAttnProcessor2_0 ()
@@ -624,14 +622,13 @@ def sdpa_zp_stats_type():
624622 calibration_step ()
625623
626624 if args .svd_quant :
627- print ("Apply SVDQuant..." )
625+ print ("Applying SVDQuant..." )
628626 denoising_network = apply_svd_quant (
629627 denoising_network ,
630628 blacklist = None ,
631629 rank = args .svd_quant_rank ,
632630 iters = args .svd_quant_iters ,
633631 dtype = torch .float32 )
634- print ("SVDQuant applied." )
635632
636633 if args .compile_ptq :
637634 for m in denoising_network .modules ():
@@ -674,7 +671,7 @@ def sdpa_zp_stats_type():
674671 print (f"Corrected layers in VAE: { corrected_layers } " )
675672
676673 if args .vae_quantize :
677- assert not is_flux_model , "Not supported yet"
674+ assert not is_flux , "Not supported yet"
678675 print ("Quantizing VAE" )
679676 vae_calibration = collect_vae_calibration (
680677 pipe , calibration_prompts , test_seeds , dtype , latents , args )
@@ -801,20 +798,15 @@ def sdpa_zp_stats_type():
801798 device = next (iter (denoising_network .parameters ())).device
802799 dtype = next (iter (denoising_network .parameters ())).dtype
803800
804- # Define tracing input
805- if is_sd_xl :
806- generate_fn = generate_unet_xl_rand_inputs
807- shape = SD_XL_EMBEDDINGS_SHAPE
808- else :
809- generate_fn = generate_unet_21_rand_inputs
810- shape = SD_2_1_EMBEDDINGS_SHAPE
811- trace_inputs = generate_fn (
812- embedding_shape = shape ,
813- unet_input_shape = unet_input_shape (args .resolution ),
814- device = device ,
815- dtype = dtype )
816-
817801 if args .export_target == 'onnx' :
802+ assert is_sd_xl , "Only SDXL ONNX export is currently supported. If this impacts you, feel free to open an issue"
803+
804+ trace_inputs = generate_unet_xl_rand_inputs (
805+ embedding_shape = SD_XL_EMBEDDINGS_SHAPE ,
806+ unet_input_shape = unet_input_shape (args .resolution ),
807+ device = device ,
808+ dtype = dtype )
809+
818810 if args .weight_quant_granularity == 'per_group' :
819811 export_manager = BlockQuantProxyLevelManager
820812 else :
@@ -901,10 +893,19 @@ def sdpa_zp_stats_type():
901893 fid .update (quant_image .unsqueeze (0 ).to ('cuda' ), real = False )
902894
903895 print (f"Torchmetrics FID: { float (fid .compute ())} " )
896+ torchmetrics_fid = float (fid .compute ())
897+ # Dump args to json
898+ with open (os .path .join (output_dir , 'args.json' ), 'w' ) as fp :
899+ json .dump (vars (args ), fp )
900+ clean_fid = 0.
904901 if cleanfid is not None :
905902 score = cleanfid .compute_fid (
906903 os .path .join (output_dir , 'float' ), os .path .join (output_dir , 'quant' ))
907904 print (f"Cleanfid FID: { float (score )} " )
905+ clean_fid = float (score )
906+ results = {'torchmetrics_fid' : torchmetrics_fid , 'clean_fid' : clean_fid }
907+ with open (os .path .join (output_dir , 'results.json' ), 'w' ) as fp :
908+ json .dump (results , fp )
908909
909910 elif args .inference_pipeline == 'reference_images' :
910911 pipe .set_progress_bar_config (disable = True )
0 commit comments