Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,11 +2261,14 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
layer_names = []
if layer_names is None:
layer_names = []
if self.low_gpu_mem_usage or (
len(block_names) == 1
and len(layer_names) == 0
and not self.has_qlayer_outside_block
and (last_cache_name is None or last_cache_name in block_names)
if not os.environ.get("AR_CALIB_FORCE_CUDA", False) and (
self.low_gpu_mem_usage
or (
len(block_names) == 1
and len(layer_names) == 0
and not self.has_qlayer_outside_block
and (last_cache_name is None or last_cache_name in block_names)
)
):
# low_gpu_mem_usage or calibrate only the embedding layer, which is also very fast on CPU
all_inputs = self.cache_inter_data(block_names, nsamples, layer_names=[], last_cache_name=last_cache_name)
Expand Down
103 changes: 100 additions & 3 deletions auto_round/compressors/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"""

import os
import sys
from datetime import datetime, timedelta

import torch
Expand Down Expand Up @@ -138,7 +139,7 @@ def _process_v2(self, messages, image):
conversation[-1]["content"].append({"image": image, "type": "image"})
else:
conversation.append({"role": content["role"], "content": content["content"]})
if hasattr(self.processor, "chat_template"):
if hasattr(self.processor, "chat_template") and self.processor.chat_template is not None:
text = self.processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=False, return_dict=False
)
Expand All @@ -165,7 +166,7 @@ def get_input(
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs
**kwargs,
):

if isinstance(text, list):
Expand Down Expand Up @@ -197,6 +198,102 @@ def squeeze_result(ret):
return ret


@register_processor("longcat_next")
class LongCatNextProcessor(BasicProcessor):
"""Processor for meituan-longcat/LongCat-Next multimodal models.

LongCat-Next supports text, image, and audio inputs. Images are referenced
in the conversation text via ``<longcat_img_start>URL<longcat_img_end>``
tags. The HuggingFace ``processor`` returns a tuple of
``(text_inputs, visual_inputs, audio_inputs)`` instead of a single dict,
so this class unpacks them into a flat dict suitable for ``model.forward()``.
"""

IMAGE_TOKEN = "<image>"
LONGCAT_IMG_START = "<longcat_img_start>"
LONGCAT_IMG_END = "<longcat_img_end>"

def post_init(self, model, tokenizer, processor=None, image_processor=None, use_rtn=False, **kwargs):
assert tokenizer is not None, "tokenizer should not be None"
assert processor is not None, "processor should not be None"
self.model = model
self.tokenizer = tokenizer
self.tokenizer.fix_mistral_regex = True
self.processor = processor
if image_processor is not None:
self.image_processor = image_processor
else:
self.image_processor = self.default_image_processor
self.use_rtn = use_rtn
self.check_image_processor()
Comment thread
xin3he marked this conversation as resolved.
Outdated

# build generation_config from model file because the code is on hub.
model_module = sys.modules[self.model.__module__]
GenerationConfig = model_module.GenerationConfig
LongcatNextForCausalLMGenerationStatus = model_module.LongcatNextForCausalLMGenerationStatus

self.visual_generation_config = GenerationConfig(**self.model.generation_config.visual_generation_config)
self.audio_generation_config = GenerationConfig(**self.model.generation_config.audio_generation_config)
self.multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(
self.visual_generation_config, self.audio_generation_config
)
Comment thread
xin3he marked this conversation as resolved.

def get_input(self, text, images, squeeze=True, max_length=None, truncation=False, **kwargs):
if isinstance(text, list):
# text is a list of message dicts (conversation format)
messages = []
for content in text:
msg = {"role": content["role"], "content": content["content"]}
if self.IMAGE_TOKEN in content["content"] and images is not None:
# Replace generic <image> token with LongCat image tags wrapping the URL
msg["content"] = content["content"].replace(
self.IMAGE_TOKEN,
f"{self.LONGCAT_IMG_START}{images}{self.LONGCAT_IMG_END}",
)
messages.append(msg)

text_input = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
# Plain string input
if max_length is not None:
text_input = self.tokenizer.decode(self.tokenizer(text).input_ids[:max_length])
else:
text_input = text

# LongCat processor returns (text_inputs, visual_inputs, audio_inputs)
text_inputs, visual_inputs, audio_inputs = self.processor(text=text_input, return_tensors="pt")

ret = {}
for key in text_inputs:
ret[key] = text_inputs[key]
if visual_inputs is not None:
ret["visual_inputs"] = visual_inputs
if audio_inputs is not None:
ret["audio_inputs"] = audio_inputs
ret["multimodal_generation_status"] = self.multimodal_generation_status
ret["visual_generation_config"] = self.visual_generation_config

return ret
Comment thread
xin3he marked this conversation as resolved.

@staticmethod
def data_collator(batch):
if len(batch) == 1:
return batch[0]

batched_data = {}
for key in batch[0].keys():
values = [item[key] for item in batch]
if isinstance(values[0], torch.Tensor):
try:
batched_data[key] = torch.stack(values)
except (RuntimeError, TypeError):
batched_data[key] = values
else:
batched_data[key] = values

return batched_data


@register_processor("qwen2_5_omni")
class Qwen2_5OmniProcessor(HFProcessor):
"""Processor for Qwen2.5-Omni multimodal models.
Expand Down Expand Up @@ -438,7 +535,7 @@ def get_input(
max_length=None,
truncation=False,
truncation_strategy="text",
**kwargs
**kwargs,
):
from mistral_common.protocol.instruct.request import ChatCompletionRequest # pylint: disable=E0401

Expand Down
1 change: 1 addition & 0 deletions auto_round/compressors/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _register_template(
_register_template("mistral3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"])
_register_template("mistral3_2", default_dataset="liuhaotian/llava", processor=PROCESSORS["mistral3_2"])
_register_template("gemma3", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["hf"])
_register_template("longcat_next", default_dataset="NeelNanda/pile-10k", processor=PROCESSORS["longcat_next"])

Comment thread
xin3he marked this conversation as resolved.

def load_template(path: str):
Expand Down
1 change: 1 addition & 0 deletions auto_round/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"AR_DISABLE_DATASET_SUBPROCESS": lambda: os.getenv("AR_DISABLE_DATASET_SUBPROCESS", "0").lower() in ("1", "true"),
"AR_DISABLE_COPY_MTP_WEIGHTS": lambda: os.getenv("AR_DISABLE_COPY_MTP_WEIGHTS", "0").lower()
in ("1", "true", "yes"),
"AR_CALIB_FORCE_CUDA": lambda: os.getenv("AR_CALIB_FORCE_CUDA", "0").lower() in ("1", "true", "yes"),
}


Expand Down
9 changes: 8 additions & 1 deletion auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
is_quantized_input_module,
)

FIX_MISTRAL_REGEX_MODEL_TYPE_LIST = ["longcat_next"]


def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None:
"""This function is recommended to be used instead of module.weight = None.
Expand Down Expand Up @@ -599,6 +601,7 @@ def mllm_load_model(
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
fix_mistral_regex=True if model_type in FIX_MISTRAL_REGEX_MODEL_TYPE_LIST else False,
**processor_load_kwargs,
)
processor = AutoProcessor.from_pretrained(
Expand Down Expand Up @@ -874,7 +877,8 @@ def _get_llm_block_names(model):
block_names[i].append(target_m[0] + "." + n)
return block_names

def _get_vlm_block_names(model, quant_vision=False):
def _get_vlm_block_names(model, quant_vision=False, ignore_audio=True):
# Since calibration dataset doesn't contain audio data, audio-related blocks will be ignored by default.
if (
hasattr(model, "config")
and hasattr(model.config, "model_type")
Expand All @@ -884,10 +888,13 @@ def _get_vlm_block_names(model, quant_vision=False):
block_names = []
target_modules = []
vision_blocks_tuple = ("vision", "visual", "image", "img")
audio_blocks_tuple = ("audio", "speech", "wav", "waveform")
target_modules = _search_block("", model)

for i, target_m in enumerate(target_modules):
if quant_vision or all(key not in target_m[0].lower() for key in (vision_blocks_tuple)):
if ignore_audio and any(key in target_m[0].lower() for key in audio_blocks_tuple):
continue
block_names.append([])
for n, m in target_m[1].named_children():
block_names[-1].append(target_m[0] + "." + n)
Expand Down
Loading