diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index 0b854af8dc..a7fa533f58 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -1,5 +1,6 @@ """Module containing ProcessingStrategy classes and its derivative for different MultiModal Model types""" +import ast from copy import deepcopy from typing import Optional @@ -75,6 +76,49 @@ def convert_legacy_format(example: dict) -> dict: result["messages"] = messages return result + def convert_multiple_choice_to_multimedia_messages( + messages: dict, + ) -> list[dict]: + + def construct_prompt(sample): + question = sample["question"] + options = sample["options"] + if isinstance(options, str): + options = ast.literal_eval(options) + + example = "" + start_chr = "A" + prediction_range = [] + index2ans = {} + for option in options: + prediction_range.append(start_chr) + example += f"({start_chr}) {option}\n" + index2ans[start_chr] = option + start_chr = chr(ord(start_chr) + 1) + + empty_prompt_sample_structure = "{}\n\n{}\n\nAnswer with the option's letter from the given choices directly." + empty_prompt = empty_prompt_sample_structure.format(question, example) + + return empty_prompt + + new_messages = [] + + user_content = construct_prompt(messages) + assistant_response = messages["answer"] + + new_messages.append( + {"role": "user", "content": [{"type": "text", "text": user_content}]} + ) + + new_messages.append( + { + "role": "assistant", + "content": [{"type": "text", "text": assistant_response}], + } + ) + + return new_messages + def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: """Convert regular messages format to Messages format with content type""" @@ -106,39 +150,51 @@ def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: processed_examples = [] for example in examples: - if not ("messages" in example or "conversations" in example): + if not ( + "messages" in example + or "conversations" in example + or "question" in example + ): raise ValueError( - "Only `messages` and `conversations` message keys are currently supported." + "Only `messages`, `conversations`, and `question` message keys are currently supported." ) processed_example = None if "messages" in example: # OpenAI format processed_example = example + # convert regular messages format to Messages format with content type + # for compatibility with apply_chat_template + processed_example["messages"] = convert_messages_to_multimedia_messages( + processed_example["messages"] + ) + elif "question" in example: # Multiple choice format + processed_example = {} + processed_example["messages"] = ( + convert_multiple_choice_to_multimedia_messages(example) + ) else: # Legacy format processed_example = convert_legacy_format(example) - - # convert regular messages format to Messages format with content type - # for compatibility with apply_chat_template - processed_example["messages"] = convert_messages_to_multimedia_messages( - processed_example["messages"] - ) + processed_example["messages"] = convert_messages_to_multimedia_messages( + processed_example["messages"] + ) # find the image key if it exists - possible_image_keys = ["images", "image"] - image_key = None - for key in possible_image_keys: - if key in processed_example: - image_key = key - break - - # if the image key exists, add the image to the first message - if image_key is not None: - # TODO: check if it's normal to be single image only for common datasets - # From observation, it's usually a list of single image but some datasets may have several columns for images - # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages - image_value = processed_example[image_key][0] - - # Handle image loading (Image, url, path, base64) + + image_keys = [] + for key in example.keys(): + if "image" in key: + image_keys.append(key) + + for im_key in image_keys: + if example[im_key] is None: + continue + if isinstance(example[im_key], list): + if len(example[im_key]) == 0: + continue + image_value = example[im_key][0] + else: + image_value = example[im_key] + image_value = load_image(image_value) if self.image_size is not None: @@ -163,33 +219,12 @@ def convert_messages_to_multimedia_messages(messages: list[dict]) -> list[dict]: color=padding_color, ) - # Look for any image type in the first message - # some dataset have an {type: "image"} in the first message - ind_to_add = None - - for i, content in enumerate( - processed_example["messages"][0]["content"] - ): - # Usually datasets created with image columns, don't have it in the messages itself - if content["type"] == "image" and all( - k not in content for k in ["image", "url", "path", "base64"] - ): - ind_to_add = i - break - - # If an image type is found, add the image to that index - if ind_to_add is not None: - processed_example["messages"][0]["content"][ind_to_add][ - "image" - ] = image_value - else: - # if no image type is found, add it to end of the first message - processed_example["messages"][0]["content"].append( - { - "type": "image", - "image": image_value, - } - ) + processed_example["messages"][0]["content"].append( + { + "type": "image", + "image": image_value, + } + ) processed_examples.append(processed_example)