Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions examples/multimodal_vision/qwen3_omni_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import requests
import soundfile as sf
from PIL import Image
from qwen3_omni_patch import fast_pos_embed_interpolate
from transformers import (
AutoProcessor,
Qwen3OmniMoeForConditionalGeneration,
default_data_collator,
)

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers.compression.compressed_tensors_utils import (
modify_save_pretrained,
)
from llmcompressor.utils import dispatch_for_generation

# Load model.
model_id = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
model_id, torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Apply patch
model.thinker.visual.fast_pos_embed_interpolate = fast_pos_embed_interpolate.__get__(
model.thinker.visual
)

# Oneshot arguments
BATCH_SIZE = 1
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}

# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"lm_head",
r"re:.*visual.*",
r"re:.*code2wav.*",
],
),
]


def data_collator(features):
batch = default_data_collator(features)
batch["image_grid_thw"] = batch["image_grid_thw"].squeeze(0)
return batch


# Perform oneshot
oneshot(
model=model.thinker, # base model does not define forward: pass `thinker` instead
processor=processor,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
batch_size=BATCH_SIZE,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
data_collator=data_collator,
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please describe the animal in this image\n"},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
image_url = "http://images.cocodataset.org/train2017/000000231895.jpg"
raw_image = Image.open(requests.get(image_url, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device)
text_ids, audio = model.generate(**inputs, max_new_tokens=100, disable_compile=True)
text = processor.batch_decode(
text_ids.sequences[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
print(text)
if audio is not None:
sf.write(
"sample_output.wav",
audio.reshape(-1).detach().cpu().numpy(),
samplerate=24000,
)
print("==========================================")

# Save to disk compressed.
modify_save_pretrained(model)
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
model.save_pretrained(SAVE_DIR, save_compressed=True)
processor.save_pretrained(SAVE_DIR)
72 changes: 72 additions & 0 deletions examples/multimodal_vision/qwen3_omni_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# flake8: noqa
# ruff: noqa

import torch
from compressed_tensors import get_execution_device


def fast_pos_embed_interpolate(self, grid_thw):
grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]

idx_list = [[] for _ in range(4)]
weight_list = [[] for _ in range(4)]

for t, h, w in zip(grid_ts, grid_hs, grid_ws):
h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h)
w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w)

h_idxs_floor = h_idxs.int()
w_idxs_floor = w_idxs.int()
h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)
w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1)

dh = h_idxs - h_idxs_floor
dw = w_idxs - w_idxs_floor

base_h = h_idxs_floor * self.num_grid_per_side
base_h_ceil = h_idxs_ceil * self.num_grid_per_side

indices = [
(base_h[None].T + w_idxs_floor[None]).flatten(),
(base_h[None].T + w_idxs_ceil[None]).flatten(),
(base_h_ceil[None].T + w_idxs_floor[None]).flatten(),
(base_h_ceil[None].T + w_idxs_ceil[None]).flatten(),
]

weights = [
((1 - dh)[None].T * (1 - dw)[None]).flatten(),
((1 - dh)[None].T * dw[None]).flatten(),
(dh[None].T * (1 - dw)[None]).flatten(),
(dh[None].T * dw[None]).flatten(),
]

for i in range(4):
idx_list[i].extend(indices[i].tolist())
weight_list[i].extend(weights[i].tolist())

# PATCH: do not rely on `pos_embed.weight`, which may be offloaded
device = get_execution_device(self.pos_embed)

idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device)
weight_tensor = torch.tensor(
weight_list, dtype=self.pos_embed.weight.dtype, device=device
)
pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]

patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)])

patch_pos_embeds_permute = []
merge_size = self.config.spatial_merge_size
for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws):
pos_embed = pos_embed.repeat(t, 1)
pos_embed = (
pos_embed.view(
t, h // merge_size, merge_size, w // merge_size, merge_size, -1
)
.permute(0, 1, 3, 2, 4, 5)
.flatten(0, 4)
)
patch_pos_embeds_permute.append(pos_embed)
patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
return patch_pos_embeds