diff --git a/eureka_ml_insights/data_utils/__init__.py b/eureka_ml_insights/data_utils/__init__.py index d8cbe50..0d30871 100644 --- a/eureka_ml_insights/data_utils/__init__.py +++ b/eureka_ml_insights/data_utils/__init__.py @@ -15,8 +15,8 @@ from .prompt_processing import JinjaPromptTemplate from .spatial_utils import ( ExtractAnswerGrid, - ExtractAnswerMaze, - ExtractAnswerSpatialMap, + ExtractAnswerSpatialMapAndMaze, + ExtractQuestionOptions, ) from .transform import ( AddColumn, @@ -69,8 +69,8 @@ ASTEvalTransform, PrependStringTransform, ExtractAnswerGrid, - ExtractAnswerSpatialMap, - ExtractAnswerMaze, + ExtractAnswerSpatialMapAndMaze, + ExtractQuestionOptions, ShuffleColumnsTransform, ColumnMatchMapTransform, TokenCounterTransform, diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 8a579bf..f2907d2 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -157,25 +157,34 @@ def extract_answer_from_text_grid(text, question_type): return None # Return None if no numbers are found -def extract_answer_from_text_map(text, question_type, model_name): +def extract_answer_from_text_map_and_maze(model_output_raw, options): """ - Extracts the answer from the text based on specific patterns, - and as a fallback, extracts the first number if no patterns match. - The code is from: https://github.com/alvinmingwisc/spatial_reason_vlm/tree/main/eval, - and included with minimal modifications. + Extracts the answer from the text based on known model output patterns. + Searches for both a letter and whole word answer and returns both as they are not + always consistent. Args: - - text (str): The text containing the model's answer. - - question_type (str): The text containing the question type. - - model_name (str): The model name. + - model_output_raw (str): The text containing the model's answer. + - options (str): The list of options. Returns: - - str or None: The extracted answer, or None if no answer could be extracted. + - str or None: The extracted answers, or empty strings if no answer could be extracted. """ - # Mapping of textual numbers to their numeric equivalents + + # replace common subsitutions in model outputs + + model_output_parsed_letter = "" + model_output_parsed = "" + + if not model_output_raw: + return [model_output_parsed, model_output_parsed_letter] + + model_output_raw = re.sub(r"\bno objects\b", "0 objects", model_output_raw, re.IGNORECASE) + model_output_raw = re.sub(r"\bnot\b", "no", model_output_raw, re.IGNORECASE) + model_output_raw = re.sub(r"\bshould be\b", "is", model_output_raw, re.IGNORECASE) + number_mapping = { - "zero": 0, - "no": 0, + "zero": 0, "one": 1, "two": 2, "three": 3, @@ -187,127 +196,71 @@ def extract_answer_from_text_map(text, question_type, model_name): "nine": 9, } - dirs = ["southeast", "northeast", "northwest", "southwest"] - dir_pattern = rf"\b({'|'.join(dirs)})\b" - - if text is None: - return None - - question_id = int(re.search("[0-9]", re.search("Q[0-9]", question_type).group()).group()) - - if question_id == 0: - direction_match = re.search(r"\b[A-D]\.\s*(" + "|".join(dirs) + r")\b", text, re.IGNORECASE) - if direction_match: - return direction_match.group(1).lower() - - match = re.search(dir_pattern, text, re.IGNORECASE) - if match: - return match.group(1) - return None - - elif question_id == 1: - match = re.search( - rf"^([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|located\s+to\s+the\s+)({dir_pattern})", - text, - re.IGNORECASE, - ) - - if match: - string = match.group(1) - return string - - match = re.search(r"\b[A-D]\.\s*(.*)", text) # problem with extracting . - - if match: - string = match.group(1) - string = remove_redundancy(string) - string = extract_before_is(string) - return string - - match = re.search(r"\b([ABCD][.,]|[(][abcdABCD][)])\s*(.*?)(?=\sis\b|\.|,|<|$)", text) - if match: - answer = match.group(1).strip() - # Remove trailing punctuation if any - answer = re.sub(r"[\.,\?!<]+$", "", answer) - return answer - - match = re.search( - rf"Therefore, the object in the {dir_pattern} of [\w\s\'\']+ is ([\w\s\'\']+)", text, re.IGNORECASE - ) - if match: - string = match.group(2) - return string - - if "claude" in model_name.lower(): - match = re.search(rf"^([\w\s\'\']+?)\s+is\s+(to\s+the\s+)({dir_pattern})", text, re.IGNORECASE) - if match: - string = match.group(1) - return string - - if "gemini" in model_name.lower(): - patterns = [ - rf"\*\*Concise Answer:\*\*\n([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})", - rf"\*\*Answer:\*\*\s+([\w\s\'\']+?)\s+is\s+in\s+the\s+({dir_pattern})\s+of\s+([\w\s\'\']+)", - r"\*\*Answer:\*\*\n([\w\s\'\']+)", - r"\*\*Answer\*\*:\s+([\w\s\'\']+)", - r"\*\*Answer:\*\*\s+([\w\s\'\']+)", - ] - - for pattern in patterns: - match = re.search(pattern, text, re.IGNORECASE) - if match: - return match.group(1) - - if "gpt-4o" in model_name.lower() or "gpt4o" in model_name.lower(): - match = re.search( - rf"Concise Answer:\s+([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})", - text, - re.IGNORECASE, - ) - if match: - string = match.group(1) - return string - - # If no match, check for an answer following "is", with specific end markers defined - match = re.search(r"\bis\b\s+(.*?)(?=\.|,|<|$)", text) - if match: - answer = match.group(1).strip() - # Remove trailing punctuation if any - answer = re.sub(r"[\.,\?!<]+$", "", answer) - return answer - - return None # Return None if no match is found - - elif question_id == 2: - match = re.search(r"\b[A-D]\.\s*(\d+)", text) # match number only - if match: - return match.group(1) - # Create a list to store all found numbers along with their positions - found_numbers = [] - - # Check for textual numbers and their positions - for text_num, num in number_mapping.items(): - for match in re.finditer(rf"\b{text_num}\b", text, re.IGNORECASE): - found_numbers.append((match.start(), num)) - - # Check for digit sequences and their positions, specifically ignoring list markers at the start - # Exclude numbers following "\n\n" and directly followed by ". " - text = re.sub(r"^\n\n\d+\.\s", "", text) # Remove the leading list marker if it exists - - for match in re.finditer(r"\d+", text): - found_numbers.append((match.start(), int(match.group(0)))) - - # Sort found numbers by their positions (smallest position first) - if found_numbers: - found_numbers.sort(key=lambda x: x[0]) - # Return the number associated with the earliest position - return str(found_numbers[0][1]) - return None - - else: - raise ValueError(f"Question ID {question_id} is not supported.") - - return None # Return None if no numbers are found + for k, v in number_mapping.items(): + model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE) + + # get dict of options from options string + options_dict = {x.split(".")[0].strip().lower():x.split(".")[1].strip().lower() for x in options} + + + model_output_parsed_letter = "" + model_output_parsed = "" + + answers = [v for k, v in options_dict.items()] + answers_pattern = rf"\b({'|'.join(answers)})\b" + + if "Answer:".lower() in model_output_raw.lower(): + pattern_letter = r"^\**Answer:\**\s+(\w)\. (\w+)" + matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) + if matches: + match_option = matches.group(1).lower() + if match_option in options_dict: + model_output_parsed_letter = options_dict[match_option] + else: + model_output_parsed_letter = match_option + + pattern_phrase = r"Answer:\**\s+([^\n]+)" + matches = re.search(pattern_phrase, model_output_raw, re.IGNORECASE) + if matches: + model_output_answer_line = matches.group(1) + + answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) + + if answers_match: + model_output_parsed = answers_match.group(1) + else: + letters = [k for k, v in options_dict.items()] + letters_pattern = rf"\b({'|'.join(letters)})\b" + letters_pattern_match = re.search(letters_pattern, model_output_answer_line, re.IGNORECASE) + + if letters_pattern_match: + match_option = letters_pattern_match.group(1).lower() + model_output_parsed_letter = options_dict[match_option] + + elif "answer is".lower() in model_output_raw.lower(): + pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*\**' + + # first look for a single letter answer + matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) + if matches: + match_option = matches.group(1).lower() + if match_option in options_dict: + model_output_parsed_letter = options_dict[match_option] + else: + model_output_parsed_letter = match_option + + # next look if any of the options names are present in the first line + + model_output_answer_line = model_output_raw.splitlines()[0] + + answers = [v for k, v in options_dict.items()] + answers_pattern = rf"\b({'|'.join(answers)})\b" + answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) + + if answers_match: + model_output_parsed = answers_match.group(1) + + return model_output_parsed + " or " + model_output_parsed_letter def extract_answer_from_text_maze(text, question_type): @@ -443,43 +396,59 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: ) return df - @dataclass -class ExtractAnswerGrid(ExtractAnswer): - """This class is an answer extractor for the GRID benchmark.""" +class ExtractQuestionOptions(DFTransformBase): + """This class is for extracting the option list from a prompt.""" - answer_column_name: str - extracted_answer_column_name: str - question_type_column_name: str - mode: str + prompt_column_name: str + extracted_options_column_name: str - @abstractmethod - def _parse_answer_function(self, answer_text, question_type): - return extract_answer_from_text_grid(answer_text, question_type) + def _extract_options_from_text_map(self, prompt): + """ + Extracts the multiple-choice options list from the text. + + Args: + - text (str): The text containing the prompt. + + Returns: + - str or None: The extracted list of options. + """ + # get list of options from prompt + prompt_lines = prompt.splitlines() + matches = [i for i, x in enumerate(prompt_lines) if "Available options:" in x] + options = prompt_lines[matches[0]+1:matches[0]+5] + + return options + + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + df[self.extracted_options_column_name] = df[self.prompt_column_name].apply(self._extract_options_from_text_map) + return df @dataclass -class ExtractAnswerSpatialMap(ExtractAnswer): - """This class is an answer extractor for the SPATIAL_MAP benchmark.""" +class ExtractAnswerGrid(ExtractAnswer): + """This class is an answer extractor for the GRID benchmark.""" answer_column_name: str extracted_answer_column_name: str question_type_column_name: str - model_name: str + mode: str @abstractmethod def _parse_answer_function(self, answer_text, question_type): - return extract_answer_from_text_map(answer_text, question_type, self.model_name) + return extract_answer_from_text_grid(answer_text, question_type) @dataclass -class ExtractAnswerMaze(ExtractAnswer): - """This class is an answer extractor for the MAZE benchmark.""" +class ExtractAnswerSpatialMapAndMaze(DFTransformBase): + """This class is an answer extractor for the SPATIAL_MAP and MAZE benchmark.""" answer_column_name: str extracted_answer_column_name: str - question_type_column_name: str + extracted_options_column_name: str - @abstractmethod - def _parse_answer_function(self, answer_text, question_type): - return extract_answer_from_text_maze(answer_text, question_type) + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + df[self.extracted_answer_column_name] = df.apply( + lambda x: extract_answer_from_text_map_and_maze(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1 + ) + return df diff --git a/eureka_ml_insights/user_configs/vision_language/maze.py b/eureka_ml_insights/user_configs/vision_language/maze.py index 7294a65..b461a93 100644 --- a/eureka_ml_insights/user_configs/vision_language/maze.py +++ b/eureka_ml_insights/user_configs/vision_language/maze.py @@ -8,11 +8,12 @@ ColumnRename, DataLoader, DataReader, - ExtractAnswerMaze, + ExtractQuestionOptions, + ExtractAnswerSpatialMapAndMaze, PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -81,23 +82,27 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) "format": ".jsonl", "transform": SequenceTransform( [ + ExtractQuestionOptions( + prompt_column_name="prompt", + extracted_options_column_name="target_options_answers", + ), ColumnRename(name_mapping={"model_output": "model_output_raw"}), - ExtractAnswerMaze( + ExtractAnswerSpatialMapAndMaze( answer_column_name="model_output_raw", extracted_answer_column_name="model_output", - question_type_column_name="question_type", + extracted_options_column_name="target_options_answers", ), ], ), }, ), - metric_config=MetricConfig(CaseInsensitiveMatch), + metric_config=MetricConfig(SubstringExistsMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveMatch_result"], + "column_names": ["SubstringExistsMatch_result"], "group_by": "task", "normalize": True, }, diff --git a/eureka_ml_insights/user_configs/vision_language/spatial_map.py b/eureka_ml_insights/user_configs/vision_language/spatial_map.py index 1453335..4bb2134 100644 --- a/eureka_ml_insights/user_configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/user_configs/vision_language/spatial_map.py @@ -8,11 +8,12 @@ ColumnRename, DataLoader, DataReader, - ExtractAnswerSpatialMap, + ExtractAnswerSpatialMapAndMaze, + ExtractQuestionOptions, PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -82,24 +83,27 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) "format": ".jsonl", "transform": SequenceTransform( [ + ExtractQuestionOptions( + prompt_column_name="prompt", + extracted_options_column_name="target_options_answers", + ), ColumnRename(name_mapping={"model_output": "model_output_raw"}), - ExtractAnswerSpatialMap( + ExtractAnswerSpatialMapAndMaze( answer_column_name="model_output_raw", extracted_answer_column_name="model_output", - question_type_column_name="question_type", - model_name=model_config.init_args['model_name'], # passing the model name for model-specific answer extraction + extracted_options_column_name="target_options_answers", ), ], ), }, ), - metric_config=MetricConfig(CaseInsensitiveMatch), + metric_config=MetricConfig(SubstringExistsMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveMatch_result"], + "column_names": ["SubstringExistsMatch_result"], "group_by": "task", "normalize": True, }, diff --git a/tests/data_utils_tests/vision_language_data_utils_tests.py b/tests/data_utils_tests/vision_language_data_utils_tests.py new file mode 100644 index 0000000..bb30c93 --- /dev/null +++ b/tests/data_utils_tests/vision_language_data_utils_tests.py @@ -0,0 +1,63 @@ +# write unit tests for the classes in data_utils/transform.py + +import logging +import unittest + +import pandas as pd + +from eureka_ml_insights.data_utils.spatial_utils import extract_answer_from_text_map_and_maze + +log = logging.getLogger("VISION_LANGUAGE_extract_answer_from_text_map_and_maze_tests") + + +class TestExtractAnswer(unittest.TestCase): + def test_extract_answer_from_text_map_and_maze(self): + + ## Direction answer extraction tests + testcases_dir = [ + "The correct answer is A: Planetarium Prints is located Northeast of Police Supply Store.", + + "Planetarium Prints is located to the Northeast of Police Supply Store, so the answer is A.\n\nTherefore, A (Northeast) is the correct answer as it accurately describes the position of Planetarium Prints relative to Police Supply Store.", + + "A. Northeast\n\nBased on the information provided, Planetarium Prints is located to the Northeast of Police Supply Store. This can be determined through the following steps:\n\n1. Police Supply Store is used as the reference point.\n2. Coral Crafts is stated to be Northwest of Police Supply Store.\n3. Planetarium Prints is then described as being Southeast of Coral Crafts.\n4. Since Coral Crafts is Northwest of Police Supply Store, and Planetarium Prints is Southeast of Coral Crafts, this places Planetarium Prints to the Northeast of Police Supply Store.\n\nAdditionally, the description directly states \"Planetarium Prints is to the Northeast of Police Supply Store\", confirming the answer.", + + "Planetarium Prints is **Northeast** of Police Supply Store. \n\nHere's the step-by-step explanation:\n\n1. **Establish a reference point:** The question asks for the direction of Planetarium Prints *relative* to Police Supply Store. This means we consider Police Supply Store as our central point.\n\n2. **Use the given information:** We are told \"Planetarium Prints is to the Northeast of Police Supply Store.\"\n\n3. **Direct relationship:** Since we have a direct statement about the relative position of Planetarium Prints to Police Supply Store, we don't need to analyze the other clues.\n\n4. **Final answer:** Therefore, Planetarium Prints is located **Northeast** of Police Supply Store. \n", + + "Planetarium Prints is Northeast of Police Supply Store.\n\n**Reasoning:**\n\n1. **Identify the relative positions:**\n - Police Supply Store is the reference point.\n - Planetarium Prints is one of the objects we need to locate relative to the Police Supply Store.\n\n2. **Analyze the given directions:**\n - Planetarium Prints is described as being to the Northeast of Police Supply Store.\n\n3. **Verify with the map:**\n - On the map, Planetarium Prints is positioned to the upper right (Northeast) of the Police Supply Store.\n\n4. **Conclusion:**\n - Based on the description and the map, Planetarium Prints is indeed Northeast of Police Supply Store.\n\nThus, the correct answer is:\nA. Northeast", + + "Planetarium Prints is to the **Northeast** of Police Supply Store (Option A).\n\n - Based on the above placements, **Planetarium Prints is to the Northeast of Police Supply Store**.\n\nTherefore, the correct answer is **Option A: Northeast**.", + + "The correct answer is A. Northeast.\n\nTo determine the direction of Planetarium Prints relative to Police Supply Store, we can follow the given descriptions:\n\n1. Police Supply Store is in the map.\n2. Narwhal's Novelties is to the Northwest of Police Supply Store.\n3. Coral Crafts is to the Northwest of Narwhal's Novelties.\n4. Coral Crafts is to the Northwest of Police Supply Store.\n5. Planetarium Prints is to the Southeast of Coral Crafts.\n6. Planetarium Prints is to the Northeast of Police Supply Store.\n7. Oz Oddities is to the Southwest of Planetarium Prints.\n8. Oz Oddities is to the Southwest of Police Supply Store.\n9. Ice Queen Ice Cream is to the Northwest of Planetarium Prints.\n10. Ice Queen Ice Cream is to the Southeast of Coral Crafts.\n\nFrom these descriptions, we can see that Planetarium Prints is located to the Northeast of Police Supply Store. This is directly stated in point 6. Therefore, the correct answer is A. Northeast.", + + "A. Northeast.\n\nPlanetarium Prints is to the Northeast of Police Supply Store. The reasoning behind this answer is based on the provided directions and locations of the objects on the map. Narwhal's Novelties is to the Northwest of Police Supply Store, and Coral Crafts is to the Northwest of Narwhal's Novelties. Planetarium Prints is to the Southeast of Coral Crafts, which means it is also to the Northeast of Police Supply Store.", + ] + + target_options_dir = [["A. Northeast","B. Northwest","C. Southwest","D. Southeast."]] * len(testcases_dir) + correct_answers_dir = ["northeast"] * len(testcases_dir) + + ## Numerical extraction tests + + testcases_numerical = [ + "A. 1\n\nTo determine how many objects are in the Southeast of Oz Oddities, we need to look at the relative positions of the objects on the map:\n\n1. Oz Oddities is located at the bottom of the map.\n2. Directly to the Northeast of Oz Oddities is the Police Supply Store.\n3. To the Southeast of Oz Oddities, there is only one object, which is Planetarium Prints.\n4. All other objects are either to the North or Northwest of Oz Oddities and therefore not in the Southeast direction.\n\nBased on the map, only Planetarium Prints is in the Southeast of Oz Oddities, which means the correct answer is A. 1.", + + "There are zero objects", + + "There are no objects", + ] + + target_options_numerical= [["A. 1","B. 0","C. 2","D. 3."]] * len(testcases_numerical) + correct_answers_numerical = ["1", "0", "0"] + + target_options = target_options_dir + target_options_numerical + testcases = testcases_dir + testcases_numerical + correct_answers = correct_answers_dir + correct_answers_numerical + + results = [] + for i, test in enumerate(testcases): + extracted_answer = extract_answer_from_text_map_and_maze(test, target_options[i]) + results.append(correct_answers[i].lower() in extracted_answer.lower()) + + self.assertTrue(all(results)) + +if __name__ == "__main__": + unittest.main()