-
Notifications
You must be signed in to change notification settings - Fork 320
[Examples] QwenOmni Example #2125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kylesayrs
wants to merge
3
commits into
main
Choose a base branch
from
kylesayrs/qwen_omni
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+178
−0
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| import requests | ||
| import soundfile as sf | ||
| from PIL import Image | ||
| from qwen3_omni_patch import fast_pos_embed_interpolate | ||
| from transformers import ( | ||
| AutoProcessor, | ||
| Qwen3OmniMoeForConditionalGeneration, | ||
| default_data_collator, | ||
| ) | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.quantization import GPTQModifier | ||
| from llmcompressor.transformers.compression.compressed_tensors_utils import ( | ||
| modify_save_pretrained, | ||
| ) | ||
| from llmcompressor.utils import dispatch_for_generation | ||
|
|
||
| # Load model. | ||
| model_id = "Qwen/Qwen3-Omni-30B-A3B-Instruct" | ||
| model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( | ||
| model_id, torch_dtype="auto" | ||
| ) | ||
| processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | ||
|
|
||
| # Apply patch | ||
| model.thinker.visual.fast_pos_embed_interpolate = fast_pos_embed_interpolate.__get__( | ||
| model.thinker.visual | ||
| ) | ||
|
|
||
| # Oneshot arguments | ||
| BATCH_SIZE = 1 | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
| DATASET_ID = "flickr30k" | ||
| DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} | ||
|
|
||
| # Recipe | ||
| recipe = [ | ||
| GPTQModifier( | ||
| targets="Linear", | ||
| scheme="W4A16", | ||
| ignore=[ | ||
| "lm_head", | ||
| r"re:.*visual.*", | ||
| r"re:.*code2wav.*", | ||
| ], | ||
| ), | ||
| ] | ||
|
|
||
|
|
||
| def data_collator(features): | ||
| batch = default_data_collator(features) | ||
| batch["image_grid_thw"] = batch["image_grid_thw"].squeeze(0) | ||
| return batch | ||
|
|
||
|
|
||
| # Perform oneshot | ||
| oneshot( | ||
| model=model.thinker, # base model does not define forward: pass `thinker` instead | ||
| processor=processor, | ||
| dataset=DATASET_ID, | ||
| splits=DATASET_SPLIT, | ||
| recipe=recipe, | ||
| batch_size=BATCH_SIZE, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| data_collator=data_collator, | ||
| ) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": [ | ||
| {"type": "text", "text": "Please describe the animal in this image\n"}, | ||
| {"type": "image"}, | ||
| ], | ||
| }, | ||
| ] | ||
| prompt = processor.apply_chat_template(messages, add_generation_prompt=True) | ||
| image_url = "http://images.cocodataset.org/train2017/000000231895.jpg" | ||
| raw_image = Image.open(requests.get(image_url, stream=True).raw) | ||
|
|
||
| inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to(model.device) | ||
| text_ids, audio = model.generate(**inputs, max_new_tokens=100, disable_compile=True) | ||
| text = processor.batch_decode( | ||
| text_ids.sequences[:, inputs["input_ids"].shape[1] :], | ||
| skip_special_tokens=True, | ||
| clean_up_tokenization_spaces=False, | ||
| ) | ||
| print(text) | ||
| if audio is not None: | ||
| sf.write( | ||
| "sample_output.wav", | ||
| audio.reshape(-1).detach().cpu().numpy(), | ||
| samplerate=24000, | ||
| ) | ||
| print("==========================================") | ||
|
|
||
| # Save to disk compressed. | ||
| modify_save_pretrained(model) | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| processor.save_pretrained(SAVE_DIR) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| # flake8: noqa | ||
| # ruff: noqa | ||
|
|
||
| import torch | ||
| from compressed_tensors import get_execution_device | ||
|
|
||
|
|
||
| def fast_pos_embed_interpolate(self, grid_thw): | ||
| grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] | ||
|
|
||
| idx_list = [[] for _ in range(4)] | ||
| weight_list = [[] for _ in range(4)] | ||
|
|
||
| for t, h, w in zip(grid_ts, grid_hs, grid_ws): | ||
| h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) | ||
| w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) | ||
|
|
||
| h_idxs_floor = h_idxs.int() | ||
| w_idxs_floor = w_idxs.int() | ||
| h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) | ||
| w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) | ||
|
|
||
| dh = h_idxs - h_idxs_floor | ||
| dw = w_idxs - w_idxs_floor | ||
|
|
||
| base_h = h_idxs_floor * self.num_grid_per_side | ||
| base_h_ceil = h_idxs_ceil * self.num_grid_per_side | ||
|
|
||
| indices = [ | ||
| (base_h[None].T + w_idxs_floor[None]).flatten(), | ||
| (base_h[None].T + w_idxs_ceil[None]).flatten(), | ||
| (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), | ||
| (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), | ||
| ] | ||
|
|
||
| weights = [ | ||
| ((1 - dh)[None].T * (1 - dw)[None]).flatten(), | ||
| ((1 - dh)[None].T * dw[None]).flatten(), | ||
| (dh[None].T * (1 - dw)[None]).flatten(), | ||
| (dh[None].T * dw[None]).flatten(), | ||
| ] | ||
|
|
||
| for i in range(4): | ||
| idx_list[i].extend(indices[i].tolist()) | ||
| weight_list[i].extend(weights[i].tolist()) | ||
|
|
||
| # PATCH: do not rely on `pos_embed.weight`, which may be offloaded | ||
| device = get_execution_device(self.pos_embed) | ||
|
|
||
| idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) | ||
| weight_tensor = torch.tensor( | ||
| weight_list, dtype=self.pos_embed.weight.dtype, device=device | ||
| ) | ||
kylesayrs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] | ||
| patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] | ||
|
|
||
| patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) | ||
|
|
||
| patch_pos_embeds_permute = [] | ||
| merge_size = self.config.spatial_merge_size | ||
| for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): | ||
| pos_embed = pos_embed.repeat(t, 1) | ||
| pos_embed = ( | ||
| pos_embed.view( | ||
| t, h // merge_size, merge_size, w // merge_size, merge_size, -1 | ||
| ) | ||
| .permute(0, 1, 3, 2, 4, 5) | ||
| .flatten(0, 4) | ||
| ) | ||
| patch_pos_embeds_permute.append(pos_embed) | ||
| patch_pos_embeds = torch.cat(patch_pos_embeds_permute) | ||
| return patch_pos_embeds | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.