@@ -29,7 +29,28 @@ def _sort_entries(entries):
2929 return non_fin + [(st , v ) for st , v in entries if st == "fin" ]
3030
3131
32- def main (datasets , cases , result_path ):
32+ # We ran evaluations for more than the task for which evaluation loss was
33+ # lowest. With this predicate, we filter for the winning tasks only.
34+ def _filter_dataset_case (
35+ dataset : str ,
36+ case : str ,
37+ task : str ,
38+ ) -> bool :
39+ if dataset .endswith ("_128k" ):
40+ # Not yet implemented!!
41+ return True
42+ # Filter out error in results:
43+ if task == "380" and case .startswith ("lr_" ) and dataset .startswith ("helmet_trivia" ):
44+ return False
45+ if task == "fin" :
46+ # Only those for which "fin" is the only result
47+ return dataset .startswith ("helmet_pop" ) and (
48+ case .startswith ("slr_" ) or case .startswith ("h2onorm_" )
49+ )
50+ return task != "010"
51+
52+
53+ def main (datasets , cases , result_path , final_table : bool ):
3354 base_path = result_path .parent
3455 col_labels = [
3556 d .removeprefix ("helmet_" ).rsplit ("_" , 1 )[0 ].replace ("_" , r"\_" )
@@ -48,23 +69,49 @@ def main(datasets, cases, result_path):
4869 else :
4970 df = pd .read_csv (csv_path )
5071 avg = df .groupby ("task" )["sub_exact_match" ].mean ()
51- row .append (_sort_entries ([(_short_task (t ), v ) for t , v in avg .items ()]))
72+ row .append (
73+ _sort_entries (
74+ [
75+ (_short_task (t ), v )
76+ for t , v in avg .items ()
77+ if not final_table
78+ or _filter_dataset_case (dataset , case_key , _short_task (t ))
79+ ]
80+ )
81+ )
5282 table .append (row )
5383
54- # Each dataset gets 2 sub-columns (l for task, r for value) for cross-cell alignment.
84+ # - final_table == False:
85+ # Each dataset gets 2 sub-columns (l for task, r for value) for cross-cell alignment.
86+ # - final_table == True:
87+ # Each dataset column features a single entry (r for value)
5588 N = len (datasets )
56- col_spec = "l" + "lr" * N
57- tex_lines = [
58- r"\begin{tabular}{" + col_spec + "}" ,
59- r"\noalign{\smallskip}\hline\noalign{\smallskip}" ,
60- " & " .join (["" ] + [r"\multicolumn{2}{c}{" + lbl + "}" for lbl in col_labels ])
61- + r" \\" ,
62- r"\noalign{\smallskip}\hline\hline\noalign{\smallskip}" ,
63- ]
64- for i , case_label in enumerate (case_labels ):
65- row_entries = table [i ]
89+ if final_table :
90+ col_spec = "l" + "r" * N
91+ tex_lines = [
92+ r"\begin{tabular}{" + col_spec + "}" ,
93+ r"\noalign{\smallskip}\hline\noalign{\smallskip}" ,
94+ " & " .join (["" ] + col_labels ) + r" \\" ,
95+ r"\noalign{\smallskip}\hline\hline\noalign{\smallskip}" ,
96+ ]
97+ else :
98+ col_spec = "l" + "lr" * N
99+ tex_lines = [
100+ r"\begin{tabular}{" + col_spec + "}" ,
101+ r"\noalign{\smallskip}\hline\noalign{\smallskip}" ,
102+ " & " .join (
103+ ["" ] + [r"\multicolumn{2}{c}{" + lbl + "}" for lbl in col_labels ]
104+ )
105+ + r" \\" ,
106+ r"\noalign{\smallskip}\hline\hline\noalign{\smallskip}" ,
107+ ]
108+ for case_label , row_entries in zip (case_labels , table ):
66109 max_rows = max ((len (e ) for e in row_entries ), default = 0 )
67110 max_rows = max (max_rows , 1 )
111+ if final_table and max_rows > 1 :
112+ print (
113+ f"{ case_label } : max_rows = { max_rows } > 1, must not happen for final_table=True"
114+ )
68115 for k in range (max_rows ):
69116 if k == 0 and max_rows > 1 :
70117 label_cell = r"\multirow{" + str (max_rows ) + r"}{*}{" + case_label + "}"
@@ -76,23 +123,28 @@ def main(datasets, cases, result_path):
76123 for entries in row_entries :
77124 if k < len (entries ):
78125 st , v = entries [k ]
79- cells .append (r"{\small " + st + r":}" )
126+ if not final_table :
127+ cells .append (r"{\small " + st + r":}" )
80128 cells .append (r"{\small\!" + f"{ v * 100 :.2f} " + "}" )
81129 else :
82- cells .append ("" )
130+ if not final_table :
131+ cells .append ("" )
83132 cells .append ("" )
84133 tex_lines .append (" & " .join (cells ) + r" \\" )
85134 tex_lines .append (r"\noalign{\smallskip}\hline\noalign{\smallskip}" )
86135 tex_lines .append (r"\end{tabular}" )
87136
137+ if result_path .exists ():
138+ result_path .unlink ()
88139 result_path .write_text ("\n " .join (tex_lines ) + "\n " )
89140
90141
142+ # TODO: If `final_table = True`, do not print the task ID, just the metric value
91143if __name__ == "__main__" :
92144 base_path = Path .home () / "out/finetune/neurips_exp/lora/qwen3_4b"
93145
94- # dataset_size = "64k"
95- dataset_size = "128k"
146+ dataset_size = "64k"
147+ # dataset_size = "128k"
96148 datasets = [
97149 f"helmet_nq_{ dataset_size } " ,
98150 f"helmet_trivia_qa_{ dataset_size } " ,
@@ -101,14 +153,16 @@ def main(datasets, cases, result_path):
101153 ]
102154 cases = [
103155 ("lr_4gpu_cs2048_lr5" , "lr_2048" ),
104- ("h2o_4gpu_cs2048_lr5" , "h2o_2048" ),
105156 ("slr_4gpu_cs2048_lr5" , "slr_2048" ),
106- # ("qh2o_4gpu_cs2048_lr5", "qh2o_2048"),
107- # ("h2onorm_4gpu_cs2048_lr5", "h2onorm_2048"),
108- # ("qh2onorm_4gpu_cs2048_lr5", "qh2onorm_2048"),
109- # ("lr_4gpu_cs1024_lr5", "lr_1024"),
110- # ("h2o_4gpu_cs1024_lr5", "h2o_1024"),
157+ ("h2o_4gpu_cs2048_lr5" , "h2o_2048" ),
158+ ("qh2o_4gpu_cs2048_lr5" , "qh2o_2048" ),
159+ ("h2onorm_4gpu_cs2048_lr5" , "h2onorm_2048" ),
160+ ("qh2onorm_4gpu_cs2048_lr5" , "qh2onorm_2048" ),
161+ ("lr_4gpu_cs1024_lr5" , "lr_1024" ),
162+ ("h2o_4gpu_cs1024_lr5" , "h2o_1024" ),
111163 ]
112164 result_path = base_path / f"results_{ dataset_size } .tex"
165+ # final_table = False
166+ final_table = True
113167
114- main (datasets , cases , result_path )
168+ main (datasets , cases , result_path , final_table )
0 commit comments