diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index d22e0d902..e889dcc3b 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -52,14 +52,16 @@ def __init__( self.blocksize = math.ceil(self.columns / num_blocks) # Initialize Hessian matrix and counter. We need it in float32 to compute the inverse - self.H = torch.zeros((self.groups, self.columns, self.columns), - device='cpu', - dtype=torch.float32, - pin_memory=torch.cuda.is_available()) - self.B = torch.zeros((self.groups, self.columns, self.columns), - device='cpu', - dtype=torch.float32, - pin_memory=torch.cuda.is_available()) + self.H = torch.zeros( + (self.groups, self.columns, self.columns), + device='cpu', + dtype=torch.float32, + ) + self.B = torch.zeros( + (self.groups, self.columns, self.columns), + device='cpu', + dtype=torch.float32, + ) self.nsamples = 0 assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index b0b48d0d6..05aa2b654 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -242,13 +242,26 @@ def single_layer_update(self): pass def get_quant_weights(self, i, i1, permutation_list, with_quant_history=False): + from brevitas.quant_tensor import _unpack_quant_tensor + # We need to recompute quant weights at runtime since our float weights are being updated # Add offset in case of blockwise computation i = i1 + i + # For QuantLinear and for some QuantConvolutional layers, we exploit the possibility # of quantizing only a subset of the entire matrix speeding up the computation of GPxQ + no_slice = False + # Groupwise Quantization does not support slicing + no_slice = no_slice or self.layer.weight_quant.is_groupwise + # If we need quantization of past channels, we do not use slicing + no_slice = no_slice or with_quant_history + # If we are in export mode (i.e., inference mode), we do not slice for torch.compile + # compatibility + no_slice = no_slice or self.layer.weight_quant.export_mode + if isinstance(self.layer, qnn.QuantLinear): - if self.layer.weight_quant.is_groupwise or with_quant_history: + if no_slice: + # No slicing, not optimized q = self.layer.quant_weight(quant_input=self.quant_metadata) q = _unpack_quant_tensor(q).unsqueeze(0) # [1, OC, IC] @@ -264,11 +277,11 @@ def get_quant_weights(self, i, i1, permutation_list, with_quant_history=False): subtensor_slice_list=subtensor_slice_list, quant_input=self.quant_metadata)).unsqueeze(0) # [1, OC, 1] elif isinstance(self.layer, SUPPORTED_CONV_OP): - # For depthwise and ConvTranspose we fall back to quantizing the entire martix. - # For all other cases, we create a mask that represent the slicing we will perform on the weight matrix - # and we quantize only the selected dimensions. - if self.layer.weight_quant.is_groupwise or with_quant_history or self.groups > 1 or ( - self.groups == 1 and is_conv_transposed(self.layer)): + # Depthwise and ConvTranspose does not support slicing + no_slice_conv = no_slice or (self.groups > 1 or is_conv_transposed(self.layer)) + + if no_slice_conv: + quant_weight = self.layer.quant_weight(quant_input=self.quant_metadata) quant_weight = _unpack_quant_tensor(quant_weight) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 35e01e755..5f3131df4 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -97,7 +97,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--weight-quant-format WEIGHT_QUANT_FORMAT] [--input-quant-format INPUT_QUANT_FORMAT] [--weight-quant-granularity {per_channel,per_tensor,per_group}] - [--input-quant-granularity {per_tensor,per_group}] + [--input-quant-granularity {per_tensor,per_group,per_row}] [--input-scale-type {static,dynamic}] [--weight-group-size WEIGHT_GROUP_SIZE] [--input-group-size INPUT_GROUP_SIZE] @@ -116,6 +116,8 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--inference-pipeline {samples,reference_images,mlperf}] [--caption-path CAPTION_PATH] [--reference-images-path REFERENCE_IMAGES_PATH] + [--few-shot-calibration [FEW_SHOT_CALIBRATION ...]] + [--calibration-batch-size CALIBRATION_BATCH_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] [--exclude-blacklist-act-eq | --no-exclude-blacklist-act-eq] [--quantize-input-zero-point | --no-quantize-input-zero-point] @@ -251,7 +253,7 @@ options: --weight-quant-granularity {per_channel,per_tensor,per_group} Granularity for scales/zero-point of weights. Default: per_channel. - --input-quant-granularity {per_tensor,per_group} + --input-quant-granularity {per_tensor,per_group,per_row} Granularity for scales/zero-point of inputs. Default: per_tensor. --input-scale-type {static,dynamic} @@ -307,6 +309,11 @@ options: Inference pipeline for evaluation. Default: None --reference-images-path REFERENCE_IMAGES_PATH Inference pipeline for evaluation. Default: None + --few-shot-calibration [FEW_SHOT_CALIBRATION ...] + What timesteps to use for few-shot-calibration. + Default: [] + --calibration-batch-size CALIBRATION_BATCH_SIZE + Batch size for few-shot-calibration. Default: 1 --quantize-weight-zero-point Enable Quantize weight zero-point. Default: Enabled --no-quantize-weight-zero-point diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 7ee2d715a..ab086af13 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -29,6 +29,7 @@ # Each will produce slightly different but valid results from torchmetrics.image.fid import FrechetInceptionDistance +torch._dynamo.config.force_parameter_static_shapes = False try: from cleanfid import fid as cleanfid except: @@ -57,8 +58,6 @@ from brevitas_examples.common.svd_quant import ErrorCorrectedModule from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.llm.llm_quant.svd_quant import apply_svd_quant -from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid -from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params @@ -166,23 +165,23 @@ def run_val_inference( test_latents=None, output_type='latent', deterministic=True, - is_unet=True): + is_unet=True, + batch=1): with torch.no_grad(): - if test_latents is None: - test_latents = generate_latents(seeds[0], device, dtype, unet_input_shape(resolution)) extra_kwargs = {} - if is_unet: + if is_unet and test_latents is not None: extra_kwargs['latents'] = test_latents generator = torch.Generator(device).manual_seed(0) if deterministic else None neg_prompts = NEGATIVE_PROMPTS if use_negative_prompts else [] - for prompt in tqdm(prompts): - # We don't want to generate any image, so we return only the latent encoding pre VAE + for index in tqdm(range(0, len(prompts), batch)): + curr_prompts = prompts[index:index + batch] + pipe( - prompt, - negative_prompt=neg_prompts[0], + curr_prompts, + negative_prompt=neg_prompts * len(curr_prompts), output_type=output_type, guidance_scale=guidance_scale, height=resolution, @@ -252,16 +251,21 @@ def main(args): # Extend seeds based on batch_size test_seeds = [TEST_SEED] + [TEST_SEED + i for i in range(1, args.batch_size)] - is_flux_model = 'flux' in args.model.lower() + is_flux = 'flux' in args.model.lower() # Load model from float checkpoint print(f"Loading model from {args.model}...") - if is_flux_model: - extra_kwargs = {} - else: - extra_kwargs = {'variant': 'fp16' if dtype == torch.float16 else None} - pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, use_safetensors=True) - is_unet = isinstance(pipe, (StableDiffusionXLPipeline, StableDiffusionPipeline)) + extra_kwargs = {} + if not is_flux: + variant_dict = {torch.float16: 'fp16', torch.bfloat16: 'bf16'} + extra_kwargs = {'variant': variant_dict.get(dtype, None)} + + pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype, **extra_kwargs) + + # Detect if is unet-based pipeline + is_unet = hasattr(pipe, 'unet') + # Detect Stable Diffusion XL pipeline + is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline) if is_unet: pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) @@ -296,8 +300,38 @@ def main(args): use_negative_prompts=args.use_negative_prompts, subfolder='float') - # Detect Stable Diffusion XL pipeline - is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline) + # Compute a few-shot calibration set + few_shot_calibration_prompts = None + if len(args.few_shot_calibration) > 0: + args.few_shot_calibration = list(map(int, args.few_shot_calibration)) + pipe.set_progress_bar_config(disable=True) + few_shot_calibration_prompts = [] + counter = [0] + + def calib_hook(module, inp, inp_kwargs): + if counter[0] in args.few_shot_calibration: + few_shot_calibration_prompts.append((inp, inp_kwargs)) + counter[0] += 1 + if counter[0] == args.calibration_steps: + counter[0] = 0 + + h = denoising_network.register_forward_pre_hook(calib_hook, with_kwargs=True) + + run_val_inference( + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + deterministic=args.deterministic, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale, + is_unet=is_unet, + batch=args.calibration_batch_size) + h.remove() # Enable attention slicing if args.attention_slicing: @@ -324,8 +358,31 @@ def main(args): if hasattr(m, 'lora_layer') and m.lora_layer is not None: raise RuntimeError("LoRA layers should be fused in before calling into quantization.") + def calibration_step(force_full_calibration=False, num_prompts=None): + if len(args.few_shot_calibration) > 0 and not force_full_calibration: + for i, (inp_args, inp_kwargs) in enumerate(few_shot_calibration_prompts): + denoising_network(*inp_args, **inp_kwargs) + if num_prompts is not None and i == num_prompts: + break + else: + prompts_subset = calibration_prompts[:num_prompts] if num_prompts is not None else calibration_prompts + run_val_inference( + pipe, + args.resolution, + prompts_subset, + test_seeds, + args.device, + dtype, + deterministic=args.deterministic, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale, + is_unet=is_unet) + if args.activation_equalization: pipe.set_progress_bar_config(disable=True) + print("Applying Activation Equalization") with torch.no_grad(), activation_equalization_mode( denoising_network, alpha=args.act_eq_alpha, @@ -336,23 +393,11 @@ def main(args): for m in denoising_network.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features - total_steps = args.calibration_steps - if args.dry_run or args.load_checkpoint is not None: - calibration_prompts = [calibration_prompts[0]] - total_steps = 1 - run_val_inference( - pipe, - args.resolution, - calibration_prompts, - test_seeds, - args.device, - dtype, - deterministic=args.deterministic, - total_steps=total_steps, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale, - is_unet=is_unet) + act_eq_num_prompts = 1 if args.dry_run or args.load_checkpoint else len( + calibration_prompts) + + # SmoothQuant seems to be make better use of all the timesteps + calibration_step(force_full_calibration=True, num_prompts=act_eq_num_prompts) # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper for m in denoising_network.modules(): @@ -516,7 +561,7 @@ def sdpa_zp_stats_type(): lambda module: module.cross_attention_dim if module.is_cross_attention else None} - if is_flux_model: + if is_flux: extra_kwargs['qk_norm'] = 'rms_norm' extra_kwargs['bias'] = True extra_kwargs['processor'] = FusedFluxAttnProcessor2_0() @@ -559,18 +604,7 @@ def sdpa_zp_stats_type(): pipe.set_progress_bar_config(disable=True) with torch.no_grad(): - run_val_inference( - pipe, - args.resolution, [calibration_prompts[0]], - test_seeds, - args.device, - dtype, - total_steps=1, - deterministic=args.deterministic, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale, - is_unet=is_unet) + calibration_step(num_prompts=1) if args.load_checkpoint is not None: with load_quant_model_mode(denoising_network): @@ -580,27 +614,15 @@ def sdpa_zp_stats_type(): torch.load(args.load_checkpoint, map_location='cpu')) print(f"Checkpoint loaded!") pipe = pipe.to(args.device) - elif not args.dry_run: + elif not args.dry_run: if needs_calibration: print("Applying activation calibration") with torch.no_grad(), calibration_mode(denoising_network): - run_val_inference( - pipe, - args.resolution, - calibration_prompts, - test_seeds, - args.device, - dtype, - deterministic=args.deterministic, - total_steps=args.calibration_steps, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale, - is_unet=is_unet) + calibration_step(force_full_calibration=True) if args.svd_quant: - print("Apply SVDQuant...") + print("Applying SVDQuant...") denoising_network = apply_svd_quant( denoising_network, blacklist=None, @@ -617,46 +639,25 @@ def sdpa_zp_stats_type(): for m in denoising_network.modules(): if hasattr(m, 'compile_quant'): m.compile_quant() - if args.gptq: print("Applying GPTQ. It can take several hours") - with torch.no_grad(), gptq_mode(denoising_network, - create_weight_orig=False, - use_quant_activations=False, - return_forward_output=True, + with torch.no_grad(), quant_inference_mode(denoising_network, compile=args.compile_eval): + with gptq_mode( + denoising_network, + create_weight_orig=args + .bias_correction, # if we use bias_corr, we need weight_orig + use_quant_activations=True, + return_forward_output=False, act_order=True) as gptq: - for _ in tqdm(range(gptq.num_layers)): - run_val_inference( - pipe, - args.resolution, - calibration_prompts, - test_seeds, - args.device, - dtype, - deterministic=args.deterministic, - total_steps=args.calibration_steps, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale, - is_unet=is_unet) - gptq.update() - torch.cuda.empty_cache() + for _ in tqdm(range(gptq.num_layers)): + calibration_step(num_prompts=128) + gptq.update() + if args.bias_correction: print("Applying bias correction") - with torch.no_grad(), bias_correction_mode(denoising_network): - run_val_inference( - pipe, - args.resolution, - calibration_prompts, - test_seeds, - args.device, - dtype, - deterministic=args.deterministic, - total_steps=args.calibration_steps, - use_negative_prompts=args.use_negative_prompts, - test_latents=latents, - guidance_scale=args.guidance_scale, - is_unet=is_unet) + with torch.no_grad(), quant_inference_mode(denoising_network, compile=args.compile_eval): + with bias_correction_mode(denoising_network): + calibration_step(force_full_calibration=True) if args.vae_fp16_fix and is_sd_xl: vae_fix_scale = 128 @@ -677,7 +678,7 @@ def sdpa_zp_stats_type(): print(f"Corrected layers in VAE: {corrected_layers}") if args.vae_quantize: - assert not is_flux_model, "Not supported yet" + assert not is_flux, "Not supported yet" print("Quantizing VAE") vae_calibration = collect_vae_calibration( pipe, calibration_prompts, test_seeds, dtype, latents, args) @@ -804,20 +805,15 @@ def sdpa_zp_stats_type(): device = next(iter(denoising_network.parameters())).device dtype = next(iter(denoising_network.parameters())).dtype - # Define tracing input - if is_sd_xl: - generate_fn = generate_unet_xl_rand_inputs - shape = SD_XL_EMBEDDINGS_SHAPE - else: - generate_fn = generate_unet_21_rand_inputs - shape = SD_2_1_EMBEDDINGS_SHAPE - trace_inputs = generate_fn( - embedding_shape=shape, - unet_input_shape=unet_input_shape(args.resolution), - device=device, - dtype=dtype) - if args.export_target == 'onnx': + assert is_sd_xl, "Only SDXL ONNX export is currently supported. If this impacts you, feel free to open an issue" + + trace_inputs = generate_unet_xl_rand_inputs( + embedding_shape=SD_XL_EMBEDDINGS_SHAPE, + unet_input_shape=unet_input_shape(args.resolution), + device=device, + dtype=dtype) + if args.weight_quant_granularity == 'per_group': export_manager = BlockQuantProxyLevelManager else: @@ -839,6 +835,9 @@ def sdpa_zp_stats_type(): # Perform inference if args.prompt > 0 and not args.dry_run: if args.inference_pipeline == 'mlperf': + from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import \ + compute_mlperf_fid + print(f"Computing accuracy with MLPerf pipeline") with torch.no_grad(), quant_inference_mode(denoising_network, compile=args.compile_eval): # Perform a single forward pass before evenutally compiling @@ -904,10 +903,19 @@ def sdpa_zp_stats_type(): fid.update(quant_image.unsqueeze(0).to('cuda'), real=False) print(f"Torchmetrics FID: {float(fid.compute())}") + torchmetrics_fid = float(fid.compute()) + # Dump args to json + with open(os.path.join(output_dir, 'args.json'), 'w') as fp: + json.dump(vars(args), fp) + clean_fid = 0. if cleanfid is not None: score = cleanfid.compute_fid( os.path.join(output_dir, 'float'), os.path.join(output_dir, 'quant')) print(f"Cleanfid FID: {float(score)}") + clean_fid = float(score) + results = {'torchmetrics_fid': torchmetrics_fid, 'clean_fid': clean_fid} + with open(os.path.join(output_dir, 'results.json'), 'w') as fp: + json.dump(results, fp) elif args.inference_pipeline == 'reference_images': pipe.set_progress_bar_config(disable=True) @@ -974,12 +982,7 @@ def sdpa_zp_stats_type(): if __name__ == "__main__": parser = argparse.ArgumentParser(description='Stable Diffusion quantization') - parser.add_argument( - '-m', - '--model', - type=str, - default='/scratch/hf_models/stable-diffusion-2-1-base', - help='Path or name of the model.') + parser.add_argument('-m', '--model', type=str, default=None, help='Path or name of the model.') parser.add_argument( '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') parser.add_argument( @@ -1025,7 +1028,7 @@ def sdpa_zp_stats_type(): parser.add_argument( '--resolution', type=int, - default=512, + default=1024, help='Resolution along height and width dimension. Default: 512.') parser.add_argument('--svd-quant-rank', type=int, default=32, help='SVDQuant rank. Default: 32') parser.add_argument( @@ -1170,7 +1173,7 @@ def sdpa_zp_stats_type(): '--input-quant-granularity', type=str, default='per_tensor', - choices=['per_tensor', 'per_group'], + choices=['per_tensor', 'per_group', 'per_row'], help='Granularity for scales/zero-point of inputs. Default: per_tensor.') parser.add_argument( '--input-scale-type', @@ -1291,6 +1294,16 @@ def sdpa_zp_stats_type(): type=str, default=None, help='Inference pipeline for evaluation. Default: %(default)s') + parser.add_argument( + '--few-shot-calibration', + default=[], + nargs='*', + help='What timesteps to use for few-shot-calibration. Default: %(default)s') + parser.add_argument( + '--calibration-batch-size', + type=int, + default=1, + help='Batch size for few-shot-calibration. Default: %(default)s') add_bool_arg( parser, 'quantize-weight-zero-point',