1414from filelock import FileLock , Timeout
1515from pathlib import Path
1616import re
17- from typing import List , Dict , Any , Optional , Iterable , Tuple
17+ from typing import List , Dict , Any , Optional , Iterable , Tuple , Literal
1818
1919from keys_values .data .base import (
2020 LIT_MODEL_FNAME ,
2525
2626EVAL_METRICS_FNAME = "eval/eval_metrics_{}.csv"
2727
28- EVAL_METRICS_GLOB = EVAL_METRICS_FNAME .replace ("{}" , "*" )
29-
3028REGEX_TASKNAME = re .compile (r"step-[0-9]{6}|final" )
3129
3230_REQUIRED_FILES = [
@@ -57,12 +55,17 @@ def __init__(
5755 model_type : str ,
5856 tasks : Optional [List [str ]] = None ,
5957 collect_results : bool = False ,
58+ eval_metrics_filename : Optional [str ] = None ,
6059 ):
6160 if isinstance (out_dir , str ):
6261 out_dir = Path (out_dir )
6362 self ._out_dir = out_dir
6463 self .model_type = model_type
6564 self ._tasks = tasks .copy () if tasks is not None else None
65+ if eval_metrics_filename is None :
66+ eval_metrics_filename = EVAL_METRICS_FNAME
67+ self ._eval_metrics_filename = eval_metrics_filename
68+ self ._eval_metrics_glob = eval_metrics_filename .replace ("{}" , "*" )
6669 self ._init_task_names (collect_results )
6770
6871 def _init_task_names (self , collect_results : bool ):
@@ -100,9 +103,8 @@ def _init_task_names(self, collect_results: bool):
100103 elif self ._num_result_files (path ) == 0 :
101104 raise ValueError (f"{ path } contains no evaluation result files" )
102105
103- @staticmethod
104- def _num_result_files (path : Path ) -> int :
105- return len (list (path .glob (EVAL_METRICS_GLOB )))
106+ def _num_result_files (self , path : Path ) -> int :
107+ return len (list (path .glob (self ._eval_metrics_glob )))
106108
107109 @property
108110 def tasks (self ) -> List [str ]:
@@ -127,37 +129,44 @@ def check_complete(task_path: Path, model_type: str) -> bool:
127129
128130 def eval_result_files (
129131 self ,
130- return_incompletes : bool = False ,
132+ mode : Literal [ "non-lock" , "lock" , "all" ] = "non-lock" ,
131133 ) -> Iterable [Tuple [str , List [Path ]]]:
132134 """
133135 Args:
134- return_incompletes: If `True`, we return the complete lock files.
135- Defaults to `False`, so lock files are filtered out.
136+ mode: For "non-lock", we return complete files (not locks). For
137+ "lock", we return incomplete lock files. For "all", we
138+ return all files.
136139 Yields:
137140 `(task_name, result_file_paths)`, where `result_file_paths`
138141 is list of paths of evaluation result files for this task name.
139- These files are filtered to not contain incomplete lock files.
140- But if `return_incompletes == True`, only incomplete files are
141- returned.
142+ This list is filtered depending on `mode`.
142143
143144 """
145+ choices = ("non-lock" , "lock" , "all" )
146+ if mode not in choices :
147+ raise ValueError (f"Invalid mode = { mode } , must be in { choices } " )
144148 for task_name in self ._tasks :
145149 result_file_paths = self ._filter_incomplete_files (
146- (self ._out_dir / task_name ).glob (EVAL_METRICS_GLOB ),
147- return_incompletes = return_incompletes ,
150+ (self ._out_dir / task_name ).glob (self . _eval_metrics_glob ),
151+ mode = mode ,
148152 )
149153 if result_file_paths :
150154 yield task_name , result_file_paths
151155
152156 @staticmethod
153157 def _filter_incomplete_files (
154158 paths : Iterable [Path ],
155- return_incompletes : bool = False ,
159+ mode : Literal [ "non-lock" , "lock" , "all" ] ,
156160 ) -> List [Path ]:
157161 result = []
162+ return_all = mode == "all"
163+ return_incompletes = mode == "lock"
158164 for path in paths :
159165 with path .open ("r" ) as fp :
160- if fp .readline ().startswith (FILE_LOCK_TEXT ) == return_incompletes :
166+ if (
167+ return_all
168+ or fp .readline ().startswith (FILE_LOCK_TEXT ) == return_incompletes
169+ ):
161170 result .append (path )
162171 return result
163172
@@ -172,11 +181,19 @@ class EvaluationWithTasksHelper:
172181 dataloader we use.
173182 """
174183
175- def __init__ (self , out_dir : Path , tag : Optional [str ] = None ):
184+ def __init__ (
185+ self ,
186+ out_dir : Path ,
187+ tag : Optional [str ] = None ,
188+ eval_metrics_filename : Optional [str ] = None ,
189+ ):
176190 self ._out_dir = out_dir
177191 if tag is None :
178192 tag = ""
179193 self ._tag = tag
194+ if eval_metrics_filename is None :
195+ eval_metrics_filename = EVAL_METRICS_FNAME
196+ self ._eval_metrics_filename = eval_metrics_filename
180197
181198 def evaluation_metrics_path (self , batch : Dict [str , Any ]) -> Path :
182199 """
@@ -197,7 +214,7 @@ def evaluation_metrics_path(self, batch: Dict[str, Any]) -> Path:
197214 f"batch[{ TASK_NAME } ] = { task } ."
198215 )
199216 suffix = self ._tag + str (orig_idxs [0 ])
200- fname = EVAL_METRICS_FNAME .format (suffix )
217+ fname = self . _eval_metrics_filename .format (suffix )
201218 return self ._out_dir / task / fname
202219
203220 def get_lock (self , batch : Dict [str , Any ]) -> Optional [Path ]:
0 commit comments