Skip to content

Commit c7b848e

Browse files
author
neel
committed
added to comments, renamed metric classes
1 parent 74107ad commit c7b848e

File tree

5 files changed

+23
-9
lines changed

5 files changed

+23
-9
lines changed

eureka_ml_insights/data_utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
PrependStringTransform,
7070
ExtractAnswerGrid,
7171
ExtractAnswerSpatialMapAndMaze,
72+
ExtractQuestionOptions,
7273
ShuffleColumnsTransform,
7374
ColumnMatchMapTransform,
7475
TokenCounterTransform,

eureka_ml_insights/metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from .geomtric_reasoning_metrics import GeoMCQMetric
33
from .metrics_base import (
44
CaseInsensitiveMatch,
5-
CaseInsensitiveOrMatch,
65
ClassicMetric,
76
CompositeMetric,
87
ExactMatch,
98
IdentityMetric,
109
Metric,
10+
MultiCandidateAnyExactMatch,
11+
MultiCandidateAnyCaseInsensitiveMatch,
1112
SubstringExistsMatch,
1213
)
1314
from .mmmu_metrics import MMMUMetric

eureka_ml_insights/metrics/metrics_base.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,14 @@ def __evaluate__(self, answer_text, target_text, is_valid):
144144
else:
145145
return "incorrect"
146146

147-
class ExactOrMatch(ExactMatch):
148-
"""This class checks for a case-sensitive, but otherwise exact match, and returns the or of them."""
147+
class MultiCandidateAnyExactMatch(ExactMatch):
148+
"""
149+
This class checks for a case-sensitive match for a list of answers from the model output,
150+
and returns the or of the list of metric results.
151+
152+
This is required for answers to multiple-choice questions. As many models sometimes give the letter answer
153+
and sometimes the full word answer. This allows one to consider the answer correct if either one was correct.
154+
"""
149155

150156
def __evaluate__(self, answer_texts, target_text, is_valid):
151157

@@ -170,8 +176,14 @@ class CaseInsensitiveMatch(ExactMatch):
170176
def __evaluate__(self, answer_text, target_text, is_valid):
171177
return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid)
172178

173-
class CaseInsensitiveOrMatch(ExactOrMatch):
174-
"""This class checks for a case-insensitive, but otherwise exact or match."""
179+
class MultiCandidateAnyCaseInsensitiveMatch(MultiCandidateAnyExactMatch):
180+
"""
181+
This class checks for a case-insensitive match for a list of answers from the model output,
182+
and returns the or of the list of metric results.
183+
184+
This is required for answers to multiple-choice questions. As many models sometimes give the letter answer
185+
and sometimes the full word answer. This allows one to consider the answer correct if either one was correct.
186+
"""
175187

176188
def __evaluate__(self, answer_texts, target_text, is_valid):
177189
answer_texts = [str(answer_text).lower() for answer_text in answer_texts]

eureka_ml_insights/user_configs/vision_language/maze.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
PrependStringTransform,
1414
SequenceTransform,
1515
)
16-
from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator
16+
from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator
1717

1818
from eureka_ml_insights.configs import (
1919
AggregatorConfig,
@@ -96,7 +96,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
9696
),
9797
},
9898
),
99-
metric_config=MetricConfig(CaseInsensitiveOrMatch),
99+
metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch),
100100
aggregator_configs=[
101101
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}),
102102
AggregatorConfig(

eureka_ml_insights/user_configs/vision_language/spatial_map.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
PrependStringTransform,
1414
SequenceTransform,
1515
)
16-
from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator
16+
from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator
1717

1818
from eureka_ml_insights.configs import (
1919
AggregatorConfig,
@@ -97,7 +97,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
9797
),
9898
},
9999
),
100-
metric_config=MetricConfig(CaseInsensitiveOrMatch),
100+
metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch),
101101
aggregator_configs=[
102102
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}),
103103
AggregatorConfig(

0 commit comments

Comments
 (0)