|
23 | 23 | from eureka_ml_insights.metrics import (
|
24 | 24 | CountAggregator,
|
25 | 25 | ExactMatch,
|
26 |
| - BiLevelMaxAggregator, |
27 |
| - BiLevelCountAggregator, |
28 |
| - BiLevelAverageAggregator, |
29 |
| - BiLevelSumAggregator |
| 26 | + BiLevelAggregator, |
| 27 | + BiLevelCountAggregator |
30 | 28 | )
|
31 | 29 |
|
32 | 30 | from eureka_ml_insights.configs import(
|
@@ -113,17 +111,6 @@ def configure_pipeline(
|
113 | 111 | column_name_src="model_output",
|
114 | 112 | column_name_dst="raw_model_output",
|
115 | 113 | ),
|
116 |
| - # these columns are currently copied so that they can be used in the bilevel aggregators |
117 |
| - # if they are not copied, then the bilevel aggregator used in sub category reports (e.g. subdomain) |
118 |
| - # will find the column name ambiguous |
119 |
| - CopyColumn( |
120 |
| - column_name_src="Subdomain", |
121 |
| - column_name_dst="Subdomain_copy", |
122 |
| - ), |
123 |
| - CopyColumn( |
124 |
| - column_name_src="High-level domain", |
125 |
| - column_name_dst="High-level domain_copy", |
126 |
| - ), |
127 | 114 | RegexTransform(
|
128 | 115 | columns="model_output",
|
129 | 116 | prompt_pattern=r"Final Answer: (\w)(?=\s|\W|$)",
|
@@ -171,39 +158,42 @@ def configure_pipeline(
|
171 | 158 | AggregatorConfig(BiLevelCountAggregator,
|
172 | 159 | {
|
173 | 160 | "column_names": ["ExactMatch_result"],
|
174 |
| - "first_groupby": ["data_repeat_id", "Subdomain_copy"], |
| 161 | + "first_groupby": ["data_repeat_id", "Subdomain"], |
175 | 162 | "second_groupby": "Subdomain",
|
176 | 163 | "filename_base": "ExactMatch_GroupBy_Subdomain_AllRuns",
|
177 | 164 | "normalize": True
|
178 | 165 | }),
|
179 | 166 | AggregatorConfig(BiLevelCountAggregator,
|
180 | 167 | {
|
181 | 168 | "column_names": ["ExactMatch_result"],
|
182 |
| - "first_groupby": ["data_repeat_id", "High-level domain_copy"], |
| 169 | + "first_groupby": ["data_repeat_id", "High-level domain"], |
183 | 170 | "second_groupby": "High-level domain",
|
184 | 171 | "filename_base": "ExactMatch_GroupBy_High-level_domain_AllRuns",
|
185 | 172 | "normalize": True
|
186 | 173 | }),
|
187 | 174 | # three similar reports for average completion usage
|
188 |
| - AggregatorConfig(BiLevelAverageAggregator, |
| 175 | + AggregatorConfig(BiLevelAggregator, |
189 | 176 | {
|
190 | 177 | "column_names": ["usage_completion"],
|
191 | 178 | "first_groupby": "data_point_id",
|
192 | 179 | "filename_base": "UsageCompletion_AllRuns",
|
| 180 | + "agg_fn": "mean" |
193 | 181 | }),
|
194 |
| - AggregatorConfig(BiLevelAverageAggregator, |
| 182 | + AggregatorConfig(BiLevelAggregator, |
195 | 183 | {
|
196 | 184 | "column_names": ["usage_completion"],
|
197 |
| - "first_groupby": ["data_point_id", "Subdomain_copy"], |
| 185 | + "first_groupby": ["data_point_id", "Subdomain"], |
198 | 186 | "second_groupby": "Subdomain",
|
199 | 187 | "filename_base": "UsageCompletion_GroupBy_Subdomain_AllRuns",
|
| 188 | + "agg_fn": "mean" |
200 | 189 | }),
|
201 |
| - AggregatorConfig(BiLevelAverageAggregator, |
| 190 | + AggregatorConfig(BiLevelAggregator, |
202 | 191 | {
|
203 | 192 | "column_names": ["usage_completion"],
|
204 |
| - "first_groupby": ["data_point_id", "High-level domain_copy"], |
| 193 | + "first_groupby": ["data_point_id", "High-level domain"], |
205 | 194 | "second_groupby": "High-level domain",
|
206 |
| - "filename_base": "UsageCompletion_GroupBy_High-level_domain_AllRuns", |
| 195 | + "filename_base": "UsageCompletion_GroupBy_High-level_domain_AllRuns", |
| 196 | + "agg_fn": "mean" |
207 | 197 | }),
|
208 | 198 | ],
|
209 | 199 | output_dir=os.path.join(self.log_dir, "eval_report"),
|
@@ -241,60 +231,56 @@ def configure_pipeline(
|
241 | 231 | DataReader,
|
242 | 232 | {
|
243 | 233 | "path": os.path.join(self.posteval_data_post_processing_comp.output_dir, "transformed_data.jsonl"),
|
244 |
| - "format": ".jsonl", |
245 |
| - "transform": SequenceTransform( |
246 |
| - [ |
247 |
| - CopyColumn( |
248 |
| - column_name_src="data_point_id", |
249 |
| - column_name_dst="data_point_id_copy", |
250 |
| - ), |
251 |
| - ] |
252 |
| - ) |
| 234 | + "format": ".jsonl" |
253 | 235 | },
|
254 | 236 | ),
|
255 | 237 | aggregator_configs=[
|
256 | 238 | # the first three reports aggregate results by data_point_id and take the best out of N
|
257 | 239 | AggregatorConfig(
|
258 |
| - BiLevelMaxAggregator, |
| 240 | + BiLevelAggregator, |
259 | 241 | {
|
260 | 242 | "column_names": [
|
261 | 243 | "ExactMatch_result_numeric"
|
262 | 244 | ],
|
263 | 245 | "first_groupby": "data_point_id",
|
264 | 246 | "filename_base": "ExactMatch_BestOfN",
|
| 247 | + "agg_fn": "max" |
265 | 248 | },
|
266 | 249 | ),
|
267 | 250 | AggregatorConfig(
|
268 |
| - BiLevelMaxAggregator, |
| 251 | + BiLevelAggregator, |
269 | 252 | {
|
270 | 253 | "column_names": [
|
271 | 254 | "ExactMatch_result_numeric"
|
272 | 255 | ],
|
273 |
| - "first_groupby": "data_point_id_copy", |
| 256 | + "first_groupby": "data_point_id", |
274 | 257 | "second_groupby": "Subdomain",
|
275 | 258 | "filename_base": "ExactMatch_BestOfN_GroupBy_Subdomain",
|
| 259 | + "agg_fn": "max" |
276 | 260 | },
|
277 | 261 | ),
|
278 | 262 | AggregatorConfig(
|
279 |
| - BiLevelMaxAggregator, |
| 263 | + BiLevelAggregator, |
280 | 264 | {
|
281 | 265 | "column_names": [
|
282 | 266 | "ExactMatch_result_numeric"
|
283 | 267 | ],
|
284 |
| - "first_groupby": "data_point_id_copy", |
| 268 | + "first_groupby": "data_point_id", |
285 | 269 | "second_groupby": "High-level domain",
|
286 | 270 | "filename_base": "ExactMatch_BestOfN_GroupBy_High-level_domain",
|
| 271 | + "agg_fn": "max" |
287 | 272 | },
|
288 | 273 | ),
|
289 | 274 | # aggregates results by data_point_id and takes the sum of usage for completion tokens
|
290 | 275 | AggregatorConfig(
|
291 |
| - BiLevelSumAggregator, |
| 276 | + BiLevelAggregator, |
292 | 277 | {
|
293 | 278 | "column_names": [
|
294 | 279 | "usage_completion"
|
295 | 280 | ],
|
296 | 281 | "first_groupby": "data_point_id",
|
297 | 282 | "filename_base": "UsageCompletion_BestOfN",
|
| 283 | + "agg_fn": "sum" |
298 | 284 | },
|
299 | 285 | ),
|
300 | 286 | ],
|
|
0 commit comments