Skip to content

Commit bbe1f72

Browse files
neelsjneel
and
neel
authored
Neel/spatial map answer extraction (#57)
This PR is a set of changes to update and simplify the answer extraction or Spatial Map and Maze. Removes model specific answer extraction --------- Co-authored-by: neel <[email protected]>
1 parent cfc5144 commit bbe1f72

File tree

5 files changed

+213
-172
lines changed

5 files changed

+213
-172
lines changed

eureka_ml_insights/data_utils/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from .prompt_processing import JinjaPromptTemplate
1717
from .spatial_utils import (
1818
ExtractAnswerGrid,
19-
ExtractAnswerMaze,
20-
ExtractAnswerSpatialMap,
19+
ExtractAnswerSpatialMapAndMaze,
20+
ExtractQuestionOptions,
2121
)
2222
from .transform import (
2323
AddColumn,
@@ -71,8 +71,8 @@
7171
ASTEvalTransform,
7272
PrependStringTransform,
7373
ExtractAnswerGrid,
74-
ExtractAnswerSpatialMap,
75-
ExtractAnswerMaze,
74+
ExtractAnswerSpatialMapAndMaze,
75+
ExtractQuestionOptions,
7676
ShuffleColumnsTransform,
7777
ColumnMatchMapTransform,
7878
TokenCounterTransform,

eureka_ml_insights/data_utils/spatial_utils.py

+122-153
Original file line numberDiff line numberDiff line change
@@ -154,25 +154,34 @@ def extract_answer_from_text_grid(text, question_type):
154154
return None # Return None if no numbers are found
155155

156156

157-
def extract_answer_from_text_map(text, question_type, model_name):
157+
def extract_answer_from_text_map_and_maze(model_output_raw, options):
158158
"""
159-
Extracts the answer from the text based on specific patterns,
160-
and as a fallback, extracts the first number if no patterns match.
161-
The code is from: https://github.com/alvinmingwisc/spatial_reason_vlm/tree/main/eval,
162-
and included with minimal modifications.
159+
Extracts the answer from the text based on known model output patterns.
160+
Searches for both a letter and whole word answer and returns both as they are not
161+
always consistent.
163162
164163
Args:
165-
- text (str): The text containing the model's answer.
166-
- question_type (str): The text containing the question type.
167-
- model_name (str): The model name.
164+
- model_output_raw (str): The text containing the model's answer.
165+
- options (str): The list of options.
168166
169167
Returns:
170-
- str or None: The extracted answer, or None if no answer could be extracted.
168+
- str or None: The extracted answers, or empty strings if no answer could be extracted.
171169
"""
172-
# Mapping of textual numbers to their numeric equivalents
170+
171+
# replace common subsitutions in model outputs
172+
173+
model_output_parsed_letter = ""
174+
model_output_parsed = ""
175+
176+
if not model_output_raw:
177+
return [model_output_parsed, model_output_parsed_letter]
178+
179+
model_output_raw = re.sub(r"\bno objects\b", "0 objects", model_output_raw, re.IGNORECASE)
180+
model_output_raw = re.sub(r"\bnot\b", "no", model_output_raw, re.IGNORECASE)
181+
model_output_raw = re.sub(r"\bshould be\b", "is", model_output_raw, re.IGNORECASE)
182+
173183
number_mapping = {
174-
"zero": 0,
175-
"no": 0,
184+
"zero": 0,
176185
"one": 1,
177186
"two": 2,
178187
"three": 3,
@@ -184,127 +193,71 @@ def extract_answer_from_text_map(text, question_type, model_name):
184193
"nine": 9,
185194
}
186195

187-
dirs = ["southeast", "northeast", "northwest", "southwest"]
188-
dir_pattern = rf"\b({'|'.join(dirs)})\b"
189-
190-
if text is None:
191-
return None
192-
193-
question_id = int(re.search("[0-9]", re.search("Q[0-9]", question_type).group()).group())
194-
195-
if question_id == 0:
196-
direction_match = re.search(r"\b[A-D]\.\s*(" + "|".join(dirs) + r")\b", text, re.IGNORECASE)
197-
if direction_match:
198-
return direction_match.group(1).lower()
199-
200-
match = re.search(dir_pattern, text, re.IGNORECASE)
201-
if match:
202-
return match.group(1)
203-
return None
204-
205-
elif question_id == 1:
206-
match = re.search(
207-
rf"^([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|located\s+to\s+the\s+)({dir_pattern})",
208-
text,
209-
re.IGNORECASE,
210-
)
211-
212-
if match:
213-
string = match.group(1)
214-
return string
215-
216-
match = re.search(r"\b[A-D]\.\s*(.*)", text) # problem with extracting .
217-
218-
if match:
219-
string = match.group(1)
220-
string = remove_redundancy(string)
221-
string = extract_before_is(string)
222-
return string
223-
224-
match = re.search(r"\b([ABCD][.,]|[(][abcdABCD][)])\s*(.*?)(?=\sis\b|\.|,|<|$)", text)
225-
if match:
226-
answer = match.group(1).strip()
227-
# Remove trailing punctuation if any
228-
answer = re.sub(r"[\.,\?!<]+$", "", answer)
229-
return answer
230-
231-
match = re.search(
232-
rf"Therefore, the object in the {dir_pattern} of [\w\s\'\']+ is ([\w\s\'\']+)", text, re.IGNORECASE
233-
)
234-
if match:
235-
string = match.group(2)
236-
return string
237-
238-
if "claude" in model_name.lower():
239-
match = re.search(rf"^([\w\s\'\']+?)\s+is\s+(to\s+the\s+)({dir_pattern})", text, re.IGNORECASE)
240-
if match:
241-
string = match.group(1)
242-
return string
243-
244-
if "gemini" in model_name.lower():
245-
patterns = [
246-
rf"\*\*Concise Answer:\*\*\n([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})",
247-
rf"\*\*Answer:\*\*\s+([\w\s\'\']+?)\s+is\s+in\s+the\s+({dir_pattern})\s+of\s+([\w\s\'\']+)",
248-
r"\*\*Answer:\*\*\n([\w\s\'\']+)",
249-
r"\*\*Answer\*\*:\s+([\w\s\'\']+)",
250-
r"\*\*Answer:\*\*\s+([\w\s\'\']+)",
251-
]
252-
253-
for pattern in patterns:
254-
match = re.search(pattern, text, re.IGNORECASE)
255-
if match:
256-
return match.group(1)
257-
258-
if "gpt-4o" in model_name.lower() or "gpt4o" in model_name.lower():
259-
match = re.search(
260-
rf"Concise Answer:\s+([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})",
261-
text,
262-
re.IGNORECASE,
263-
)
264-
if match:
265-
string = match.group(1)
266-
return string
267-
268-
# If no match, check for an answer following "is", with specific end markers defined
269-
match = re.search(r"\bis\b\s+(.*?)(?=\.|,|<|$)", text)
270-
if match:
271-
answer = match.group(1).strip()
272-
# Remove trailing punctuation if any
273-
answer = re.sub(r"[\.,\?!<]+$", "", answer)
274-
return answer
275-
276-
return None # Return None if no match is found
277-
278-
elif question_id == 2:
279-
match = re.search(r"\b[A-D]\.\s*(\d+)", text) # match number only
280-
if match:
281-
return match.group(1)
282-
# Create a list to store all found numbers along with their positions
283-
found_numbers = []
284-
285-
# Check for textual numbers and their positions
286-
for text_num, num in number_mapping.items():
287-
for match in re.finditer(rf"\b{text_num}\b", text, re.IGNORECASE):
288-
found_numbers.append((match.start(), num))
289-
290-
# Check for digit sequences and their positions, specifically ignoring list markers at the start
291-
# Exclude numbers following "\n\n" and directly followed by ". "
292-
text = re.sub(r"^\n\n\d+\.\s", "", text) # Remove the leading list marker if it exists
293-
294-
for match in re.finditer(r"\d+", text):
295-
found_numbers.append((match.start(), int(match.group(0))))
296-
297-
# Sort found numbers by their positions (smallest position first)
298-
if found_numbers:
299-
found_numbers.sort(key=lambda x: x[0])
300-
# Return the number associated with the earliest position
301-
return str(found_numbers[0][1])
302-
return None
303-
304-
else:
305-
raise ValueError(f"Question ID {question_id} is not supported.")
306-
307-
return None # Return None if no numbers are found
196+
for k, v in number_mapping.items():
197+
model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE)
198+
199+
# get dict of options from options string
200+
options_dict = {x.split(".")[0].strip().lower():x.split(".")[1].strip().lower() for x in options}
201+
202+
203+
model_output_parsed_letter = ""
204+
model_output_parsed = ""
205+
206+
answers = [v for k, v in options_dict.items()]
207+
answers_pattern = rf"\b({'|'.join(answers)})\b"
208+
209+
if "Answer:".lower() in model_output_raw.lower():
210+
pattern_letter = r"^\**Answer:\**\s+(\w)\. (\w+)"
211+
matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE)
212+
if matches:
213+
match_option = matches.group(1).lower()
214+
if match_option in options_dict:
215+
model_output_parsed_letter = options_dict[match_option]
216+
else:
217+
model_output_parsed_letter = match_option
218+
219+
pattern_phrase = r"Answer:\**\s+([^\n]+)"
220+
matches = re.search(pattern_phrase, model_output_raw, re.IGNORECASE)
221+
if matches:
222+
model_output_answer_line = matches.group(1)
223+
224+
answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE)
225+
226+
if answers_match:
227+
model_output_parsed = answers_match.group(1)
228+
else:
229+
letters = [k for k, v in options_dict.items()]
230+
letters_pattern = rf"\b({'|'.join(letters)})\b"
231+
letters_pattern_match = re.search(letters_pattern, model_output_answer_line, re.IGNORECASE)
232+
233+
if letters_pattern_match:
234+
match_option = letters_pattern_match.group(1).lower()
235+
model_output_parsed_letter = options_dict[match_option]
236+
237+
elif "answer is".lower() in model_output_raw.lower():
238+
pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*\**'
239+
240+
# first look for a single letter answer
241+
matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE)
242+
if matches:
243+
match_option = matches.group(1).lower()
244+
if match_option in options_dict:
245+
model_output_parsed_letter = options_dict[match_option]
246+
else:
247+
model_output_parsed_letter = match_option
248+
249+
# next look if any of the options names are present in the first line
250+
251+
model_output_answer_line = model_output_raw.splitlines()[0]
252+
253+
answers = [v for k, v in options_dict.items()]
254+
answers_pattern = rf"\b({'|'.join(answers)})\b"
255+
answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE)
256+
257+
if answers_match:
258+
model_output_parsed = answers_match.group(1)
259+
260+
return model_output_parsed + " or " + model_output_parsed_letter
308261

309262

310263
def extract_answer_from_text_maze(text, question_type):
@@ -440,43 +393,59 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
440393
)
441394
return df
442395

443-
444396
@dataclass
445-
class ExtractAnswerGrid(ExtractAnswer):
446-
"""This class is an answer extractor for the GRID benchmark."""
397+
class ExtractQuestionOptions(DFTransformBase):
398+
"""This class is for extracting the option list from a prompt."""
447399

448-
answer_column_name: str
449-
extracted_answer_column_name: str
450-
question_type_column_name: str
451-
mode: str
400+
prompt_column_name: str
401+
extracted_options_column_name: str
452402

453-
@abstractmethod
454-
def _parse_answer_function(self, answer_text, question_type):
455-
return extract_answer_from_text_grid(answer_text, question_type)
403+
def _extract_options_from_text_map(self, prompt):
404+
"""
405+
Extracts the multiple-choice options list from the text.
406+
407+
Args:
408+
- text (str): The text containing the prompt.
409+
410+
Returns:
411+
- str or None: The extracted list of options.
412+
"""
456413

414+
# get list of options from prompt
415+
prompt_lines = prompt.splitlines()
416+
matches = [i for i, x in enumerate(prompt_lines) if "Available options:" in x]
417+
options = prompt_lines[matches[0]+1:matches[0]+5]
418+
419+
return options
420+
421+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
422+
df[self.extracted_options_column_name] = df[self.prompt_column_name].apply(self._extract_options_from_text_map)
423+
return df
457424

458425
@dataclass
459-
class ExtractAnswerSpatialMap(ExtractAnswer):
460-
"""This class is an answer extractor for the SPATIAL_MAP benchmark."""
426+
class ExtractAnswerGrid(ExtractAnswer):
427+
"""This class is an answer extractor for the GRID benchmark."""
461428

462429
answer_column_name: str
463430
extracted_answer_column_name: str
464431
question_type_column_name: str
465-
model_name: str
432+
mode: str
466433

467434
@abstractmethod
468435
def _parse_answer_function(self, answer_text, question_type):
469-
return extract_answer_from_text_map(answer_text, question_type, self.model_name)
436+
return extract_answer_from_text_grid(answer_text, question_type)
470437

471438

472439
@dataclass
473-
class ExtractAnswerMaze(ExtractAnswer):
474-
"""This class is an answer extractor for the MAZE benchmark."""
440+
class ExtractAnswerSpatialMapAndMaze(DFTransformBase):
441+
"""This class is an answer extractor for the SPATIAL_MAP and MAZE benchmark."""
475442

476443
answer_column_name: str
477444
extracted_answer_column_name: str
478-
question_type_column_name: str
445+
extracted_options_column_name: str
479446

480-
@abstractmethod
481-
def _parse_answer_function(self, answer_text, question_type):
482-
return extract_answer_from_text_maze(answer_text, question_type)
447+
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
448+
df[self.extracted_answer_column_name] = df.apply(
449+
lambda x: extract_answer_from_text_map_and_maze(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1
450+
)
451+
return df

eureka_ml_insights/user_configs/vision_language/maze.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
ColumnRename,
99
DataLoader,
1010
DataReader,
11-
ExtractAnswerMaze,
11+
ExtractQuestionOptions,
12+
ExtractAnswerSpatialMapAndMaze,
1213
PrependStringTransform,
1314
SequenceTransform,
1415
)
15-
from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator
16+
from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator
1617

1718
from eureka_ml_insights.configs import (
1819
AggregatorConfig,
@@ -81,23 +82,27 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
8182
"format": ".jsonl",
8283
"transform": SequenceTransform(
8384
[
85+
ExtractQuestionOptions(
86+
prompt_column_name="prompt",
87+
extracted_options_column_name="target_options_answers",
88+
),
8489
ColumnRename(name_mapping={"model_output": "model_output_raw"}),
85-
ExtractAnswerMaze(
90+
ExtractAnswerSpatialMapAndMaze(
8691
answer_column_name="model_output_raw",
8792
extracted_answer_column_name="model_output",
88-
question_type_column_name="question_type",
93+
extracted_options_column_name="target_options_answers",
8994
),
9095
],
9196
),
9297
},
9398
),
94-
metric_config=MetricConfig(CaseInsensitiveMatch),
99+
metric_config=MetricConfig(SubstringExistsMatch),
95100
aggregator_configs=[
96-
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}),
101+
AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}),
97102
AggregatorConfig(
98103
CountAggregator,
99104
{
100-
"column_names": ["CaseInsensitiveMatch_result"],
105+
"column_names": ["SubstringExistsMatch_result"],
101106
"group_by": "task",
102107
"normalize": True,
103108
},

0 commit comments

Comments
 (0)