-
Notifications
You must be signed in to change notification settings - Fork 11
Description
import re
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info # From the GRIT repo
Load model and processor
model_id = "yfan1997/GRIT-20-Qwen2.5-VL-3B" # or your local checkpoint path
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype="bfloat16",
device_map={"": 0},
attn_implementation="flash_attention_2",
).eval()
processor = AutoProcessor.from_pretrained(model_id)
Prepare input
image_path = "path/to/your/image.jpg"
query = "Ask a question here."
Format prompt with GRIT thinking structure
prompt_suffix = (
" First, think between and while output necessary "
"coordinates needed to answer the question in JSON with key 'bbox_2d'. "
"Then, based on the thinking contents and coordinates, rethink between "
" and then answer the question after .\n"
)
Create messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": f"Question: {query}{prompt_suffix}"},
],
}
]
Apply chat template
chat_text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
Process inputs
img_inputs, vid_inputs = process_vision_info(messages)
inputs = processor(
text=[chat_text],
images=img_inputs,
videos=vid_inputs,
padding=True,
return_tensors="pt",
).to(model.device)
Run inference
generation_config = model.generation_config
generation_config.max_new_tokens = 1024
generation_config.temperature = 0.001
generation_config.top_k = 1
generation_config.top_p = 0.0
with torch.inference_mode():
gen_ids = model.generate(**inputs, generation_config=generation_config)
output = processor.batch_decode(
gen_ids[:, inputs.input_ids.shape[1]:],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
Parse bounding boxes
bbox_regex = re.compile(r"\b\d+,\s*\d+,\s*\d+,\s*\d+\b")
bboxes = []
for match in bbox_regex.findall(output):
try:
x1, y1, x2, y2 = map(int, match.split(","))
bboxes.append((x1, y1, x2, y2))
except ValueError:
pass
print(f"Output: {output}")
print(f"Detected bounding boxes: {bboxes}")