@@ -154,25 +154,34 @@ def extract_answer_from_text_grid(text, question_type):
154
154
return None # Return None if no numbers are found
155
155
156
156
157
- def extract_answer_from_text_map ( text , question_type , model_name ):
157
+ def extract_answer_from_text_map_and_maze ( model_output_raw , options ):
158
158
"""
159
- Extracts the answer from the text based on specific patterns,
160
- and as a fallback, extracts the first number if no patterns match.
161
- The code is from: https://github.com/alvinmingwisc/spatial_reason_vlm/tree/main/eval,
162
- and included with minimal modifications.
159
+ Extracts the answer from the text based on known model output patterns.
160
+ Searches for both a letter and whole word answer and returns both as they are not
161
+ always consistent.
163
162
164
163
Args:
165
- - text (str): The text containing the model's answer.
166
- - question_type (str): The text containing the question type.
167
- - model_name (str): The model name.
164
+ - model_output_raw (str): The text containing the model's answer.
165
+ - options (str): The list of options.
168
166
169
167
Returns:
170
- - str or None: The extracted answer , or None if no answer could be extracted.
168
+ - str or None: The extracted answers , or empty strings if no answer could be extracted.
171
169
"""
172
- # Mapping of textual numbers to their numeric equivalents
170
+
171
+ # replace common subsitutions in model outputs
172
+
173
+ model_output_parsed_letter = ""
174
+ model_output_parsed = ""
175
+
176
+ if not model_output_raw :
177
+ return [model_output_parsed , model_output_parsed_letter ]
178
+
179
+ model_output_raw = re .sub (r"\bno objects\b" , "0 objects" , model_output_raw , re .IGNORECASE )
180
+ model_output_raw = re .sub (r"\bnot\b" , "no" , model_output_raw , re .IGNORECASE )
181
+ model_output_raw = re .sub (r"\bshould be\b" , "is" , model_output_raw , re .IGNORECASE )
182
+
173
183
number_mapping = {
174
- "zero" : 0 ,
175
- "no" : 0 ,
184
+ "zero" : 0 ,
176
185
"one" : 1 ,
177
186
"two" : 2 ,
178
187
"three" : 3 ,
@@ -184,127 +193,71 @@ def extract_answer_from_text_map(text, question_type, model_name):
184
193
"nine" : 9 ,
185
194
}
186
195
187
- dirs = ["southeast" , "northeast" , "northwest" , "southwest" ]
188
- dir_pattern = rf"\b({ '|' .join (dirs )} )\b"
189
-
190
- if text is None :
191
- return None
192
-
193
- question_id = int (re .search ("[0-9]" , re .search ("Q[0-9]" , question_type ).group ()).group ())
194
-
195
- if question_id == 0 :
196
- direction_match = re .search (r"\b[A-D]\.\s*(" + "|" .join (dirs ) + r")\b" , text , re .IGNORECASE )
197
- if direction_match :
198
- return direction_match .group (1 ).lower ()
199
-
200
- match = re .search (dir_pattern , text , re .IGNORECASE )
201
- if match :
202
- return match .group (1 )
203
- return None
204
-
205
- elif question_id == 1 :
206
- match = re .search (
207
- rf"^([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|located\s+to\s+the\s+)({ dir_pattern } )" ,
208
- text ,
209
- re .IGNORECASE ,
210
- )
211
-
212
- if match :
213
- string = match .group (1 )
214
- return string
215
-
216
- match = re .search (r"\b[A-D]\.\s*(.*)" , text ) # problem with extracting .
217
-
218
- if match :
219
- string = match .group (1 )
220
- string = remove_redundancy (string )
221
- string = extract_before_is (string )
222
- return string
223
-
224
- match = re .search (r"\b([ABCD][.,]|[(][abcdABCD][)])\s*(.*?)(?=\sis\b|\.|,|<|$)" , text )
225
- if match :
226
- answer = match .group (1 ).strip ()
227
- # Remove trailing punctuation if any
228
- answer = re .sub (r"[\.,\?!<]+$" , "" , answer )
229
- return answer
230
-
231
- match = re .search (
232
- rf"Therefore, the object in the { dir_pattern } of [\w\s\'\']+ is ([\w\s\'\']+)" , text , re .IGNORECASE
233
- )
234
- if match :
235
- string = match .group (2 )
236
- return string
237
-
238
- if "claude" in model_name .lower ():
239
- match = re .search (rf"^([\w\s\'\']+?)\s+is\s+(to\s+the\s+)({ dir_pattern } )" , text , re .IGNORECASE )
240
- if match :
241
- string = match .group (1 )
242
- return string
243
-
244
- if "gemini" in model_name .lower ():
245
- patterns = [
246
- rf"\*\*Concise Answer:\*\*\n([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({ dir_pattern } )" ,
247
- rf"\*\*Answer:\*\*\s+([\w\s\'\']+?)\s+is\s+in\s+the\s+({ dir_pattern } )\s+of\s+([\w\s\'\']+)" ,
248
- r"\*\*Answer:\*\*\n([\w\s\'\']+)" ,
249
- r"\*\*Answer\*\*:\s+([\w\s\'\']+)" ,
250
- r"\*\*Answer:\*\*\s+([\w\s\'\']+)" ,
251
- ]
252
-
253
- for pattern in patterns :
254
- match = re .search (pattern , text , re .IGNORECASE )
255
- if match :
256
- return match .group (1 )
257
-
258
- if "gpt-4o" in model_name .lower () or "gpt4o" in model_name .lower ():
259
- match = re .search (
260
- rf"Concise Answer:\s+([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({ dir_pattern } )" ,
261
- text ,
262
- re .IGNORECASE ,
263
- )
264
- if match :
265
- string = match .group (1 )
266
- return string
267
-
268
- # If no match, check for an answer following "is", with specific end markers defined
269
- match = re .search (r"\bis\b\s+(.*?)(?=\.|,|<|$)" , text )
270
- if match :
271
- answer = match .group (1 ).strip ()
272
- # Remove trailing punctuation if any
273
- answer = re .sub (r"[\.,\?!<]+$" , "" , answer )
274
- return answer
275
-
276
- return None # Return None if no match is found
277
-
278
- elif question_id == 2 :
279
- match = re .search (r"\b[A-D]\.\s*(\d+)" , text ) # match number only
280
- if match :
281
- return match .group (1 )
282
- # Create a list to store all found numbers along with their positions
283
- found_numbers = []
284
-
285
- # Check for textual numbers and their positions
286
- for text_num , num in number_mapping .items ():
287
- for match in re .finditer (rf"\b{ text_num } \b" , text , re .IGNORECASE ):
288
- found_numbers .append ((match .start (), num ))
289
-
290
- # Check for digit sequences and their positions, specifically ignoring list markers at the start
291
- # Exclude numbers following "\n\n" and directly followed by ". "
292
- text = re .sub (r"^\n\n\d+\.\s" , "" , text ) # Remove the leading list marker if it exists
293
-
294
- for match in re .finditer (r"\d+" , text ):
295
- found_numbers .append ((match .start (), int (match .group (0 ))))
296
-
297
- # Sort found numbers by their positions (smallest position first)
298
- if found_numbers :
299
- found_numbers .sort (key = lambda x : x [0 ])
300
- # Return the number associated with the earliest position
301
- return str (found_numbers [0 ][1 ])
302
- return None
303
-
304
- else :
305
- raise ValueError (f"Question ID { question_id } is not supported." )
306
-
307
- return None # Return None if no numbers are found
196
+ for k , v in number_mapping .items ():
197
+ model_output_raw = re .sub (rf"\b{ k } \b" , str (v ), model_output_raw , re .IGNORECASE )
198
+
199
+ # get dict of options from options string
200
+ options_dict = {x .split ("." )[0 ].strip ().lower ():x .split ("." )[1 ].strip ().lower () for x in options }
201
+
202
+
203
+ model_output_parsed_letter = ""
204
+ model_output_parsed = ""
205
+
206
+ answers = [v for k , v in options_dict .items ()]
207
+ answers_pattern = rf"\b({ '|' .join (answers )} )\b"
208
+
209
+ if "Answer:" .lower () in model_output_raw .lower ():
210
+ pattern_letter = r"^\**Answer:\**\s+(\w)\. (\w+)"
211
+ matches = re .search (pattern_letter , model_output_raw , re .IGNORECASE )
212
+ if matches :
213
+ match_option = matches .group (1 ).lower ()
214
+ if match_option in options_dict :
215
+ model_output_parsed_letter = options_dict [match_option ]
216
+ else :
217
+ model_output_parsed_letter = match_option
218
+
219
+ pattern_phrase = r"Answer:\**\s+([^\n]+)"
220
+ matches = re .search (pattern_phrase , model_output_raw , re .IGNORECASE )
221
+ if matches :
222
+ model_output_answer_line = matches .group (1 )
223
+
224
+ answers_match = re .search (answers_pattern , model_output_answer_line , re .IGNORECASE )
225
+
226
+ if answers_match :
227
+ model_output_parsed = answers_match .group (1 )
228
+ else :
229
+ letters = [k for k , v in options_dict .items ()]
230
+ letters_pattern = rf"\b({ '|' .join (letters )} )\b"
231
+ letters_pattern_match = re .search (letters_pattern , model_output_answer_line , re .IGNORECASE )
232
+
233
+ if letters_pattern_match :
234
+ match_option = letters_pattern_match .group (1 ).lower ()
235
+ model_output_parsed_letter = options_dict [match_option ]
236
+
237
+ elif "answer is" .lower () in model_output_raw .lower ():
238
+ pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*\**'
239
+
240
+ # first look for a single letter answer
241
+ matches = re .search (pattern_letter , model_output_raw , re .IGNORECASE )
242
+ if matches :
243
+ match_option = matches .group (1 ).lower ()
244
+ if match_option in options_dict :
245
+ model_output_parsed_letter = options_dict [match_option ]
246
+ else :
247
+ model_output_parsed_letter = match_option
248
+
249
+ # next look if any of the options names are present in the first line
250
+
251
+ model_output_answer_line = model_output_raw .splitlines ()[0 ]
252
+
253
+ answers = [v for k , v in options_dict .items ()]
254
+ answers_pattern = rf"\b({ '|' .join (answers )} )\b"
255
+ answers_match = re .search (answers_pattern , model_output_answer_line , re .IGNORECASE )
256
+
257
+ if answers_match :
258
+ model_output_parsed = answers_match .group (1 )
259
+
260
+ return model_output_parsed + " or " + model_output_parsed_letter
308
261
309
262
310
263
def extract_answer_from_text_maze (text , question_type ):
@@ -440,43 +393,59 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame:
440
393
)
441
394
return df
442
395
443
-
444
396
@dataclass
445
- class ExtractAnswerGrid ( ExtractAnswer ):
446
- """This class is an answer extractor for the GRID benchmark ."""
397
+ class ExtractQuestionOptions ( DFTransformBase ):
398
+ """This class is for extracting the option list from a prompt ."""
447
399
448
- answer_column_name : str
449
- extracted_answer_column_name : str
450
- question_type_column_name : str
451
- mode : str
400
+ prompt_column_name : str
401
+ extracted_options_column_name : str
452
402
453
- @abstractmethod
454
- def _parse_answer_function (self , answer_text , question_type ):
455
- return extract_answer_from_text_grid (answer_text , question_type )
403
+ def _extract_options_from_text_map (self , prompt ):
404
+ """
405
+ Extracts the multiple-choice options list from the text.
406
+
407
+ Args:
408
+ - text (str): The text containing the prompt.
409
+
410
+ Returns:
411
+ - str or None: The extracted list of options.
412
+ """
456
413
414
+ # get list of options from prompt
415
+ prompt_lines = prompt .splitlines ()
416
+ matches = [i for i , x in enumerate (prompt_lines ) if "Available options:" in x ]
417
+ options = prompt_lines [matches [0 ]+ 1 :matches [0 ]+ 5 ]
418
+
419
+ return options
420
+
421
+ def transform (self , df : pd .DataFrame ) -> pd .DataFrame :
422
+ df [self .extracted_options_column_name ] = df [self .prompt_column_name ].apply (self ._extract_options_from_text_map )
423
+ return df
457
424
458
425
@dataclass
459
- class ExtractAnswerSpatialMap (ExtractAnswer ):
460
- """This class is an answer extractor for the SPATIAL_MAP benchmark."""
426
+ class ExtractAnswerGrid (ExtractAnswer ):
427
+ """This class is an answer extractor for the GRID benchmark."""
461
428
462
429
answer_column_name : str
463
430
extracted_answer_column_name : str
464
431
question_type_column_name : str
465
- model_name : str
432
+ mode : str
466
433
467
434
@abstractmethod
468
435
def _parse_answer_function (self , answer_text , question_type ):
469
- return extract_answer_from_text_map (answer_text , question_type , self . model_name )
436
+ return extract_answer_from_text_grid (answer_text , question_type )
470
437
471
438
472
439
@dataclass
473
- class ExtractAnswerMaze ( ExtractAnswer ):
474
- """This class is an answer extractor for the MAZE benchmark."""
440
+ class ExtractAnswerSpatialMapAndMaze ( DFTransformBase ):
441
+ """This class is an answer extractor for the SPATIAL_MAP and MAZE benchmark."""
475
442
476
443
answer_column_name : str
477
444
extracted_answer_column_name : str
478
- question_type_column_name : str
445
+ extracted_options_column_name : str
479
446
480
- @abstractmethod
481
- def _parse_answer_function (self , answer_text , question_type ):
482
- return extract_answer_from_text_maze (answer_text , question_type )
447
+ def transform (self , df : pd .DataFrame ) -> pd .DataFrame :
448
+ df [self .extracted_answer_column_name ] = df .apply (
449
+ lambda x : extract_answer_from_text_map_and_maze (x [self .answer_column_name ], x [self .extracted_options_column_name ]), axis = 1
450
+ )
451
+ return df
0 commit comments