Skip to content

Commit 88f06e8

Browse files
author
Besmira Nushi
committed
modify gpqa config so it uses the new aggregators
1 parent 8b8114e commit 88f06e8

File tree

1 file changed

+24
-38
lines changed
  • eureka_ml_insights/user_configs

1 file changed

+24
-38
lines changed

eureka_ml_insights/user_configs/gpqa.py

+24-38
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,8 @@
2323
from eureka_ml_insights.metrics import (
2424
CountAggregator,
2525
ExactMatch,
26-
BiLevelMaxAggregator,
27-
BiLevelCountAggregator,
28-
BiLevelAverageAggregator,
29-
BiLevelSumAggregator
26+
BiLevelAggregator,
27+
BiLevelCountAggregator
3028
)
3129

3230
from eureka_ml_insights.configs import(
@@ -113,17 +111,6 @@ def configure_pipeline(
113111
column_name_src="model_output",
114112
column_name_dst="raw_model_output",
115113
),
116-
# these columns are currently copied so that they can be used in the bilevel aggregators
117-
# if they are not copied, then the bilevel aggregator used in sub category reports (e.g. subdomain)
118-
# will find the column name ambiguous
119-
CopyColumn(
120-
column_name_src="Subdomain",
121-
column_name_dst="Subdomain_copy",
122-
),
123-
CopyColumn(
124-
column_name_src="High-level domain",
125-
column_name_dst="High-level domain_copy",
126-
),
127114
RegexTransform(
128115
columns="model_output",
129116
prompt_pattern=r"Final Answer: (\w)(?=\s|\W|$)",
@@ -171,39 +158,42 @@ def configure_pipeline(
171158
AggregatorConfig(BiLevelCountAggregator,
172159
{
173160
"column_names": ["ExactMatch_result"],
174-
"first_groupby": ["data_repeat_id", "Subdomain_copy"],
161+
"first_groupby": ["data_repeat_id", "Subdomain"],
175162
"second_groupby": "Subdomain",
176163
"filename_base": "ExactMatch_GroupBy_Subdomain_AllRuns",
177164
"normalize": True
178165
}),
179166
AggregatorConfig(BiLevelCountAggregator,
180167
{
181168
"column_names": ["ExactMatch_result"],
182-
"first_groupby": ["data_repeat_id", "High-level domain_copy"],
169+
"first_groupby": ["data_repeat_id", "High-level domain"],
183170
"second_groupby": "High-level domain",
184171
"filename_base": "ExactMatch_GroupBy_High-level_domain_AllRuns",
185172
"normalize": True
186173
}),
187174
# three similar reports for average completion usage
188-
AggregatorConfig(BiLevelAverageAggregator,
175+
AggregatorConfig(BiLevelAggregator,
189176
{
190177
"column_names": ["usage_completion"],
191178
"first_groupby": "data_point_id",
192179
"filename_base": "UsageCompletion_AllRuns",
180+
"agg_fn": "mean"
193181
}),
194-
AggregatorConfig(BiLevelAverageAggregator,
182+
AggregatorConfig(BiLevelAggregator,
195183
{
196184
"column_names": ["usage_completion"],
197-
"first_groupby": ["data_point_id", "Subdomain_copy"],
185+
"first_groupby": ["data_point_id", "Subdomain"],
198186
"second_groupby": "Subdomain",
199187
"filename_base": "UsageCompletion_GroupBy_Subdomain_AllRuns",
188+
"agg_fn": "mean"
200189
}),
201-
AggregatorConfig(BiLevelAverageAggregator,
190+
AggregatorConfig(BiLevelAggregator,
202191
{
203192
"column_names": ["usage_completion"],
204-
"first_groupby": ["data_point_id", "High-level domain_copy"],
193+
"first_groupby": ["data_point_id", "High-level domain"],
205194
"second_groupby": "High-level domain",
206-
"filename_base": "UsageCompletion_GroupBy_High-level_domain_AllRuns",
195+
"filename_base": "UsageCompletion_GroupBy_High-level_domain_AllRuns",
196+
"agg_fn": "mean"
207197
}),
208198
],
209199
output_dir=os.path.join(self.log_dir, "eval_report"),
@@ -241,60 +231,56 @@ def configure_pipeline(
241231
DataReader,
242232
{
243233
"path": os.path.join(self.posteval_data_post_processing_comp.output_dir, "transformed_data.jsonl"),
244-
"format": ".jsonl",
245-
"transform": SequenceTransform(
246-
[
247-
CopyColumn(
248-
column_name_src="data_point_id",
249-
column_name_dst="data_point_id_copy",
250-
),
251-
]
252-
)
234+
"format": ".jsonl"
253235
},
254236
),
255237
aggregator_configs=[
256238
# the first three reports aggregate results by data_point_id and take the best out of N
257239
AggregatorConfig(
258-
BiLevelMaxAggregator,
240+
BiLevelAggregator,
259241
{
260242
"column_names": [
261243
"ExactMatch_result_numeric"
262244
],
263245
"first_groupby": "data_point_id",
264246
"filename_base": "ExactMatch_BestOfN",
247+
"agg_fn": "max"
265248
},
266249
),
267250
AggregatorConfig(
268-
BiLevelMaxAggregator,
251+
BiLevelAggregator,
269252
{
270253
"column_names": [
271254
"ExactMatch_result_numeric"
272255
],
273-
"first_groupby": "data_point_id_copy",
256+
"first_groupby": "data_point_id",
274257
"second_groupby": "Subdomain",
275258
"filename_base": "ExactMatch_BestOfN_GroupBy_Subdomain",
259+
"agg_fn": "max"
276260
},
277261
),
278262
AggregatorConfig(
279-
BiLevelMaxAggregator,
263+
BiLevelAggregator,
280264
{
281265
"column_names": [
282266
"ExactMatch_result_numeric"
283267
],
284-
"first_groupby": "data_point_id_copy",
268+
"first_groupby": "data_point_id",
285269
"second_groupby": "High-level domain",
286270
"filename_base": "ExactMatch_BestOfN_GroupBy_High-level_domain",
271+
"agg_fn": "max"
287272
},
288273
),
289274
# aggregates results by data_point_id and takes the sum of usage for completion tokens
290275
AggregatorConfig(
291-
BiLevelSumAggregator,
276+
BiLevelAggregator,
292277
{
293278
"column_names": [
294279
"usage_completion"
295280
],
296281
"first_groupby": "data_point_id",
297282
"filename_base": "UsageCompletion_BestOfN",
283+
"agg_fn": "sum"
298284
},
299285
),
300286
],

0 commit comments

Comments
 (0)