Skip to content

Commit 0619b6c

Browse files
committed
Some fixes related to evaluation
1 parent e5525d4 commit 0619b6c

5 files changed

Lines changed: 93 additions & 29 deletions

File tree

keys_values/finetune/longcontext_eval_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def eval_for_setup_internal(
700700
result_path = eval_metrics_path
701701
else:
702702
eval_fname = eval_metrics_path.stem
703-
suffix = "_".split(eval_fname)[-1]
703+
suffix = eval_fname.split("_")[-1]
704704
result_path = (
705705
eval_metrics_path.parent
706706
/ GENERATED_SAMPLES_FILENAME.format(suffix)

keys_values/scripts/cleanup_gen_samples.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ def main(control_file: Path):
3434

3535

3636
if __name__ == "__main__":
37-
# dataset_size = "64k"
38-
dataset_size = "128k"
37+
dataset_size = "64k"
38+
# dataset_size = "128k"
39+
# control_file = (
40+
# Path.home() / "sync" / "keys_values" / f"eval_inst1_{dataset_size}_h2o.yaml"
41+
# )
3942
control_file = (
40-
Path.home() / "sync" / "keys_values" / f"eval_inst1_{dataset_size}.yaml"
43+
Path.home() / "git" / "keys_values" / f"eval_inst2_3_{dataset_size}_h2o.yaml"
4144
)
42-
# control_file = Path.home() / "git" / "keys_values" / f"eval_inst2_3_{dataset_size}.yaml"
4345
main(control_file)

keys_values/scripts/collect_eval_results.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def main(
5959
print(f"Total number of records: {len(all_data)}")
6060
if all_data:
6161
combined_path = out_dir / EVAL_METRICS_ALL_FILENAME
62+
if combined_path.exists():
63+
combined_path.unlink()
6264
with open(combined_path, "w") as fp:
6365
writer = csv.writer(fp, delimiter=",")
6466
writer.writerow(column_names)
@@ -88,6 +90,8 @@ def main(
8890
"qh2onorm_4gpu_cs2048_lr5",
8991
"lr_4gpu_cs1024_lr5",
9092
"h2o_4gpu_cs1024_lr5",
93+
"slr_4gpu_cs1024_lr5",
94+
"h2onorm_4gpu_cs1024_lr5",
9195
]
9296
model_type = "lora"
9397
if mode == "collect":

keys_values/scripts/collect_gen_samples.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def main(
5656
print(f"Total number of records: {num_total}")
5757
if all_data:
5858
combined_path = out_dir / GENERATED_SAMPLES_ALL_FILENAME
59+
if combined_path.exists():
60+
combined_path.unlink()
5961
with open(combined_path, "w") as fp:
6062
yaml.safe_dump(all_data, fp)
6163

@@ -82,6 +84,8 @@ def main(
8284
"qh2onorm_4gpu_cs2048_lr5",
8385
"lr_4gpu_cs1024_lr5",
8486
"h2o_4gpu_cs1024_lr5",
87+
"slr_4gpu_cs1024_lr5",
88+
"h2onorm_4gpu_cs1024_lr5",
8589
]
8690
model_type = "lora"
8791
if mode == "collect":

keys_values/scripts/create_result_table.py

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,28 @@ def _sort_entries(entries):
2929
return non_fin + [(st, v) for st, v in entries if st == "fin"]
3030

3131

32-
def main(datasets, cases, result_path):
32+
# We ran evaluations for more than the task for which evaluation loss was
33+
# lowest. With this predicate, we filter for the winning tasks only.
34+
def _filter_dataset_case(
35+
dataset: str,
36+
case: str,
37+
task: str,
38+
) -> bool:
39+
if dataset.endswith("_128k"):
40+
# Not yet implemented!!
41+
return True
42+
# Filter out error in results:
43+
if task == "380" and case.startswith("lr_") and dataset.startswith("helmet_trivia"):
44+
return False
45+
if task == "fin":
46+
# Only those for which "fin" is the only result
47+
return dataset.startswith("helmet_pop") and (
48+
case.startswith("slr_") or case.startswith("h2onorm_")
49+
)
50+
return task != "010"
51+
52+
53+
def main(datasets, cases, result_path, final_table: bool):
3354
base_path = result_path.parent
3455
col_labels = [
3556
d.removeprefix("helmet_").rsplit("_", 1)[0].replace("_", r"\_")
@@ -48,23 +69,49 @@ def main(datasets, cases, result_path):
4869
else:
4970
df = pd.read_csv(csv_path)
5071
avg = df.groupby("task")["sub_exact_match"].mean()
51-
row.append(_sort_entries([(_short_task(t), v) for t, v in avg.items()]))
72+
row.append(
73+
_sort_entries(
74+
[
75+
(_short_task(t), v)
76+
for t, v in avg.items()
77+
if not final_table
78+
or _filter_dataset_case(dataset, case_key, _short_task(t))
79+
]
80+
)
81+
)
5282
table.append(row)
5383

54-
# Each dataset gets 2 sub-columns (l for task, r for value) for cross-cell alignment.
84+
# - final_table == False:
85+
# Each dataset gets 2 sub-columns (l for task, r for value) for cross-cell alignment.
86+
# - final_table == True:
87+
# Each dataset column features a single entry (r for value)
5588
N = len(datasets)
56-
col_spec = "l" + "lr" * N
57-
tex_lines = [
58-
r"\begin{tabular}{" + col_spec + "}",
59-
r"\noalign{\smallskip}\hline\noalign{\smallskip}",
60-
" & ".join([""] + [r"\multicolumn{2}{c}{" + lbl + "}" for lbl in col_labels])
61-
+ r" \\",
62-
r"\noalign{\smallskip}\hline\hline\noalign{\smallskip}",
63-
]
64-
for i, case_label in enumerate(case_labels):
65-
row_entries = table[i]
89+
if final_table:
90+
col_spec = "l" + "r" * N
91+
tex_lines = [
92+
r"\begin{tabular}{" + col_spec + "}",
93+
r"\noalign{\smallskip}\hline\noalign{\smallskip}",
94+
" & ".join([""] + col_labels) + r" \\",
95+
r"\noalign{\smallskip}\hline\hline\noalign{\smallskip}",
96+
]
97+
else:
98+
col_spec = "l" + "lr" * N
99+
tex_lines = [
100+
r"\begin{tabular}{" + col_spec + "}",
101+
r"\noalign{\smallskip}\hline\noalign{\smallskip}",
102+
" & ".join(
103+
[""] + [r"\multicolumn{2}{c}{" + lbl + "}" for lbl in col_labels]
104+
)
105+
+ r" \\",
106+
r"\noalign{\smallskip}\hline\hline\noalign{\smallskip}",
107+
]
108+
for case_label, row_entries in zip(case_labels, table):
66109
max_rows = max((len(e) for e in row_entries), default=0)
67110
max_rows = max(max_rows, 1)
111+
if final_table and max_rows > 1:
112+
print(
113+
f"{case_label}: max_rows = {max_rows} > 1, must not happen for final_table=True"
114+
)
68115
for k in range(max_rows):
69116
if k == 0 and max_rows > 1:
70117
label_cell = r"\multirow{" + str(max_rows) + r"}{*}{" + case_label + "}"
@@ -76,23 +123,28 @@ def main(datasets, cases, result_path):
76123
for entries in row_entries:
77124
if k < len(entries):
78125
st, v = entries[k]
79-
cells.append(r"{\small " + st + r":}")
126+
if not final_table:
127+
cells.append(r"{\small " + st + r":}")
80128
cells.append(r"{\small\!" + f"{v * 100:.2f}" + "}")
81129
else:
82-
cells.append("")
130+
if not final_table:
131+
cells.append("")
83132
cells.append("")
84133
tex_lines.append(" & ".join(cells) + r" \\")
85134
tex_lines.append(r"\noalign{\smallskip}\hline\noalign{\smallskip}")
86135
tex_lines.append(r"\end{tabular}")
87136

137+
if result_path.exists():
138+
result_path.unlink()
88139
result_path.write_text("\n".join(tex_lines) + "\n")
89140

90141

142+
# TODO: If `final_table = True`, do not print the task ID, just the metric value
91143
if __name__ == "__main__":
92144
base_path = Path.home() / "out/finetune/neurips_exp/lora/qwen3_4b"
93145

94-
# dataset_size = "64k"
95-
dataset_size = "128k"
146+
dataset_size = "64k"
147+
# dataset_size = "128k"
96148
datasets = [
97149
f"helmet_nq_{dataset_size}",
98150
f"helmet_trivia_qa_{dataset_size}",
@@ -101,14 +153,16 @@ def main(datasets, cases, result_path):
101153
]
102154
cases = [
103155
("lr_4gpu_cs2048_lr5", "lr_2048"),
104-
("h2o_4gpu_cs2048_lr5", "h2o_2048"),
105156
("slr_4gpu_cs2048_lr5", "slr_2048"),
106-
# ("qh2o_4gpu_cs2048_lr5", "qh2o_2048"),
107-
# ("h2onorm_4gpu_cs2048_lr5", "h2onorm_2048"),
108-
# ("qh2onorm_4gpu_cs2048_lr5", "qh2onorm_2048"),
109-
# ("lr_4gpu_cs1024_lr5", "lr_1024"),
110-
# ("h2o_4gpu_cs1024_lr5", "h2o_1024"),
157+
("h2o_4gpu_cs2048_lr5", "h2o_2048"),
158+
("qh2o_4gpu_cs2048_lr5", "qh2o_2048"),
159+
("h2onorm_4gpu_cs2048_lr5", "h2onorm_2048"),
160+
("qh2onorm_4gpu_cs2048_lr5", "qh2onorm_2048"),
161+
("lr_4gpu_cs1024_lr5", "lr_1024"),
162+
("h2o_4gpu_cs1024_lr5", "h2o_1024"),
111163
]
112164
result_path = base_path / f"results_{dataset_size}.tex"
165+
# final_table = False
166+
final_table = True
113167

114-
main(datasets, cases, result_path)
168+
main(datasets, cases, result_path, final_table)

0 commit comments

Comments
 (0)