Skip to content

update: Update VLMs Data Collator #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions unsloth_zoo/vision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def __call__(self, examples):
# The issue is batch = self.processor( forces tensors to be returned and not None.
texts = []
images = []
has_images = False

if self.formatting_func is not None:
examples = [self.formatting_func(example) for example in examples]
Expand All @@ -274,38 +275,57 @@ def __call__(self, examples):
add_generation_prompt = False,
)
# Dataset with 2 columns messages / images
if "images" in example:
image = None
if "images" in example and example["images"]:
image = example["images"][0]
else:
image, video = process_vision_info(messages)
texts .append(message)
has_images = True
elif isinstance(messages, (list, dict)):
try:
image_list, _ = process_vision_info([messages])
if image_list:
image = image_list[0]
has_images = True
except:
pass
texts.append(message)
images.append(image)
pass

# Tokenize the texts and process the images
batch = self.processor(
text = texts,
images = images,
padding = True,
# [TODO] Truncating to max_seq_length does NOT work for VLMs
# truncation = True,
return_tensors = "pt",
)
if has_images:
batch = self.processor(
text = texts,
images = images,
padding = True,
# [TODO] Truncating to max_seq_length does NOT work for VLMs
# truncation = True,
return_tensors = "pt",
)
else:
# Text-only processing - more efficient as it skips image processing
batch = self.processor(
text = texts,
padding = True,
# [TODO] Truncating to max_seq_length does NOT work for VLMs
# truncation = True,
return_tensors = "pt",
)
batch.pop("token_type_ids", None)

# Pixtral accepts multiple images, so we have to cast it individually
pixel_values = batch["pixel_values"]
if type(pixel_values) is list:
for j, pixel_value_j in enumerate(pixel_values):
if type(pixel_value_j) is list:
for k, pixel_value_k in enumerate(pixel_value_j):
pixel_value_j[k] = pixel_value_k.to(self.dtype)
else:
pixel_values[j] = pixel_value_j.to(self.dtype)
pass
batch["pixel_values"] = pixel_values
else:
batch["pixel_values"] = batch["pixel_values"].to(self.dtype)
if "pixel_values" in batch:
pixel_values = batch["pixel_values"]
if type(pixel_values) is list:
for j, pixel_value_j in enumerate(pixel_values):
if type(pixel_value_j) is list:
for k, pixel_value_k in enumerate(pixel_value_j):
pixel_value_j[k] = pixel_value_k.to(self.dtype)
else:
pixel_values[j] = pixel_value_j.to(self.dtype)
pass
batch["pixel_values"] = pixel_values
else:
batch["pixel_values"] = batch["pixel_values"].to(self.dtype)
pass

# Mask image tokens and pad tokens
Expand Down