Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
get_lm_head_name,
get_module,
global_state,
hook_ngram_embeddings_on_cpu,
htcore,
is_auto_device_mapping,
is_debug_mode,
Expand Down Expand Up @@ -2212,8 +2213,15 @@ def calib(self, nsamples, bs):
self.model(*data_new, **kwargs)
else:
self.model(**data_new, **kwargs)
except NotImplementedError:
pass
except NotImplementedError as error:
error_msg = str(error)
# Raise NotImplementedError to fallback to CUDA device
if "flash_attn::" in error_msg and "CPU" in error_msg:
raise NotImplementedError(
"Could not run 'flash_attn::_flash_attn_varlen_forward' with arguments from the 'CPU' backend."
)
else:
pass
except RuntimeError as error:
error_msg = str(error)
if "The expanded size of the tensor" in str(error_msg) and "must match the existing size" in error_msg:
Expand Down Expand Up @@ -2261,15 +2269,29 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
layer_names = []
if layer_names is None:
layer_names = []

calibrate_on_cpu = False
cannot_calibrate_on_cpu = False
if self.low_gpu_mem_usage or (
len(block_names) == 1
and len(layer_names) == 0
and not self.has_qlayer_outside_block
and (last_cache_name is None or last_cache_name in block_names)
):
# low_gpu_mem_usage or calibrate only the embedding layer, which is also very fast on CPU
all_inputs = self.cache_inter_data(block_names, nsamples, layer_names=[], last_cache_name=last_cache_name)
else:
calibrate_on_cpu = True
try:
all_inputs = self.cache_inter_data(
block_names, nsamples, layer_names=[], last_cache_name=last_cache_name
)
except NotImplementedError as error:
error_msg = str(error)
if "flash_attn::" in error_msg and "CPU" in error_msg:
cannot_calibrate_on_cpu = True # fallback to GPU when flash attention is not supported on CPU
else:
raise error

if not calibrate_on_cpu or cannot_calibrate_on_cpu:
try:
if any(p.device.type == "meta" for p in self.model.parameters()):
materialize_model_(self.model)
Expand Down Expand Up @@ -2314,7 +2336,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
"No non-CPU device available in accelerate's reported memory. "
"Falling back to CPU caching."
)

# Keep ngram_embeddings on CPU
has_ngram_embeddings, raw_ngram_embeddings = hook_ngram_embeddings_on_cpu(self.model)
new_max_memory = get_balanced_memory(
self.model,
max_memory=new_max_memory,
Expand All @@ -2331,8 +2354,9 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
)

try:

self.model = dispatch_model(self.model, device_map=device_map)
if has_ngram_embeddings:
self.model.model.ngram_embeddings = raw_ngram_embeddings
except ValueError as e:
if "offload_dir" in e.__str__():
logger.warning(
Expand All @@ -2345,7 +2369,6 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
else:
raise
else:

self.model = self.model.to(self.device)

all_inputs = self.cache_inter_data(
Expand All @@ -2354,7 +2377,9 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model)

except torch.OutOfMemoryError:
except torch.OutOfMemoryError as e:
if cannot_calibrate_on_cpu:
raise e
cuda_error_msg = traceback.format_exc()
try:
logger.info("switch to cpu to cache block inputs")
Expand Down Expand Up @@ -2415,13 +2440,16 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n
calib_bs = self.batch_size
self.hook_handles = []
self._replace_forward()
self.calib(nsamples, calib_bs)
self._recover_forward()
try:
self.calib(nsamples, calib_bs)
finally:
# Use finally to recover_forward and delattr in case of that
# self.calib raises NotImplementedError, such as: flash_attn on CPU.
self._recover_forward()
for attr in ("last_cache_name", "_cache_target_set", "_cache_seen_targets", "to_cached_layers"):
if hasattr(self, attr):
delattr(self, attr)
res = self.inputs
del self.last_cache_name
del self._cache_target_set
del self._cache_seen_targets
del self.to_cached_layers
if tmp_dtype is not None:
self.model = self.model.to(tmp_dtype)

Expand Down Expand Up @@ -2613,7 +2641,6 @@ def _recover_forward(self):

def _replace_forward(self):
"""Replaces the forward function."""

for n, m in self.model.named_modules():
if n in self.to_cached_layers and type(m) not in self.supported_types: ##block
m.orig_forward = m.forward
Expand Down
106 changes: 103 additions & 3 deletions auto_round/compressors/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import os
import sys
from datetime import datetime, timedelta

import torch
Expand Down Expand Up @@ -138,7 +139,7 @@ def _process_v2(self, messages, image):
conversation[-1]["content"].append({"image": image, "type": "image"})
else:
conversation.append({"role": content["role"], "content": content["content"]})
if hasattr(self.processor, "chat_template"):
if hasattr(self.processor, "chat_template") and self.processor.chat_template is not None:
text = self.processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=False, return_dict=False
)
Expand All @@ -165,7 +166,7 @@ def get_input(
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs
**kwargs,
):

if isinstance(text, list):
Expand Down Expand Up @@ -197,6 +198,105 @@ def squeeze_result(ret):
return ret


@register_processor("longcat_next")
class LongCatNextProcessor(BasicProcessor):
"""Processor for meituan-longcat/LongCat-Next multimodal models.

LongCat-Next supports text, image, and audio inputs. Images are referenced
in the conversation text via ``<longcat_img_start>URL<longcat_img_end>``
tags. The HuggingFace ``processor`` returns a tuple of
``(text_inputs, visual_inputs, audio_inputs)`` instead of a single dict,
so this class unpacks them into a flat dict suitable for ``model.forward()``.
"""

IMAGE_TOKEN = "<image>"
LONGCAT_IMG_START = "<longcat_img_start>"
LONGCAT_IMG_END = "<longcat_img_end>"

def post_init(self, model, tokenizer, processor=None, image_processor=None, use_rtn=False, **kwargs):
assert tokenizer is not None, "tokenizer should not be None"
assert processor is not None, "processor should not be None"
self.model = model
self.tokenizer = tokenizer
self.tokenizer.fix_mistral_regex = True
self.processor = processor
if image_processor is not None:
self.image_processor = image_processor
else:
self.image_processor = self.default_image_processor
self.use_rtn = use_rtn
# LongCat-Next get_input() relies on the HF processor output directly and
# does not use self.image_processor in the current input path. Do not
# enforce image_processor availability here so text-only calibration can
# still proceed when AutoImageProcessor loading is unavailable.

# build generation_config from model file because the code is on hub.
model_module = sys.modules[self.model.__module__]
GenerationConfig = model_module.GenerationConfig
LongcatNextForCausalLMGenerationStatus = model_module.LongcatNextForCausalLMGenerationStatus

self.visual_generation_config = GenerationConfig(**self.model.generation_config.visual_generation_config)
self.audio_generation_config = GenerationConfig(**self.model.generation_config.audio_generation_config)
self.multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(
self.visual_generation_config, self.audio_generation_config
)

def get_input(self, text, images, squeeze=True, max_length=None, truncation=False, **kwargs):
if isinstance(text, list):
# text is a list of message dicts (conversation format)
messages = []
for content in text:
msg = {"role": content["role"], "content": content["content"]}
if self.IMAGE_TOKEN in content["content"] and images is not None:
# Replace generic <image> token with LongCat image tags wrapping the URL
msg["content"] = content["content"].replace(
self.IMAGE_TOKEN,
f"{self.LONGCAT_IMG_START}{images}{self.LONGCAT_IMG_END}",
)
messages.append(msg)

text_input = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Plain string input
if max_length is not None:
text_input = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length])
else:
text_input = text

# LongCat processor returns (text_inputs, visual_inputs, audio_inputs)
text_inputs, visual_inputs, audio_inputs = self.processor(text=text_input, return_tensors="pt")

ret = {}
for key in text_inputs:
ret[key] = text_inputs[key]
if visual_inputs is not None:
ret["visual_inputs"] = visual_inputs
if audio_inputs is not None:
ret["audio_inputs"] = audio_inputs
ret["multimodal_generation_status"] = self.multimodal_generation_status
ret["visual_generation_config"] = self.visual_generation_config

return ret

@staticmethod
def data_collator(batch):
if len(batch) == 1:
return batch[0]

batched_data = {}
for key in batch[0].keys():
values = [item[key] for item in batch]
if isinstance(values[0], torch.Tensor):
try:
batched_data[key] = torch.stack(values)
except (RuntimeError, TypeError):
batched_data[key] = values
else:
batched_data[key] = values

return batched_data


@register_processor("qwen2_5_omni")
class Qwen2_5OmniProcessor(HFProcessor):
"""Processor for Qwen2.5-Omni multimodal models.
Expand Down Expand Up @@ -438,7 +538,7 @@ def get_input(
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs
**kwargs,
):
from mistral_common.protocol.instruct.request import ChatCompletionRequest # pylint: disable=E0401

Expand Down
1 change: 1 addition & 0 deletions auto_round/compressors/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _register_template(
_register_template("mistral3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"])
_register_template("mistral3_2", default_dataset="liuhaotian/llava", processor=PROCESSORS["mistral3_2"])
_register_template("gemma3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"])
_register_template("longcat_next", default_dataset="liuhaotian/llava", processor=PROCESSORS["longcat_next"])


def load_template(path: str):
Expand Down
22 changes: 22 additions & 0 deletions auto_round/compressors/shard_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,35 @@ def _add_tensor(self, name: str, tensor: torch.Tensor):
self.current_shard_tensors[name] = tensor
self.current_shard_size += t_size

def _handle_tied_weights(self):
"""
Detects tied weights in the current shard and ensures they are only saved once.
This is done by tracking storage pointers of tensors and skipping duplicates.
"""
from collections import defaultdict

storage_map = set()
filtered_tensors = {}

for name, tensor in self.current_shard_tensors.items():
if not isinstance(tensor, torch.Tensor):
filtered_tensors[name] = tensor
continue

ptr = tensor.untyped_storage().data_ptr()
if ptr not in storage_map:
storage_map.add(ptr)
filtered_tensors[name] = tensor
self.current_shard_tensors = filtered_tensors

def _flush_shard(self):
if not self.current_shard_tensors:
return

self.shard_counter += 1
tmp_name = f"model-shard-{self.shard_counter:05d}.{self.shard_suffix}"
tmp_path = os.path.join(self.output_dir, tmp_name)
self._handle_tied_weights()

if self.use_safetensors:
from safetensors.torch import save_file
Expand Down
28 changes: 27 additions & 1 deletion auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
is_quantized_input_module,
)

FIX_MISTRAL_REGEX_MODEL_TYPE_LIST = ["longcat_next"]


def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None:
"""This function is recommended to be used instead of module.weight = None.
Expand Down Expand Up @@ -599,6 +601,7 @@ def mllm_load_model(
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
fix_mistral_regex=True if model_type in FIX_MISTRAL_REGEX_MODEL_TYPE_LIST else False,
**processor_load_kwargs,
)
processor = AutoProcessor.from_pretrained(
Expand Down Expand Up @@ -874,7 +877,8 @@ def _get_llm_block_names(model):
block_names[i].append(target_m[0] + "." + n)
return block_names

def _get_vlm_block_names(model, quant_vision=False):
def _get_vlm_block_names(model, quant_vision=False, ignore_audio=True):
# Since calibration dataset doesn't contain audio data, audio-related blocks will be ignored by default.
if (
hasattr(model, "config")
and hasattr(model.config, "model_type")
Expand All @@ -884,10 +888,13 @@ def _get_vlm_block_names(model, quant_vision=False):
block_names = []
target_modules = []
vision_blocks_tuple = ("vision", "visual", "image", "img")
audio_blocks_tuple = ("audio", "speech", "wav", "waveform")
target_modules = _search_block("", model)

for i, target_m in enumerate(target_modules):
if quant_vision or all(key not in target_m[0].lower() for key in (vision_blocks_tuple)):
if ignore_audio and any(key in target_m[0].lower() for key in audio_blocks_tuple):
continue
block_names.append([])
for n, m in target_m[1].named_children():
block_names[-1].append(target_m[0] + "." + n)
Expand Down Expand Up @@ -1882,3 +1889,22 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs):
return base_hook(m, hidden_states, *positional_inputs, **kwargs)

return forward


def hook_ngram_embeddings_on_cpu(model):
has_ngram_embeddings = hasattr(model, "model") and hasattr(model.model, "ngram_embeddings")
if has_ngram_embeddings:
raw_ngram_embeddings = model.model.ngram_embeddings

def hook_input_output_device_for_cpu_module(module):
from accelerate.hooks import AlignDevicesHook, add_hook_to_module

hook = AlignDevicesHook(
io_same_device=True,
execution_device="cpu",
)

add_hook_to_module(module, hook)

hook_input_output_device_for_cpu_module(raw_ngram_embeddings)
return has_ngram_embeddings, raw_ngram_embeddings if has_ngram_embeddings else None
Loading