Skip to content

Commit 24e4397

Browse files
authored
[GGUF] using quant_nontext_module to control whether quant vision model (#1317)
Signed-off-by: n1ck-guo <heng.guo@intel.com>
1 parent cf37d8e commit 24e4397

File tree

7 files changed

+105
-15
lines changed

7 files changed

+105
-15
lines changed

auto_round/compressors/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,11 +1198,8 @@ def _immediate_pack(self, name: str):
11981198
model=self.model,
11991199
device=self.device,
12001200
output_dir=self._get_save_folder_name(self.formats[0]),
1201-
mllm=self.mllm,
12021201
layer_config=self.layer_config,
12031202
tokenizer=self.tokenizer,
1204-
processor=self.processor if hasattr(self, "processor") else None,
1205-
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
12061203
)
12071204

12081205
@torch.inference_mode()

auto_round/compressors/mllm/compressor.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,18 @@ def __init__(
183183

184184
self.model = model
185185
quant_nontext_module = self._check_quant_nontext(layer_config, quant_nontext_module)
186-
if quant_nontext_module:
187-
from transformers.utils.versions import require_version
188-
189-
require_version(
190-
"pillow",
191-
"pillow is required for quantizing non-text modules, please install it with `pip install pillow`",
192-
)
186+
if quant_nontext_module and iters > 0:
187+
import importlib.util
188+
189+
missing_libs = []
190+
for require_lib in ["pillow", "torchvision"]:
191+
if importlib.util.find_spec(require_lib) is None:
192+
missing_libs.append(require_lib)
193+
if len(missing_libs) > 0:
194+
logger.error(
195+
f"{', '.join(missing_libs)} are required for quantizing non-text modules,"
196+
f" please install them with `pip install {' '.join(missing_libs)}`",
197+
)
193198
all_blocks = get_block_names(model, quant_nontext_module)
194199
self.quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)
195200
if to_quant_block_names is None:
@@ -453,7 +458,12 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
453458
if self.processor is not None and not hasattr(self.processor, "chat_template"):
454459
self.processor.chat_template = None
455460
compressed_model = super().save_quantized(
456-
output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs
461+
output_dir=output_dir,
462+
format=format,
463+
inplace=inplace,
464+
processor=self.processor,
465+
quant_nontext_module=self.quant_nontext_module if hasattr(self, "quant_nontext_module") else False,
466+
**kwargs,
457467
)
458468
return compressed_model
459469

@@ -467,3 +477,19 @@ def _check_quant_nontext(self, layer_config, quant_nontext_module):
467477
if vlm_key in layer_name and check_to_quantized(layer_config[layer_name]):
468478
return True
469479
return quant_nontext_module
480+
481+
def _immediate_pack(self, name: str):
482+
if not self.is_immediate_packing: # pylint: disable=E1101
483+
return
484+
self.formats[0].immediate_pack(
485+
name=name,
486+
model=self.model,
487+
device=self.device,
488+
output_dir=self._get_save_folder_name(self.formats[0]),
489+
mllm=self.mllm,
490+
layer_config=self.layer_config,
491+
tokenizer=self.tokenizer,
492+
processor=self.processor if hasattr(self, "processor") else None,
493+
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
494+
quant_nontext_module=self.quant_nontext_module if hasattr(self, "quant_nontext_module") else False,
495+
)

auto_round/export/export_to_gguf/convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ def download_convert_file(redownload=False):
8585
f.write(response.text)
8686

8787

88-
def wrapper_model_instance(model_instance, model, layer_config, low_cpu_mem_usage=False, device=None):
88+
def wrapper_model_instance(
89+
model_instance, model, layer_config, low_cpu_mem_usage=False, device=None, quant_nontext_module=False
90+
):
8991
if model_instance.model_arch == gguf.MODEL_ARCH.MMPROJ and model_instance.fname_out.is_dir():
9092
model_instance.fname_out = model_instance.fname_out / "mmproj-model.gguf"
9193
model_instance.model = model
@@ -96,6 +98,7 @@ def wrapper_model_instance(model_instance, model, layer_config, low_cpu_mem_usag
9698
model_instance.prepare_tensors = partial(prepare_tensors, model_instance)
9799

98100
model_instance.device = device
101+
model_instance.quant_nontext_module = quant_nontext_module
99102

100103
return model_instance
101104

@@ -528,6 +531,9 @@ def prepare_tensors(cls):
528531
elif data_qtype == gguf.GGMLQuantizationType.Q6_K:
529532
data_qtype = gguf.GGMLQuantizationType.Q8_0
530533

534+
if cls.model_arch == gguf.MODEL_ARCH.MMPROJ and cls.quant_nontext_module is False:
535+
data_qtype = gguf.GGMLQuantizationType.F32
536+
531537
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES
532538

533539
if data_qtype.name.lower() in GGML_QUANT_SIZES:

auto_round/export/export_to_gguf/export.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def create_model_class(
7474
low_cpu_mem_usage=False,
7575
model_type=convert_hf_to_gguf.ModelType.TEXT,
7676
device="cpu",
77+
quant_nontext_module: bool = False,
7778
):
7879
tmp_work_dir = model.name_or_path
7980
os.makedirs(output_dir, exist_ok=True)
@@ -118,7 +119,12 @@ def create_model_class(
118119
small_first_shard=False,
119120
)
120121
model_instance = wrapper_model_instance(
121-
model_instance, model=model, layer_config=layer_config, low_cpu_mem_usage=low_cpu_mem_usage, device=device
122+
model_instance,
123+
model=model,
124+
layer_config=layer_config,
125+
low_cpu_mem_usage=low_cpu_mem_usage,
126+
device=device,
127+
quant_nontext_module=quant_nontext_module,
122128
)
123129
model_instance = handle_special_model(model_instance, model_architecture)
124130
return model_instance
@@ -136,6 +142,7 @@ def pack_gguf_layer(
136142
image_processor=None,
137143
model_type=convert_hf_to_gguf.ModelType.TEXT,
138144
device="cpu",
145+
quant_nontext_module=False,
139146
):
140147
"""Export the model to gguf format."""
141148
global gguf_model_instance_global
@@ -153,6 +160,7 @@ def pack_gguf_layer(
153160
low_cpu_mem_usage=True,
154161
model_type=convert_hf_to_gguf.ModelType.TEXT,
155162
device=device,
163+
quant_nontext_module=quant_nontext_module,
156164
)
157165
]
158166
if model_type == convert_hf_to_gguf.ModelType.MMPROJ:
@@ -165,6 +173,7 @@ def pack_gguf_layer(
165173
low_cpu_mem_usage=True,
166174
model_type=convert_hf_to_gguf.ModelType.MMPROJ,
167175
device=device,
176+
quant_nontext_module=quant_nontext_module,
168177
)
169178
)
170179

@@ -215,7 +224,14 @@ def pack_gguf_layer(
215224

216225
@torch.inference_mode()
217226
def save_quantized_as_gguf(
218-
output_dir, model=None, backend="gguf:q4_0", layer_config=None, mllm=False, device="cpu", **kwargs
227+
output_dir,
228+
model=None,
229+
backend="gguf:q4_0",
230+
layer_config=None,
231+
mllm=False,
232+
device="cpu",
233+
quant_nontext_module=False,
234+
**kwargs,
219235
):
220236
"""Export the model to gguf format."""
221237
st = time.time()
@@ -224,7 +240,13 @@ def save_quantized_as_gguf(
224240
if "gguf_model_instance_global" not in globals():
225241
gguf_model_instance_global = [
226242
create_model_class(
227-
output_dir, model, layer_config, backend, model_type=convert_hf_to_gguf.ModelType.TEXT, device=device
243+
output_dir,
244+
model,
245+
layer_config,
246+
backend,
247+
model_type=convert_hf_to_gguf.ModelType.TEXT,
248+
device=device,
249+
quant_nontext_module=quant_nontext_module,
228250
)
229251
]
230252
if mllm:
@@ -236,6 +258,7 @@ def save_quantized_as_gguf(
236258
backend,
237259
model_type=convert_hf_to_gguf.ModelType.MMPROJ,
238260
device=device,
261+
quant_nontext_module=quant_nontext_module,
239262
)
240263
)
241264

auto_round/formats.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
SUPPORTED_FORMATS,
4444
check_to_quantized,
4545
copy_python_files_from_model_cache,
46+
find_matching_blocks,
47+
get_block_names,
4648
get_module,
4749
logger,
4850
)
@@ -647,6 +649,10 @@ def check_and_reset_format(self, ar):
647649
elif ar.bits >= 8 and ar.iters != 0:
648650
logger.warning_once("`iters=0` is recommended for bits>=8")
649651

652+
if getattr(ar, "quant_nontext_module", False):
653+
# for gguf export, leave vl model for gguf itself
654+
all_blocks = get_block_names(ar.model, False)
655+
ar.quant_block_list = find_matching_blocks(ar.model, all_blocks, None)
650656
return super().check_and_reset_format(ar)
651657

652658
def pack_layer(
@@ -661,6 +667,7 @@ def pack_layer(
661667
image_processor=None,
662668
model_type=ModelType.TEXT,
663669
device="cpu",
670+
quant_nontext_module=False,
664671
):
665672
from auto_round.export.export_to_gguf.export import pack_gguf_layer
666673

@@ -675,6 +682,7 @@ def pack_layer(
675682
image_processor,
676683
model_type,
677684
device,
685+
quant_nontext_module,
678686
)
679687

680688
def save_quantized(
@@ -826,6 +834,7 @@ def immediate_pack(
826834
tokenizer=None,
827835
processor=None,
828836
image_processor=None,
837+
quant_nontext_module: bool = False,
829838
**kwargs,
830839
):
831840
m = get_module(model, name)
@@ -843,6 +852,7 @@ def immediate_pack(
843852
image_processor=image_processor,
844853
model_type=model_type,
845854
device=device,
855+
quant_nontext_module=quant_nontext_module,
846856
)
847857

848858

test/test_cpu/export/test_gguf_format.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def test_vlm_gguf(self):
201201
iters=0,
202202
nsamples=8,
203203
disable_opt_rtn=True,
204+
quant_nontext_module=True,
204205
)
205206
quantized_model_path = "./saved"
206207
autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0")
@@ -214,6 +215,32 @@ def test_vlm_gguf(self):
214215
shutil.rmtree("./saved", ignore_errors=True)
215216
shutil.rmtree(tiny_model_path, ignore_errors=True)
216217

218+
def test_vlm_gguf_wo_quant_nontext_module(self):
219+
from ...helpers import save_tiny_model
220+
221+
model_name = get_model_path("Qwen/Qwen2-VL-2B-Instruct")
222+
tiny_model_path = save_tiny_model(model_name, "./tmp/tiny_qwen_vl_model_path", num_layers=3, is_mllm=True)
223+
from auto_round import AutoRoundMLLM
224+
225+
autoround = AutoRoundMLLM(
226+
tiny_model_path,
227+
iters=0,
228+
nsamples=8,
229+
disable_opt_rtn=True,
230+
quant_nontext_module=False,
231+
)
232+
quantized_model_path = "./saved"
233+
autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0")
234+
assert "mmproj-model.gguf" in os.listdir("./saved")
235+
for file_name in os.listdir(quantized_model_path):
236+
file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2
237+
if file_name == "mmproj-model.gguf":
238+
assert abs(file_size - 361) < 5.0
239+
else:
240+
assert abs(file_size - 264) < 5.0
241+
shutil.rmtree("./saved", ignore_errors=True)
242+
shutil.rmtree(tiny_model_path, ignore_errors=True)
243+
217244
def test_qtype_setting(self):
218245
# Qwen2.5-0.5B-Instruct no output, token_embed q6_k fallbakc to q8_0 336M
219246
# Qwen3-0.6B output q6_k, token_embed q4_0 448M

test/test_cuda/export/test_gguf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def test_vlm_gguf(self):
175175
nsamples=32,
176176
iters=0,
177177
disable_opt_rtn=True,
178+
quant_nontext_module=True,
178179
)
179180
quantized_model_path = "./saved"
180181
autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m")

0 commit comments

Comments
 (0)