-
Notifications
You must be signed in to change notification settings - Fork 77
Adding multi_specializations_frames #909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
387e6fa
Adding multi_specialization
55375aa
Adding qwen3vl multi specialization
eb141c0
Adding qwen3_vl_moe_multispecs
cfc9adf
Renaming Folder Name
d7c4a82
Comments Addressed
ddb3934
Minor fix
6113355
Minor Fixes
9da2c79
qwen3vl_moe_changes
c90cacf
Adding quickcheck
d50cf6d
Adding qwen-vl-utils in project.toml
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast | ||
|
|
||
| from QEfficient.generation.cloud_infer import QAICInferenceSession | ||
|
|
@@ -33,6 +34,7 @@ | |
| QEffTextGenerationBase, | ||
| TextGeneration, | ||
| calculate_latency, | ||
| get_compilation_dims, | ||
| write_io_files, | ||
| ) | ||
| from QEfficient.utils import LRUCache | ||
|
|
@@ -467,7 +469,15 @@ def _prepare_vision_language_prompt(self, text_prompt, image_path): | |
| return text_prompt | ||
|
|
||
| def generate( | ||
| self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs | ||
| self, | ||
| images: List[str], | ||
| prompts: List[str], | ||
| inputs: torch.Tensor = None, | ||
| num_frames: Optional[int] = None, | ||
| multi_specs: Optional[bool] = None, | ||
| generation_len: Optional[int] = None, | ||
| stream: bool = True, | ||
| **kwargs, | ||
| ) -> CloudAI100ExecInfo: | ||
| """ | ||
| Main generation method maintaining API compatibility with VisionLanguageGeneration | ||
|
|
@@ -485,6 +495,9 @@ def generate( | |
| Raises: | ||
| ValueError: If images and prompts lengths don't match | ||
| """ | ||
|
|
||
| if num_frames or multi_specs: | ||
| return self._generate_multi_frame_specialization(inputs, num_frames, generation_len) | ||
| if len(images) != len(prompts): | ||
| raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") | ||
|
|
||
|
|
@@ -504,6 +517,214 @@ def generate( | |
| # Regular batching mode | ||
| return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) | ||
|
|
||
| def run_prefill_multi_frame_specialization( | ||
| self, inputs: Optional[torch.Tensor], num_frames: Optional[int] = 1, generation_len: int = None | ||
| ): | ||
|
|
||
| if not self._qpc_path: | ||
| raise TypeError("Please run compile API for language model first!") | ||
|
|
||
| self._session.deactivate() | ||
| self._vision_session.activate() | ||
|
|
||
| if not num_frames: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: better to specific it as if num_frames ==0: num_frames = 1 and add a warning if that was the check. Since default value is set to 1 ideally we need not need this condition unless somewhere we are passing the value as none. |
||
| num_frames = 1 | ||
|
|
||
| batch_size, ctx_len, fbs = get_compilation_dims(self._qpc_path) | ||
|
|
||
| pad_token_id = 1 | ||
|
|
||
| # Skip inputs/outputs | ||
| self._session.skip_buffers( | ||
| [ | ||
| x | ||
| for x in self._session.input_names + self._session.output_names | ||
| if x.startswith("past_") or x.endswith("_RetainedState") | ||
| ] | ||
| ) | ||
|
|
||
| prefill_seq_len = max( | ||
| [x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes] | ||
| + [self._session.bindings[self._session.binding_index_map["input_ids"]].dims[1]] | ||
| ) | ||
|
|
||
| input_len = inputs["attention_mask"].sum(1, keepdims=True) | ||
| input_ids_length = inputs["input_ids"].shape[1] | ||
| num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float | ||
| padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len | ||
|
|
||
| if generation_len is None: | ||
| generation_len = ctx_len - input_len.max() | ||
| assert generation_len > 0, "generation length should be greater than zero" | ||
|
|
||
| inputs["input_ids"] = torch.nn.functional.pad( | ||
| inputs["input_ids"], | ||
| (0, padded_len - input_ids_length), | ||
| "constant", | ||
| pad_token_id, | ||
| ) | ||
| inputs["attention_mask"] = torch.nn.functional.pad( | ||
| inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 | ||
| ) | ||
|
|
||
| for k, v in inputs.items(): | ||
| inputs[k] = np.array(v) | ||
|
|
||
| vision_inputs = { | ||
| k: v | ||
| for k, v in inputs.items() | ||
| if k | ||
| in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} | ||
| } | ||
|
|
||
| vision_inputs_fp16 = {"pixel_values", "image_masks"} | ||
| vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs}) | ||
|
|
||
| vision_outputs = {} | ||
| if vision_inputs: | ||
| vision_size = vision_inputs["pixel_values"].shape[0] // num_frames | ||
|
|
||
| pixel_values_shape = list(vision_inputs["pixel_values"][:vision_size].shape) | ||
|
|
||
| idx = next( | ||
| i for i, inner in enumerate(self._vision_session.allowed_shapes) if (2, pixel_values_shape) in inner | ||
| ) | ||
|
|
||
| buffer_set = { | ||
| "vision_embeds": np.zeros( | ||
| self._vision_session.allowed_shapes[idx][self._vision_session.binding_index_map["vision_embeds"]][ | ||
| 1 | ||
| ], | ||
| dtype=np.float16, | ||
| ), | ||
| "image_grid_thw": np.zeros( | ||
| self._vision_session.allowed_shapes[idx][self._vision_session.binding_index_map["image_grid_thw"]][ | ||
| 1 | ||
| ], | ||
| dtype=np.int64, | ||
| ), | ||
| } | ||
| if "deepstack_features" in self._vision_session.binding_index_map: | ||
| buffer_set["deepstack_features"] = np.zeros( | ||
| self._vision_session.allowed_shapes[idx][ | ||
| self._vision_session.binding_index_map["deepstack_features"] | ||
| ][1], | ||
| dtype=np.float16, | ||
| ) | ||
|
|
||
| self._vision_session.set_buffers(buffer_set) | ||
|
|
||
| chunk_inputs = vision_inputs.copy() | ||
|
|
||
| for i in range(num_frames): | ||
| chunk_inputs["pixel_values"] = vision_inputs["pixel_values"][i * vision_size : (i + 1) * vision_size] | ||
| chunk_outputs = self._vision_session.run(chunk_inputs) | ||
| if i == 0: | ||
| vision_outputs = chunk_outputs | ||
| else: | ||
| vision_outputs["vision_embeds"] = np.concatenate( | ||
| (vision_outputs["vision_embeds"], chunk_outputs["vision_embeds"]), axis=1 | ||
| ) | ||
|
|
||
| vision_outputs["vision_embeds"] = np.pad( | ||
| vision_outputs["vision_embeds"], | ||
| pad_width=( | ||
| (0, 0), | ||
| (0, self._session.allowed_shapes[0][1][1][1] - vision_outputs["vision_embeds"].shape[-2]), | ||
| (0, 0), | ||
| ), # pad axis=1 only | ||
| mode="constant", | ||
| constant_values=0, | ||
| ) | ||
| if "deepstack_features" in vision_outputs: | ||
| vision_outputs["deepstack_features"] = np.pad( | ||
| vision_outputs["deepstack_features"], | ||
| pad_width=( | ||
| (0, 0), | ||
| (0, 0), | ||
| (0, self._session.allowed_shapes[0][1][1][1] - vision_outputs["deepstack_features"].shape[-2]), | ||
| (0, 0), | ||
| ), # pad axis=1 only | ||
| mode="constant", | ||
| constant_values=0, | ||
| ) | ||
|
|
||
| lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} | ||
| lang_inputs.pop("attention_mask") | ||
|
|
||
| if self._vision_qpc_path: | ||
| self._vision_session.deactivate() | ||
|
|
||
| self._session.activate() | ||
|
|
||
| self._session.set_buffers(vision_outputs) | ||
| logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") | ||
| self._vision_processed = True | ||
| self._vision_outputs = vision_outputs | ||
|
|
||
| # Calculate generation_len consistent with ctx_len | ||
| max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() | ||
| generation_len = self._fetch_generation_len(generation_len, max_gen_len) | ||
|
|
||
| # Execute chunked prefill | ||
| outputs = self._execute_chunked_prefill(lang_inputs, num_chunks) | ||
|
|
||
| self._session.skip_buffers(vision_outputs) | ||
|
|
||
| # Prepare position_ids for decode phase (next position after prefill) | ||
| position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 | ||
|
|
||
| return outputs, position_ids_decode, generation_len | ||
|
|
||
| def _generate_multi_frame_specialization( | ||
| self, | ||
| inputs: Optional[torch.Tensor], | ||
| num_frames: Optional[int] = 1, | ||
| generation_len: int = None, | ||
| stream: List[str] = None, | ||
| ): | ||
|
|
||
| exec_batch_size = self.batch_size | ||
| max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) | ||
| self.initialize_decode_inputs( | ||
| num_prompts=1, execution_batch_size=exec_batch_size, max_gen_length=max_gen_length | ||
| ) | ||
|
|
||
| if self.is_qwen_vl: | ||
| self.decode_pos_ids = np.zeros((4, exec_batch_size, 1), np.int64) | ||
|
|
||
| # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) | ||
| start = perf_counter() | ||
| outputs, position_ids, generation_len_final = self.run_prefill_multi_frame_specialization( | ||
| inputs, num_frames, generation_len | ||
| ) | ||
| self.update_decode_input(outputs, position_ids, generation_len_final) | ||
|
|
||
| # Prepare decode | ||
| decode_inputs = self.prepare_decode_inputs() | ||
|
|
||
| # Decode loop | ||
| loop_start = perf_counter() | ||
| num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None) | ||
| end = perf_counter() | ||
|
|
||
| # Decode generated texts | ||
| generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) | ||
|
|
||
| # Latency metrics | ||
| total_decode_tokens = num_token | ||
| prefill_time, decode_perf, total_perf, total_time = calculate_latency( | ||
| total_decode_tokens, loop_start, start, end | ||
| ) | ||
| perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) | ||
|
|
||
| return CloudAI100ExecInfo( | ||
| batch_size=self.batch_size, | ||
| generated_texts=generated_texts, | ||
| generated_ids=self.generated_ids, | ||
| perf_metrics=perf_metrics, | ||
| ) | ||
|
|
||
| def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): | ||
| """Handle regular batching for vision-language generation without creating a second language session""" | ||
| batch_results = [] | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add a doc string for the method