Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 85 additions & 50 deletions src/axolotl/processing_strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types"""

import ast
from copy import deepcopy
from typing import Optional

Expand Down Expand Up @@ -75,6 +76,49 @@ def convert_legacy_format(example: dict) -> dict:
result["messages"] = messages
return result

def convert_multiple_choice_to_multimedia_messages(
messages: dict,
) -> list[dict]:

def construct_prompt(sample):
question = sample["question"]
options = sample["options"]
if isinstance(options, str):
options = ast.literal_eval(options)

example = ""
start_chr = "A"
prediction_range = []
index2ans = {}
for option in options:
prediction_range.append(start_chr)
example += f"({start_chr}) {option}\n"
index2ans[start_chr] = option
start_chr = chr(ord(start_chr) + 1)

empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly."
empty_prompt = empty_prompt_sample_structure.format(question, example)

return empty_prompt

new_messages = []

user_content = construct_prompt(messages)
assistant_response = messages["answer"]

new_messages.append(
{"role": "user", "content": [{"type": "text", "text": user_content}]}
)

new_messages.append(
{
"role": "assistant",
"content": [{"type": "text", "text": assistant_response}],
}
)

return new_messages

def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
"""Convert regular messages format to Messages format with content type"""

Expand Down Expand Up @@ -106,39 +150,51 @@ def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:

processed_examples = []
for example in examples:
if not ("messages" in example or "conversations" in example):
if not (
"messages" in example
or "conversations" in example
or "question" in example
):
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
"Only `messages`, `conversations`, and `question` message keys are currently supported."
)

processed_example = None
if "messages" in example: # OpenAI format
processed_example = example
# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
elif "question" in example: # Multiple choice format
processed_example = {}
processed_example["messages"] = (
convert_multiple_choice_to_multimedia_messages(example)
)
else: # Legacy format
processed_example = convert_legacy_format(example)

# convert regular messages format to Messages format with content type
# for compatibility with apply_chat_template
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)
processed_example["messages"] = convert_messages_to_multimedia_messages(
processed_example["messages"]
)

# find the image key if it exists
possible_image_keys = ["images", "image"]
image_key = None
for key in possible_image_keys:
if key in processed_example:
image_key = key
break

# if the image key exists, add the image to the first message
if image_key is not None:
# TODO: check if it's normal to be single image only for common datasets
# From observation, it's usually a list of single image but some datasets may have several columns for images
# Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages
image_value = processed_example[image_key][0]

# Handle image loading (Image, url, path, base64)

image_keys = []
for key in example.keys():
if "image" in key:
image_keys.append(key)

for im_key in image_keys:
if example[im_key] is None:
continue
if isinstance(example[im_key], list):
if len(example[im_key]) == 0:
continue
image_value = example[im_key][0]
else:
image_value = example[im_key]

image_value = load_image(image_value)

if self.image_size is not None:
Expand All @@ -163,33 +219,12 @@ def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]:
color=padding_color,
)

# Look for any image type in the first message
# some dataset have an {type: "image"} in the first message
ind_to_add = None

for i, content in enumerate(
processed_example["messages"][0]["content"]
):
# Usually datasets created with image columns, don't have it in the messages itself
if content["type"] == "image" and all(
k not in content for k in ["image", "url", "path", "base64"]
):
ind_to_add = i
break

# If an image type is found, add the image to that index
if ind_to_add is not None:
processed_example["messages"][0]["content"][ind_to_add][
"image"
] = image_value
else:
# if no image type is found, add it to end of the first message
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)
processed_example["messages"][0]["content"].append(
{
"type": "image",
"image": image_value,
}
)

processed_examples.append(processed_example)

Expand Down
Loading