Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/quickcheck.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install -e .[test]
python -m pip install pytest-xdist
python -m pip install qwen-vl-utils
- name: Run Quickcheck
run: python -m pytest -q tests/unit_test/models/test_model_quickcheck.py -n auto
223 changes: 222 additions & 1 deletion QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.generation.cloud_infer import QAICInferenceSession
Expand All @@ -33,6 +34,7 @@
QEffTextGenerationBase,
TextGeneration,
calculate_latency,
get_compilation_dims,
write_io_files,
)
from QEfficient.utils import LRUCache
Expand Down Expand Up @@ -467,7 +469,15 @@ def _prepare_vision_language_prompt(self, text_prompt, image_path):
return text_prompt

def generate(
self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs
self,
images: List[str],
prompts: List[str],
inputs: torch.Tensor = None,
num_frames: Optional[int] = None,
multi_specs: Optional[bool] = None,
generation_len: Optional[int] = None,
stream: bool = True,
**kwargs,
) -> CloudAI100ExecInfo:
"""
Main generation method maintaining API compatibility with VisionLanguageGeneration
Expand All @@ -485,6 +495,9 @@ def generate(
Raises:
ValueError: If images and prompts lengths don't match
"""

if num_frames or multi_specs:
return self._generate_multi_frame_specialization(inputs, num_frames, generation_len)
if len(images) != len(prompts):
raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})")

Expand All @@ -504,6 +517,214 @@ def generate(
# Regular batching mode
return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs)

def run_prefill_multi_frame_specialization(
self, inputs: Optional[torch.Tensor], num_frames: Optional[int] = 1, generation_len: int = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a doc string for the method

):

if not self._qpc_path:
raise TypeError("Please run compile API for language model first!")

self._session.deactivate()
self._vision_session.activate()

if not num_frames:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to specific it as if num_frames ==0: num_frames = 1 and add a warning if that was the check. Since default value is set to 1 ideally we need not need this condition unless somewhere we are passing the value as none.

num_frames = 1

batch_size, ctx_len, fbs = get_compilation_dims(self._qpc_path)

pad_token_id = 1

# Skip inputs/outputs
self._session.skip_buffers(
[
x
for x in self._session.input_names + self._session.output_names
if x.startswith("past_") or x.endswith("_RetainedState")
]
)

prefill_seq_len = max(
[x[self._session.binding_index_map["input_ids"]][1][1] for x in self._session.allowed_shapes]
+ [self._session.bindings[self._session.binding_index_map["input_ids"]].dims[1]]
)

input_len = inputs["attention_mask"].sum(1, keepdims=True)
input_ids_length = inputs["input_ids"].shape[1]
num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float
padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len

if generation_len is None:
generation_len = ctx_len - input_len.max()
assert generation_len > 0, "generation length should be greater than zero"

inputs["input_ids"] = torch.nn.functional.pad(
inputs["input_ids"],
(0, padded_len - input_ids_length),
"constant",
pad_token_id,
)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0
)

for k, v in inputs.items():
inputs[k] = np.array(v)

vision_inputs = {
k: v
for k, v in inputs.items()
if k
in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"}
}

vision_inputs_fp16 = {"pixel_values", "image_masks"}
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})

vision_outputs = {}
if vision_inputs:
vision_size = vision_inputs["pixel_values"].shape[0] // num_frames

pixel_values_shape = list(vision_inputs["pixel_values"][:vision_size].shape)

idx = next(
i for i, inner in enumerate(self._vision_session.allowed_shapes) if (2, pixel_values_shape) in inner
)

buffer_set = {
"vision_embeds": np.zeros(
self._vision_session.allowed_shapes[idx][self._vision_session.binding_index_map["vision_embeds"]][
1
],
dtype=np.float16,
),
"image_grid_thw": np.zeros(
self._vision_session.allowed_shapes[idx][self._vision_session.binding_index_map["image_grid_thw"]][
1
],
dtype=np.int64,
),
}
if "deepstack_features" in self._vision_session.binding_index_map:
buffer_set["deepstack_features"] = np.zeros(
self._vision_session.allowed_shapes[idx][
self._vision_session.binding_index_map["deepstack_features"]
][1],
dtype=np.float16,
)

self._vision_session.set_buffers(buffer_set)

chunk_inputs = vision_inputs.copy()

for i in range(num_frames):
chunk_inputs["pixel_values"] = vision_inputs["pixel_values"][i * vision_size : (i + 1) * vision_size]
chunk_outputs = self._vision_session.run(chunk_inputs)
if i == 0:
vision_outputs = chunk_outputs
else:
vision_outputs["vision_embeds"] = np.concatenate(
(vision_outputs["vision_embeds"], chunk_outputs["vision_embeds"]), axis=1
)

vision_outputs["vision_embeds"] = np.pad(
vision_outputs["vision_embeds"],
pad_width=(
(0, 0),
(0, self._session.allowed_shapes[0][1][1][1] - vision_outputs["vision_embeds"].shape[-2]),
(0, 0),
), # pad axis=1 only
mode="constant",
constant_values=0,
)
if "deepstack_features" in vision_outputs:
vision_outputs["deepstack_features"] = np.pad(
vision_outputs["deepstack_features"],
pad_width=(
(0, 0),
(0, 0),
(0, self._session.allowed_shapes[0][1][1][1] - vision_outputs["deepstack_features"].shape[-2]),
(0, 0),
), # pad axis=1 only
mode="constant",
constant_values=0,
)

lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
lang_inputs.pop("attention_mask")

if self._vision_qpc_path:
self._vision_session.deactivate()

self._session.activate()

self._session.set_buffers(vision_outputs)
logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}")
self._vision_processed = True
self._vision_outputs = vision_outputs

# Calculate generation_len consistent with ctx_len
max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max()
generation_len = self._fetch_generation_len(generation_len, max_gen_len)

# Execute chunked prefill
outputs = self._execute_chunked_prefill(lang_inputs, num_chunks)

self._session.skip_buffers(vision_outputs)

# Prepare position_ids for decode phase (next position after prefill)
position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1

return outputs, position_ids_decode, generation_len

def _generate_multi_frame_specialization(
self,
inputs: Optional[torch.Tensor],
num_frames: Optional[int] = 1,
generation_len: int = None,
stream: List[str] = None,
):

exec_batch_size = self.batch_size
max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len)
self.initialize_decode_inputs(
num_prompts=1, execution_batch_size=exec_batch_size, max_gen_length=max_gen_length
)

if self.is_qwen_vl:
self.decode_pos_ids = np.zeros((4, exec_batch_size, 1), np.int64)

# Prefill using VLM-aware run_prefill (batch is a list of (image, text))
start = perf_counter()
outputs, position_ids, generation_len_final = self.run_prefill_multi_frame_specialization(
inputs, num_frames, generation_len
)
self.update_decode_input(outputs, position_ids, generation_len_final)

# Prepare decode
decode_inputs = self.prepare_decode_inputs()

# Decode loop
loop_start = perf_counter()
num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None)
end = perf_counter()

# Decode generated texts
generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True)

# Latency metrics
total_decode_tokens = num_token
prefill_time, decode_perf, total_perf, total_time = calculate_latency(
total_decode_tokens, loop_start, start, end
)
perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time)

return CloudAI100ExecInfo(
batch_size=self.batch_size,
generated_texts=generated_texts,
generated_ids=self.generated_ids,
perf_metrics=perf_metrics,
)

def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs):
"""Handle regular batching for vision-language generation without creating a second language session"""
batch_results = []
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,8 @@ def generate(
generation_len: Optional[int] = None,
image_height: Optional[int] = None,
image_width: Optional[int] = None,
multi_specs: Optional[bool] = None,
num_frames: Optional[int] = None,
**kwargs,
) -> Union[torch.Tensor, np.ndarray]:
"""
Expand Down Expand Up @@ -1651,7 +1653,7 @@ def generate(
self._write_io_dir = os.path.join(os.path.dirname(self.onnx_path[1]), "io_dir") if write_io else None

# Use VisionLanguageGeneration for image-prompt pairs
if (processor and images) or (tokenizer and prompts):
if (processor and images) or (tokenizer and prompts) or multi_specs or num_frames:
# Create VisionLanguageGeneration instance
batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path)
vlm_gen = VisionLanguageGeneration(
Expand All @@ -1673,6 +1675,9 @@ def generate(

# Call generate method
return vlm_gen.generate(
inputs=inputs,
num_frames=num_frames,
multi_specs=multi_specs,
images=images,
prompts=prompts,
generation_len=generation_len,
Expand Down
Loading
Loading