diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index a8735407e..e755f9944 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -92,6 +92,7 @@ get_lm_head_name, get_module, global_state, + hook_ngram_embeddings_on_cpu, htcore, is_auto_device_mapping, is_debug_mode, @@ -2212,8 +2213,15 @@ def calib(self, nsamples, bs): self.model(*data_new, **kwargs) else: self.model(**data_new, **kwargs) - except NotImplementedError: - pass + except NotImplementedError as error: + error_msg = str(error) + # Raise NotImplementedError to fallback to CUDA device + if "flash_attn::" in error_msg and "CPU" in error_msg: + raise NotImplementedError( + "Could not run 'flash_attn::_flash_attn_varlen_forward' with arguments from the 'CPU' backend." + ) + else: + pass except RuntimeError as error: error_msg = str(error) if "The expanded size of the tensor" in str(error_msg) and "must match the existing size" in error_msg: @@ -2261,6 +2269,9 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l layer_names = [] if layer_names is None: layer_names = [] + + calibrate_on_cpu = False + cannot_calibrate_on_cpu = False if self.low_gpu_mem_usage or ( len(block_names) == 1 and len(layer_names) == 0 @@ -2268,8 +2279,19 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l and (last_cache_name is None or last_cache_name in block_names) ): # low_gpu_mem_usage or calibrate only the embedding layer, which is also very fast on CPU - all_inputs = self.cache_inter_data(block_names, nsamples, layer_names=[], last_cache_name=last_cache_name) - else: + calibrate_on_cpu = True + try: + all_inputs = self.cache_inter_data( + block_names, nsamples, layer_names=[], last_cache_name=last_cache_name + ) + except NotImplementedError as error: + error_msg = str(error) + if "flash_attn::" in error_msg and "CPU" in error_msg: + cannot_calibrate_on_cpu = True # fallback to GPU when flash attention is not supported on CPU + else: + raise error + + if not calibrate_on_cpu or cannot_calibrate_on_cpu: try: if any(p.device.type == "meta" for p in self.model.parameters()): materialize_model_(self.model) @@ -2314,7 +2336,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l "No non-CPU device available in accelerate's reported memory. " "Falling back to CPU caching." ) - + # Keep ngram_embeddings on CPU + has_ngram_embeddings, raw_ngram_embeddings = hook_ngram_embeddings_on_cpu(self.model) new_max_memory = get_balanced_memory( self.model, max_memory=new_max_memory, @@ -2331,8 +2354,9 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l ) try: - self.model = dispatch_model(self.model, device_map=device_map) + if has_ngram_embeddings: + self.model.model.ngram_embeddings = raw_ngram_embeddings except ValueError as e: if "offload_dir" in e.__str__(): logger.warning( @@ -2345,7 +2369,6 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l else: raise else: - self.model = self.model.to(self.device) all_inputs = self.cache_inter_data( @@ -2354,7 +2377,9 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) - except torch.OutOfMemoryError: + except torch.OutOfMemoryError as e: + if cannot_calibrate_on_cpu: + raise e cuda_error_msg = traceback.format_exc() try: logger.info("switch to cpu to cache block inputs") @@ -2415,13 +2440,16 @@ def cache_inter_data(self, block_names, nsamples, layer_names=None, last_cache_n calib_bs = self.batch_size self.hook_handles = [] self._replace_forward() - self.calib(nsamples, calib_bs) - self._recover_forward() + try: + self.calib(nsamples, calib_bs) + finally: + # Use finally to recover_forward and delattr in case of that + # self.calib raises NotImplementedError, such as: flash_attn on CPU. + self._recover_forward() + for attr in ("last_cache_name", "_cache_target_set", "_cache_seen_targets", "to_cached_layers"): + if hasattr(self, attr): + delattr(self, attr) res = self.inputs - del self.last_cache_name - del self._cache_target_set - del self._cache_seen_targets - del self.to_cached_layers if tmp_dtype is not None: self.model = self.model.to(tmp_dtype) @@ -2613,7 +2641,6 @@ def _recover_forward(self): def _replace_forward(self): """Replaces the forward function.""" - for n, m in self.model.named_modules(): if n in self.to_cached_layers and type(m) not in self.supported_types: ##block m.orig_forward = m.forward diff --git a/auto_round/compressors/mllm/processor.py b/auto_round/compressors/mllm/processor.py index b42962cbb..ffd6ae387 100644 --- a/auto_round/compressors/mllm/processor.py +++ b/auto_round/compressors/mllm/processor.py @@ -29,6 +29,7 @@ """ import os +import sys from datetime import datetime, timedelta import torch @@ -138,7 +139,7 @@ def _process_v2(self, messages, image): conversation[-1]["content"].append({"image": image, "type": "image"}) else: conversation.append({"role": content["role"], "content": content["content"]}) - if hasattr(self.processor, "chat_template"): + if hasattr(self.processor, "chat_template") and self.processor.chat_template is not None: text = self.processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False, return_dict=False ) @@ -165,7 +166,7 @@ def get_input( max_length=None, truncation=False, truncation_strategy="text", - **kwargs + **kwargs, ): if isinstance(text, list): @@ -197,6 +198,105 @@ def squeeze_result(ret): return ret +@register_processor("longcat_next") +class LongCatNextProcessor(BasicProcessor): + """Processor for meituan-longcat/LongCat-Next multimodal models. + + LongCat-Next supports text, image, and audio inputs. Images are referenced + in the conversation text via ``URL`` + tags. The HuggingFace ``processor`` returns a tuple of + ``(text_inputs, visual_inputs, audio_inputs)`` instead of a single dict, + so this class unpacks them into a flat dict suitable for ``model.forward()``. + """ + + IMAGE_TOKEN = "" + LONGCAT_IMG_START = "" + LONGCAT_IMG_END = "" + + def post_init(self, model, tokenizer, processor=None, image_processor=None, use_rtn=False, **kwargs): + assert tokenizer is not None, "tokenizer should not be None" + assert processor is not None, "processor should not be None" + self.model = model + self.tokenizer = tokenizer + self.tokenizer.fix_mistral_regex = True + self.processor = processor + if image_processor is not None: + self.image_processor = image_processor + else: + self.image_processor = self.default_image_processor + self.use_rtn = use_rtn + # LongCat-Next get_input() relies on the HF processor output directly and + # does not use self.image_processor in the current input path. Do not + # enforce image_processor availability here so text-only calibration can + # still proceed when AutoImageProcessor loading is unavailable. + + # build generation_config from model file because the code is on hub. + model_module = sys.modules[self.model.__module__] + GenerationConfig = model_module.GenerationConfig + LongcatNextForCausalLMGenerationStatus = model_module.LongcatNextForCausalLMGenerationStatus + + self.visual_generation_config = GenerationConfig(**self.model.generation_config.visual_generation_config) + self.audio_generation_config = GenerationConfig(**self.model.generation_config.audio_generation_config) + self.multimodal_generation_status = LongcatNextForCausalLMGenerationStatus( + self.visual_generation_config, self.audio_generation_config + ) + + def get_input(self, text, images, squeeze=True, max_length=None, truncation=False, **kwargs): + if isinstance(text, list): + # text is a list of message dicts (conversation format) + messages = [] + for content in text: + msg = {"role": content["role"], "content": content["content"]} + if self.IMAGE_TOKEN in content["content"] and images is not None: + # Replace generic token with LongCat image tags wrapping the URL + msg["content"] = content["content"].replace( + self.IMAGE_TOKEN, + f"{self.LONGCAT_IMG_START}{images}{self.LONGCAT_IMG_END}", + ) + messages.append(msg) + + text_input = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + else: + # Plain string input + if max_length is not None: + text_input = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length]) + else: + text_input = text + + # LongCat processor returns (text_inputs, visual_inputs, audio_inputs) + text_inputs, visual_inputs, audio_inputs = self.processor(text=text_input, return_tensors="pt") + + ret = {} + for key in text_inputs: + ret[key] = text_inputs[key] + if visual_inputs is not None: + ret["visual_inputs"] = visual_inputs + if audio_inputs is not None: + ret["audio_inputs"] = audio_inputs + ret["multimodal_generation_status"] = self.multimodal_generation_status + ret["visual_generation_config"] = self.visual_generation_config + + return ret + + @staticmethod + def data_collator(batch): + if len(batch) == 1: + return batch[0] + + batched_data = {} + for key in batch[0].keys(): + values = [item[key] for item in batch] + if isinstance(values[0], torch.Tensor): + try: + batched_data[key] = torch.stack(values) + except (RuntimeError, TypeError): + batched_data[key] = values + else: + batched_data[key] = values + + return batched_data + + @register_processor("qwen2_5_omni") class Qwen2_5OmniProcessor(HFProcessor): """Processor for Qwen2.5-Omni multimodal models. @@ -438,7 +538,7 @@ def get_input( max_length=None, truncation=False, truncation_strategy="text", - **kwargs + **kwargs, ): from mistral_common.protocol.instruct.request import ChatCompletionRequest # pylint: disable=E0401 diff --git a/auto_round/compressors/mllm/template.py b/auto_round/compressors/mllm/template.py index 541dfcef9..78abc6424 100644 --- a/auto_round/compressors/mllm/template.py +++ b/auto_round/compressors/mllm/template.py @@ -127,6 +127,7 @@ def _register_template( _register_template("mistral3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"]) _register_template("mistral3_2", default_dataset="liuhaotian/llava", processor=PROCESSORS["mistral3_2"]) _register_template("gemma3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"]) +_register_template("longcat_next", default_dataset="liuhaotian/llava", processor=PROCESSORS["longcat_next"]) def load_template(path: str): diff --git a/auto_round/compressors/shard_writer.py b/auto_round/compressors/shard_writer.py index 3e0f63c47..f061a92b9 100644 --- a/auto_round/compressors/shard_writer.py +++ b/auto_round/compressors/shard_writer.py @@ -131,6 +131,27 @@ def _add_tensor(self, name: str, tensor: torch.Tensor): self.current_shard_tensors[name] = tensor self.current_shard_size += t_size + def _handle_tied_weights(self): + """ + Detects tied weights in the current shard and ensures they are only saved once. + This is done by tracking storage pointers of tensors and skipping duplicates. + """ + from collections import defaultdict + + storage_map = set() + filtered_tensors = {} + + for name, tensor in self.current_shard_tensors.items(): + if not isinstance(tensor, torch.Tensor): + filtered_tensors[name] = tensor + continue + + ptr = tensor.untyped_storage().data_ptr() + if ptr not in storage_map: + storage_map.add(ptr) + filtered_tensors[name] = tensor + self.current_shard_tensors = filtered_tensors + def _flush_shard(self): if not self.current_shard_tensors: return @@ -138,6 +159,7 @@ def _flush_shard(self): self.shard_counter += 1 tmp_name = f"model-shard-{self.shard_counter:05d}.{self.shard_suffix}" tmp_path = os.path.join(self.output_dir, tmp_name) + self._handle_tied_weights() if self.use_safetensors: from safetensors.torch import save_file diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index f0aec180a..eb4ddcec7 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -36,6 +36,8 @@ is_quantized_input_module, ) +FIX_MISTRAL_REGEX_MODEL_TYPE_LIST = ["longcat_next"] + def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None: """This function is recommended to be used instead of module.weight = None. @@ -599,6 +601,7 @@ def mllm_load_model( tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, trust_remote_code=trust_remote_code, + fix_mistral_regex=True if model_type in FIX_MISTRAL_REGEX_MODEL_TYPE_LIST else False, **processor_load_kwargs, ) processor = AutoProcessor.from_pretrained( @@ -874,7 +877,8 @@ def _get_llm_block_names(model): block_names[i].append(target_m[0] + "." + n) return block_names - def _get_vlm_block_names(model, quant_vision=False): + def _get_vlm_block_names(model, quant_vision=False, ignore_audio=True): + # Since calibration dataset doesn't contain audio data, audio-related blocks will be ignored by default. if ( hasattr(model, "config") and hasattr(model.config, "model_type") @@ -884,10 +888,13 @@ def _get_vlm_block_names(model, quant_vision=False): block_names = [] target_modules = [] vision_blocks_tuple = ("vision", "visual", "image", "img") + audio_blocks_tuple = ("audio", "speech", "wav", "waveform") target_modules = _search_block("", model) for i, target_m in enumerate(target_modules): if quant_vision or all(key not in target_m[0].lower() for key in (vision_blocks_tuple)): + if ignore_audio and any(key in target_m[0].lower() for key in audio_blocks_tuple): + continue block_names.append([]) for n, m in target_m[1].named_children(): block_names[-1].append(target_m[0] + "." + n) @@ -1882,3 +1889,22 @@ def forward(m, hidden_states=None, *positional_inputs, **kwargs): return base_hook(m, hidden_states, *positional_inputs, **kwargs) return forward + + +def hook_ngram_embeddings_on_cpu(model): + has_ngram_embeddings = hasattr(model, "model") and hasattr(model.model, "ngram_embeddings") + if has_ngram_embeddings: + raw_ngram_embeddings = model.model.ngram_embeddings + + def hook_input_output_device_for_cpu_module(module): + from accelerate.hooks import AlignDevicesHook, add_hook_to_module + + hook = AlignDevicesHook( + io_same_device=True, + execution_device="cpu", + ) + + add_hook_to_module(module, hook) + + hook_input_output_device_for_cpu_module(raw_ngram_embeddings) + return has_ngram_embeddings, raw_ngram_embeddings if has_ngram_embeddings else None