Skip to content

Commit 4123c89

Browse files
author
Safoora Yousefi
committed
col name simplification
1 parent a738891 commit 4123c89

File tree

1 file changed

+24
-43
lines changed

1 file changed

+24
-43
lines changed

eureka_ml_insights/user_configs/aime_seq.py

+24-43
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from eureka_ml_insights.data_utils import (
2020
AddColumnAndData,
2121
ColumnRename,
22+
CopyColumn,
2223
DataReader,
2324
RunPythonTransform,
2425
SamplerTransform,
@@ -33,40 +34,26 @@
3334

3435
from .aime import AIME_PIPELINE
3536

36-
DEFAULT_N_ITER = 2
37-
38-
39-
resume_from_dict = {
40-
1: "/home/sayouse/git/eureka-ml-insights/logs/AIME_SEQ_PIPELINE/2025-03-04-21-07-09.687511/student_inference_result_1/inference_result.jsonl",
41-
2: None,
42-
}
37+
DEFAULT_N_ITER = 3
38+
RESULT_COLS = [
39+
"attempt_id",
40+
"model_output",
41+
"uid",
42+
"prompt",
43+
"ground_truth",
44+
"Year",
45+
"ID",
46+
"student_extracted_answer",
47+
"verification_result"
48+
]
49+
resume_from_dict = {}
4350

4451

4552
class AIME_SEQ_PIPELINE(AIME_PIPELINE):
4653
"""This class specifies the config for running AIME benchmark on any model"""
4754

48-
def get_result_columns(self, i: int) -> list[str]:
49-
"""Get the desired result columns to be saved for the given iteration
50-
Args:
51-
i (int): The iteration number
52-
Returns:
53-
list[str]: The list of columns to be saved
54-
"""
55-
verification_cols_so_far = [f"verification_result_{j}" for j in range(1, i)]
56-
extracted_ans_cols_so_far = [f"student_extracted_answer_{j}" for j in range(1, i)]
57-
return (
58-
[
59-
"attempt_id",
60-
"model_output",
61-
"uid",
62-
"prompt",
63-
"ground_truth",
64-
"Year",
65-
"ID",
66-
]
67-
+ verification_cols_so_far
68-
+ extracted_ans_cols_so_far
69-
)
55+
56+
7057

7158
def configure_pipeline(
7259
self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any]
@@ -77,10 +64,6 @@ def configure_pipeline(
7764

7865
n_iter = kwargs.get("n_iter", DEFAULT_N_ITER)
7966

80-
self.data_processing_comp.data_reader_config.init_args["transform"].transforms.append(
81-
SamplerTransform(random_seed=40, sample_count=1)
82-
)
83-
8467
component_configs = [self.data_processing_comp]
8568
for i in range(1, n_iter + 1):
8669
# Student inference component, reads prompts from the last prompt processing component
@@ -114,9 +97,12 @@ def configure_pipeline(
11497
"transform": SequenceTransform(
11598
[
11699
# extract and verify the student answer
117-
AIMEExtractAnswer(f"model_output", f"student_extracted_answer_{i}"),
118-
MetricBasedVerifier(ExactMatch, f"student_extracted_answer_{i}"),
100+
AIMEExtractAnswer(f"model_output", f"student_extracted_answer"),
101+
MetricBasedVerifier(ExactMatch, f"student_extracted_answer"),
119102
AddColumnAndData("attempt_id", i),
103+
CopyColumn(
104+
column_name_src="model_output",
105+
column_name_dst=f"student_output")
120106
]
121107
),
122108
},
@@ -145,7 +131,7 @@ def configure_pipeline(
145131
"format": ".jsonl",
146132
},
147133
),
148-
output_data_columns=self.get_result_columns(i),
134+
output_data_columns=RESULT_COLS,
149135
output_dir=os.path.join(self.log_dir, f"last_inference_result_join_{i}"),
150136
)
151137
last_agg_dir = self.last_inference_result_join_comp.output_dir
@@ -176,12 +162,6 @@ def configure_pipeline(
176162
{
177163
"path": os.path.join(self.filtering_comp.output_dir, "transformed_data.jsonl"),
178164
"format": ".jsonl",
179-
"transform": ColumnRename(
180-
name_mapping={
181-
"verification_result": f"verification_result_{i}",
182-
"model_output": f"student_output_{i}",
183-
}
184-
),
185165
},
186166
),
187167
prompt_template_path=os.path.join(
@@ -228,8 +208,9 @@ def configure_pipeline(
228208

229209
# Pass the combined results from all iterations to the eval reporting component
230210
self.evalreporting_comp.data_reader_config.init_args["path"] = os.path.join(
231-
self.last_inference_result_join_comp.output_dir, "transformed_data.jsonl"
211+
last_agg_dir, "transformed_data.jsonl"
232212
)
213+
self.evalreporting_comp.metric_config.init_args["model_output_col"] = "student_extracted_answer"
233214

234215
component_configs.append(self.evalreporting_comp)
235216

0 commit comments

Comments
 (0)