Skip to content

Commit 349376e

Browse files
committed
Cleanup
1 parent 6fbebea commit 349376e

File tree

1 file changed

+45
-44
lines changed
  • src/brevitas_examples/stable_diffusion

1 file changed

+45
-44
lines changed

src/brevitas_examples/stable_diffusion/main.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -170,30 +170,26 @@ def run_val_inference(
170170
batch=1):
171171
with torch.no_grad():
172172

173-
if test_latents is None:
174-
test_latents = generate_latents([seeds[0]] * batch,
175-
device,
176-
dtype,
177-
unet_input_shape(resolution))
178173
extra_kwargs = {}
179174

180-
if is_unet:
175+
if is_unet and test_latents is not None:
181176
extra_kwargs['latents'] = test_latents
182177

183178
generator = torch.Generator(device).manual_seed(0) if deterministic else None
184179
neg_prompts = NEGATIVE_PROMPTS if use_negative_prompts else []
185-
for index in range(0, len(prompts), batch):
186-
curr_prompts = prompts[index]
187-
# We don't want to generate any image, so we return only the latent encoding pre VAE
188-
pipe([curr_prompts] * batch,
189-
negative_prompt=neg_prompts,
190-
output_type=output_type,
191-
guidance_scale=guidance_scale,
192-
height=resolution,
193-
width=resolution,
194-
num_inference_steps=total_steps,
195-
generator=generator,
196-
**extra_kwargs)
180+
for index in tqdm(range(0, len(prompts), batch)):
181+
curr_prompts = prompts[index:index + batch]
182+
183+
pipe(
184+
curr_prompts,
185+
negative_prompt=neg_prompts * len(curr_prompts),
186+
output_type=output_type,
187+
guidance_scale=guidance_scale,
188+
height=resolution,
189+
width=resolution,
190+
num_inference_steps=total_steps,
191+
generator=generator,
192+
**extra_kwargs)
197193

198194

199195
def collect_vae_calibration(pipe, calibration, test_seeds, dtype, latents, args):
@@ -256,16 +252,20 @@ def main(args):
256252
# Extend seeds based on batch_size
257253
test_seeds = [TEST_SEED] + [TEST_SEED + i for i in range(1, args.batch_size)]
258254

259-
is_flux_model = 'flux' in args.model.lower()
255+
is_flux = 'flux' in args.model.lower()
260256
# Load model from float checkpoint
261257
print(f"Loading model from {args.model}...")
262-
if is_flux_model:
258+
if is_flux:
263259
extra_kwargs = {}
264260
else:
265261
extra_kwargs = {'variant': 'fp16' if dtype == torch.float16 else None}
266262

267263
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, use_safetensors=True)
268-
is_unet = isinstance(pipe, (StableDiffusionXLPipeline, StableDiffusionPipeline))
264+
265+
# Detect if is unet-based pipeline
266+
is_unet = hasattr(pipe, 'unet')
267+
# Detect Stable Diffusion XL pipeline
268+
is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline)
269269

270270
if is_unet:
271271
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
@@ -300,6 +300,8 @@ def main(args):
300300
use_negative_prompts=args.use_negative_prompts,
301301
subfolder='float')
302302

303+
# Compute a few-shot calibration set
304+
few_shot_calibration_prompts = None
303305
if len(args.few_shot_calibration) > 0:
304306
args.few_shot_calibration = list(map(int, args.few_shot_calibration))
305307
pipe.set_progress_bar_config(disable=True)
@@ -330,11 +332,6 @@ def calib_hook(module, inp, inp_kwargs):
330332
is_unet=is_unet,
331333
batch=args.calibration_batch_size)
332334
h.remove()
333-
else:
334-
few_shot_calibration_prompts = calibration_prompts
335-
336-
# Detect Stable Diffusion XL pipeline
337-
is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline)
338335

339336
# Enable attention slicing
340337
if args.attention_slicing:
@@ -362,7 +359,7 @@ def calib_hook(module, inp, inp_kwargs):
362359
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")
363360

364361
def calibration_step(force_full_calibration=False, num_prompts=None):
365-
if len(args.few_shot_calibration) > 0 or not force_full_calibration:
362+
if len(args.few_shot_calibration) > 0 and not force_full_calibration:
366363
for i, (inp_args, inp_kwargs) in enumerate(few_shot_calibration_prompts):
367364
denoising_network(*inp_args, **inp_kwargs)
368365
if num_prompts is not None and i == num_prompts:
@@ -385,6 +382,7 @@ def calibration_step(force_full_calibration=False, num_prompts=None):
385382

386383
if args.activation_equalization:
387384
pipe.set_progress_bar_config(disable=True)
385+
print("Applying Activation Equalization")
388386
with torch.no_grad(), activation_equalization_mode(
389387
denoising_network,
390388
alpha=args.act_eq_alpha,
@@ -563,7 +561,7 @@ def sdpa_zp_stats_type():
563561
lambda module: module.cross_attention_dim
564562
if module.is_cross_attention else None}
565563

566-
if is_flux_model:
564+
if is_flux:
567565
extra_kwargs['qk_norm'] = 'rms_norm'
568566
extra_kwargs['bias'] = True
569567
extra_kwargs['processor'] = FusedFluxAttnProcessor2_0()
@@ -624,14 +622,13 @@ def sdpa_zp_stats_type():
624622
calibration_step()
625623

626624
if args.svd_quant:
627-
print("Apply SVDQuant...")
625+
print("Applying SVDQuant...")
628626
denoising_network = apply_svd_quant(
629627
denoising_network,
630628
blacklist=None,
631629
rank=args.svd_quant_rank,
632630
iters=args.svd_quant_iters,
633631
dtype=torch.float32)
634-
print("SVDQuant applied.")
635632

636633
if args.compile_ptq:
637634
for m in denoising_network.modules():
@@ -674,7 +671,7 @@ def sdpa_zp_stats_type():
674671
print(f"Corrected layers in VAE: {corrected_layers}")
675672

676673
if args.vae_quantize:
677-
assert not is_flux_model, "Not supported yet"
674+
assert not is_flux, "Not supported yet"
678675
print("Quantizing VAE")
679676
vae_calibration = collect_vae_calibration(
680677
pipe, calibration_prompts, test_seeds, dtype, latents, args)
@@ -801,20 +798,15 @@ def sdpa_zp_stats_type():
801798
device = next(iter(denoising_network.parameters())).device
802799
dtype = next(iter(denoising_network.parameters())).dtype
803800

804-
# Define tracing input
805-
if is_sd_xl:
806-
generate_fn = generate_unet_xl_rand_inputs
807-
shape = SD_XL_EMBEDDINGS_SHAPE
808-
else:
809-
generate_fn = generate_unet_21_rand_inputs
810-
shape = SD_2_1_EMBEDDINGS_SHAPE
811-
trace_inputs = generate_fn(
812-
embedding_shape=shape,
813-
unet_input_shape=unet_input_shape(args.resolution),
814-
device=device,
815-
dtype=dtype)
816-
817801
if args.export_target == 'onnx':
802+
assert is_sd_xl, "Only SDXL ONNX export is currently supported. If this impacts you, feel free to open an issue"
803+
804+
trace_inputs = generate_unet_xl_rand_inputs(
805+
embedding_shape=SD_XL_EMBEDDINGS_SHAPE,
806+
unet_input_shape=unet_input_shape(args.resolution),
807+
device=device,
808+
dtype=dtype)
809+
818810
if args.weight_quant_granularity == 'per_group':
819811
export_manager = BlockQuantProxyLevelManager
820812
else:
@@ -901,10 +893,19 @@ def sdpa_zp_stats_type():
901893
fid.update(quant_image.unsqueeze(0).to('cuda'), real=False)
902894

903895
print(f"Torchmetrics FID: {float(fid.compute())}")
896+
torchmetrics_fid = float(fid.compute())
897+
# Dump args to json
898+
with open(os.path.join(output_dir, 'args.json'), 'w') as fp:
899+
json.dump(vars(args), fp)
900+
clean_fid = 0.
904901
if cleanfid is not None:
905902
score = cleanfid.compute_fid(
906903
os.path.join(output_dir, 'float'), os.path.join(output_dir, 'quant'))
907904
print(f"Cleanfid FID: {float(score)}")
905+
clean_fid = float(score)
906+
results = {'torchmetrics_fid': torchmetrics_fid, 'clean_fid': clean_fid}
907+
with open(os.path.join(output_dir, 'results.json'), 'w') as fp:
908+
json.dump(results, fp)
908909

909910
elif args.inference_pipeline == 'reference_images':
910911
pipe.set_progress_bar_config(disable=True)

0 commit comments

Comments
 (0)