Skip to content

Commit 1feed47

Browse files
authored
Add quantization for PixArt model (openvinotoolkit#2338)
Hybrid quantization does not improve PixArt's performance (CPU ticket: 141083); however, we have achieved memory reduction benefits. We are expecting a performance improvement after enabling per-token dynamic quantization (ref ticket: 143590).
1 parent 5471cbf commit 1feed47

File tree

4 files changed

+637
-34
lines changed

4 files changed

+637
-34
lines changed

notebooks/pixart/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99

1010
## Notebook Contents
1111

12-
This notebook demonstrates how to convert and run the Paint-by-Example model using OpenVINO.
12+
This notebook demonstrates how to convert and run the Paint-by-Example model using OpenVINO. An additional part demonstrates how to run optimization with [NNCF](https://github.com/openvinotoolkit/nncf/) to speed up pipeline.
1313

1414
Notebook contains the following steps:
1515
1. Convert PyTorch models to OpenVINO IR format.
1616
2. Run PixArt-α pipeline with OpenVINO.
17-
3. Interactive demo.
17+
3. Optimize pipeline with [NNCF](https://github.com/openvinotoolkit/nncf/)
18+
4. Compare results of FP16 and optimized pipelines
19+
5. Interactive demo.
1820

1921
## Installation instructions
2022

notebooks/pixart/pixart.ipynb

+406-32
Large diffs are not rendered by default.

notebooks/pixart/pixart_helper.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from pathlib import Path
2+
3+
MODEL_DIR = Path("model")
4+
5+
TEXT_ENCODER_PATH = MODEL_DIR / "text_encoder.xml"
6+
TRANSFORMER_OV_PATH = MODEL_DIR / "transformer_ir.xml"
7+
VAE_DECODER_PATH = MODEL_DIR / "vae_decoder.xml"
8+
9+
def get_pipeline_selection_option(optimized_pipe=None):
10+
import ipywidgets as widgets
11+
12+
model_available = optimized_pipe is not None
13+
use_quantized_models = widgets.Checkbox(
14+
value=model_available,
15+
description="Use quantized models",
16+
disabled=not model_available,
17+
)
18+
return use_quantized_models
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
from typing import Any, Dict, List
2+
import datasets
3+
import time
4+
import torch
5+
6+
from collections import deque
7+
from tqdm.notebook import tqdm
8+
from transformers import set_seed
9+
import numpy as np
10+
import openvino as ov
11+
import matplotlib.pyplot as plt
12+
from PIL import Image
13+
14+
from pixart_helper import MODEL_DIR, TEXT_ENCODER_PATH, TRANSFORMER_OV_PATH, VAE_DECODER_PATH
15+
16+
set_seed(42)
17+
NUM_INFERENCE_STEPS = 4
18+
INT8_TRANSFORMER_OV_PATH = MODEL_DIR / "transformer_ir_int8.xml"
19+
INT4_TEXT_ENCODER_PATH = MODEL_DIR / "text_encoder_int4.xml"
20+
INT4_VAE_DECODER_PATH = MODEL_DIR / "vae_decoder_int4.xml"
21+
22+
NEGATIVE_PROMPTS = [
23+
"blurry unreal occluded",
24+
"low contrast disfigured uncentered mangled",
25+
"amateur out of frame low quality nsfw",
26+
"ugly underexposed jpeg artifacts",
27+
"low saturation disturbing content",
28+
"overexposed severe distortion",
29+
"amateur NSFW",
30+
"ugly mutilated out of frame disfigured",
31+
]
32+
33+
34+
def disable_progress_bar(pipeline, disable=True):
35+
if not hasattr(pipeline, "_progress_bar_config"):
36+
pipeline._progress_bar_config = {'disable': disable}
37+
else:
38+
pipeline._progress_bar_config['disable'] = disable
39+
40+
41+
class CompiledModelDecorator(ov.CompiledModel):
42+
def __init__(self, compiled_model: ov.CompiledModel, data_cache: List[Any] = None, keep_prob: float = 0.5):
43+
super().__init__(compiled_model)
44+
self.data_cache = data_cache if data_cache is not None else []
45+
self.keep_prob = keep_prob
46+
47+
def __call__(self, *args, **kwargs):
48+
if np.random.rand() <= self.keep_prob:
49+
self.data_cache.append(*args)
50+
return super().__call__(*args, **kwargs)
51+
52+
53+
def collect_calibration_data(pipe: 'PixArtAlphaPipeline', subset_size: int) -> List[Dict]:
54+
calibration_data = []
55+
ov_transformer_model = pipe.transformer.transformer
56+
pipe.transformer.transformer = CompiledModelDecorator(ov_transformer_model, calibration_data, keep_prob=1.0)
57+
disable_progress_bar(pipe)
58+
59+
size = int(np.ceil(subset_size / NUM_INFERENCE_STEPS))
60+
dataset = datasets.load_dataset("google-research-datasets/conceptual_captions", split="train", trust_remote_code=True, streaming=True)
61+
dataset = dataset.shuffle(seed=42).take(size)
62+
63+
# Run inference for data collection
64+
pbar = tqdm(total=subset_size)
65+
for batch in dataset:
66+
caption = batch["caption"]
67+
if len(caption) > pipe.tokenizer.model_max_length:
68+
continue
69+
negative_prompt = np.random.choice(NEGATIVE_PROMPTS)
70+
pipe(
71+
prompt=caption,
72+
num_inference_steps=NUM_INFERENCE_STEPS,
73+
guidance_scale=0.0,
74+
generator=torch.Generator('cpu').manual_seed(42),
75+
negative_prompt=negative_prompt,
76+
height=256,
77+
width=256,
78+
)
79+
if len(calibration_data) >= subset_size:
80+
pbar.update(subset_size - pbar.n)
81+
break
82+
pbar.update(len(calibration_data) - pbar.n)
83+
84+
pipe.transformer.transformer = ov_transformer_model
85+
disable_progress_bar(pipe, disable=False)
86+
87+
return calibration_data
88+
89+
90+
def get_operation_const_op(operation, const_port_id: int):
91+
node = operation.input_value(const_port_id).get_node()
92+
queue = deque([node])
93+
constant_node = None
94+
allowed_propagation_types_list = ["Convert", "FakeQuantize", "Reshape"]
95+
96+
while len(queue) != 0:
97+
curr_node = queue.popleft()
98+
if curr_node.get_type_name() == "Constant":
99+
constant_node = curr_node
100+
break
101+
if len(curr_node.inputs()) == 0:
102+
break
103+
if curr_node.get_type_name() in allowed_propagation_types_list:
104+
queue.append(curr_node.input_value(0).get_node())
105+
106+
return constant_node
107+
108+
109+
def is_embedding(node) -> bool:
110+
allowed_types_list = ["f16", "f32", "f64"]
111+
const_port_id = 0
112+
input_tensor = node.input_value(const_port_id)
113+
if input_tensor.get_element_type().get_type_name() in allowed_types_list:
114+
const_node = get_operation_const_op(node, const_port_id)
115+
if const_node is not None:
116+
return True
117+
118+
return False
119+
120+
121+
def get_quantization_ignored_scope(model):
122+
ops_with_weights = []
123+
for op in model.get_ops():
124+
if op.get_type_name() == "MatMul":
125+
constant_node_0 = get_operation_const_op(op, const_port_id=0)
126+
constant_node_1 = get_operation_const_op(op, const_port_id=1)
127+
if constant_node_0 or constant_node_1:
128+
ops_with_weights.append(op.get_friendly_name())
129+
if op.get_type_name() == "Gather" and is_embedding(op):
130+
ops_with_weights.append(op.get_friendly_name())
131+
132+
return ops_with_weights
133+
134+
135+
def visualize_results(orig_img: Image, optimized_img: Image):
136+
"""
137+
Helper function for results visualization
138+
139+
Parameters:
140+
orig_img (Image.Image): generated image using FP16 models
141+
optimized_img (Image.Image): generated image using quantized models
142+
Returns:
143+
fig (matplotlib.pyplot.Figure): matplotlib generated figure contains drawing result
144+
"""
145+
orig_title = "FP16 pipeline"
146+
control_title = "Optimized pipeline"
147+
figsize = (20, 20)
148+
fig, axs = plt.subplots(1, 2, figsize=figsize, sharex="all", sharey="all")
149+
list_axes = list(axs.flat)
150+
for a in list_axes:
151+
a.set_xticklabels([])
152+
a.set_yticklabels([])
153+
a.get_xaxis().set_visible(False)
154+
a.get_yaxis().set_visible(False)
155+
a.grid(False)
156+
list_axes[0].imshow(np.array(orig_img))
157+
list_axes[1].imshow(np.array(optimized_img))
158+
list_axes[0].set_title(orig_title, fontsize=15)
159+
list_axes[1].set_title(control_title, fontsize=15)
160+
161+
fig.subplots_adjust(wspace=0.01, hspace=0.01)
162+
fig.tight_layout()
163+
return fig
164+
165+
166+
def compare_models_size():
167+
fp16_model_paths = [TRANSFORMER_OV_PATH, TEXT_ENCODER_PATH, VAE_DECODER_PATH]
168+
optimized_models = [INT8_TRANSFORMER_OV_PATH, INT4_TEXT_ENCODER_PATH, INT4_VAE_DECODER_PATH]
169+
170+
for fp16_path, optimized_path in zip(fp16_model_paths, optimized_models):
171+
if not fp16_path.exists():
172+
continue
173+
fp16_ir_model_size = fp16_path.with_suffix(".bin").stat().st_size
174+
optimized_model_size = optimized_path.with_suffix(".bin").stat().st_size
175+
print(f"{fp16_path.stem} compression rate: {fp16_ir_model_size / optimized_model_size:.3f}")
176+
177+
178+
def calculate_inference_time(pipeline, validation_data):
179+
inference_time = []
180+
pipeline.set_progress_bar_config(disable=True)
181+
182+
for caption in validation_data:
183+
negative_prompt = np.random.choice(NEGATIVE_PROMPTS)
184+
start = time.perf_counter()
185+
pipeline(
186+
caption,
187+
negative_prompt=negative_prompt,
188+
num_inference_steps=NUM_INFERENCE_STEPS,
189+
guidance_scale=0.0,
190+
generator=torch.Generator('cpu').manual_seed(42),
191+
)
192+
end = time.perf_counter()
193+
delta = end - start
194+
inference_time.append(delta)
195+
196+
pipeline.set_progress_bar_config(disable=False)
197+
return np.median(inference_time)
198+
199+
200+
def compare_perf(ov_pipe, optimized_pipe, validation_size=3):
201+
validation_dataset = datasets.load_dataset("google-research-datasets/conceptual_captions", split="train", streaming=True, trust_remote_code=True)
202+
validation_dataset = validation_dataset.take(validation_size)
203+
validation_data = [batch["caption"] for batch in validation_dataset]
204+
205+
fp_latency = calculate_inference_time(ov_pipe, validation_data)
206+
print(f"FP16 pipeline: {fp_latency:.3f} seconds")
207+
opt_latency = calculate_inference_time(optimized_pipe, validation_data)
208+
print(f"Optimized pipeline: {opt_latency:.3f} seconds")
209+
print(f"Performance speed-up: {fp_latency / opt_latency:.3f}")

0 commit comments

Comments
 (0)