Skip to content

Commit 688141b

Browse files
committed
Evaluation scripts can also write generated samples to files
1 parent cbefbbf commit 688141b

9 files changed

Lines changed: 413 additions & 51 deletions

File tree

keys_values/evaluation/evaluator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Optional, List, Union
14+
from typing import Any, Dict, Optional, List, Union, Tuple
1515

1616
import torch
1717

@@ -129,7 +129,8 @@ def __call__(
129129
model: LongContextInferenceModel,
130130
prompts: torch.Tensor,
131131
targets: List[TargetType],
132-
) -> Dict[str, torch.Tensor]:
132+
return_samples: bool = False,
133+
) -> Tuple[Dict[str, torch.Tensor], Optional[List[str]]]:
133134
"""
134135
Computes metric values for data case `(input_ids, targets)`. The
135136
metrics to be computed are in `metrics`.
@@ -141,11 +142,15 @@ def __call__(
141142
targets: List of targets of length `batch_size`. Each entry is a
142143
string or list of strings. Some metrics allow for lists of
143144
strings, others require a single string
145+
return_samples: If `True`, we also return a list of generated
146+
sequences (of length `batch_size`)
144147
145148
Returns:
146149
Dictionary with entries `{name: values}`, where `name in self.metrics`
147150
and `values.shape = (batch_size,)`, the metric values for each
148151
entry in the batch.
152+
If `return_samples == True`, we also return a list of generated
153+
sequences.
149154
150155
"""
151156
assert prompts.ndim == 2
@@ -186,4 +191,4 @@ def __call__(
186191
device=prompts.device,
187192
)
188193
for metric in self.metrics
189-
}
194+
}, (outputs if return_samples else None)

keys_values/evaluation/tasks.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from filelock import FileLock, Timeout
1515
from pathlib import Path
1616
import re
17-
from typing import List, Dict, Any, Optional, Iterable, Tuple
17+
from typing import List, Dict, Any, Optional, Iterable, Tuple, Literal
1818

1919
from keys_values.data.base import (
2020
LIT_MODEL_FNAME,
@@ -25,8 +25,6 @@
2525

2626
EVAL_METRICS_FNAME = "eval/eval_metrics_{}.csv"
2727

28-
EVAL_METRICS_GLOB = EVAL_METRICS_FNAME.replace("{}", "*")
29-
3028
REGEX_TASKNAME = re.compile(r"step-[0-9]{6}|final")
3129

3230
_REQUIRED_FILES = [
@@ -57,12 +55,17 @@ def __init__(
5755
model_type: str,
5856
tasks: Optional[List[str]] = None,
5957
collect_results: bool = False,
58+
eval_metrics_filename: Optional[str] = None,
6059
):
6160
if isinstance(out_dir, str):
6261
out_dir = Path(out_dir)
6362
self._out_dir = out_dir
6463
self.model_type = model_type
6564
self._tasks = tasks.copy() if tasks is not None else None
65+
if eval_metrics_filename is None:
66+
eval_metrics_filename = EVAL_METRICS_FNAME
67+
self._eval_metrics_filename = eval_metrics_filename
68+
self._eval_metrics_glob = eval_metrics_filename.replace("{}", "*")
6669
self._init_task_names(collect_results)
6770

6871
def _init_task_names(self, collect_results: bool):
@@ -100,9 +103,8 @@ def _init_task_names(self, collect_results: bool):
100103
elif self._num_result_files(path) == 0:
101104
raise ValueError(f"{path} contains no evaluation result files")
102105

103-
@staticmethod
104-
def _num_result_files(path: Path) -> int:
105-
return len(list(path.glob(EVAL_METRICS_GLOB)))
106+
def _num_result_files(self, path: Path) -> int:
107+
return len(list(path.glob(self._eval_metrics_glob)))
106108

107109
@property
108110
def tasks(self) -> List[str]:
@@ -127,37 +129,44 @@ def check_complete(task_path: Path, model_type: str) -> bool:
127129

128130
def eval_result_files(
129131
self,
130-
return_incompletes: bool = False,
132+
mode: Literal["non-lock", "lock", "all"] = "non-lock",
131133
) -> Iterable[Tuple[str, List[Path]]]:
132134
"""
133135
Args:
134-
return_incompletes: If `True`, we return the complete lock files.
135-
Defaults to `False`, so lock files are filtered out.
136+
mode: For "non-lock", we return complete files (not locks). For
137+
"lock", we return incomplete lock files. For "all", we
138+
return all files.
136139
Yields:
137140
`(task_name, result_file_paths)`, where `result_file_paths`
138141
is list of paths of evaluation result files for this task name.
139-
These files are filtered to not contain incomplete lock files.
140-
But if `return_incompletes == True`, only incomplete files are
141-
returned.
142+
This list is filtered depending on `mode`.
142143
143144
"""
145+
choices = ("non-lock", "lock", "all")
146+
if mode not in choices:
147+
raise ValueError(f"Invalid mode = {mode}, must be in {choices}")
144148
for task_name in self._tasks:
145149
result_file_paths = self._filter_incomplete_files(
146-
(self._out_dir / task_name).glob(EVAL_METRICS_GLOB),
147-
return_incompletes=return_incompletes,
150+
(self._out_dir / task_name).glob(self._eval_metrics_glob),
151+
mode=mode,
148152
)
149153
if result_file_paths:
150154
yield task_name, result_file_paths
151155

152156
@staticmethod
153157
def _filter_incomplete_files(
154158
paths: Iterable[Path],
155-
return_incompletes: bool = False,
159+
mode: Literal["non-lock", "lock", "all"],
156160
) -> List[Path]:
157161
result = []
162+
return_all = mode == "all"
163+
return_incompletes = mode == "lock"
158164
for path in paths:
159165
with path.open("r") as fp:
160-
if fp.readline().startswith(FILE_LOCK_TEXT) == return_incompletes:
166+
if (
167+
return_all
168+
or fp.readline().startswith(FILE_LOCK_TEXT) == return_incompletes
169+
):
161170
result.append(path)
162171
return result
163172

@@ -172,11 +181,19 @@ class EvaluationWithTasksHelper:
172181
dataloader we use.
173182
"""
174183

175-
def __init__(self, out_dir: Path, tag: Optional[str] = None):
184+
def __init__(
185+
self,
186+
out_dir: Path,
187+
tag: Optional[str] = None,
188+
eval_metrics_filename: Optional[str] = None,
189+
):
176190
self._out_dir = out_dir
177191
if tag is None:
178192
tag = ""
179193
self._tag = tag
194+
if eval_metrics_filename is None:
195+
eval_metrics_filename = EVAL_METRICS_FNAME
196+
self._eval_metrics_filename = eval_metrics_filename
180197

181198
def evaluation_metrics_path(self, batch: Dict[str, Any]) -> Path:
182199
"""
@@ -197,7 +214,7 @@ def evaluation_metrics_path(self, batch: Dict[str, Any]) -> Path:
197214
f"batch[{TASK_NAME}] = {task}."
198215
)
199216
suffix = self._tag + str(orig_idxs[0])
200-
fname = EVAL_METRICS_FNAME.format(suffix)
217+
fname = self._eval_metrics_filename.format(suffix)
201218
return self._out_dir / task / fname
202219

203220
def get_lock(self, batch: Dict[str, Any]) -> Optional[Path]:

keys_values/finetune/longcontext_eval.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def setup(
3636
use_sample_metric: bool = True,
3737
sample_metric_max_generated_tokens: int = 20,
3838
sample_metric_kwargs: Optional[Dict[str, Any]] = None,
39+
num_store_generated_samples: Optional[int] = None,
40+
skip_eval: bool = False,
3941
) -> None:
4042
"""Evaluate a range of model checkpoints on a test set
4143
@@ -101,6 +103,15 @@ def setup(
101103
for sample-based metric evaluation
102104
sample_metric_kwargs: Keyword arguments for token sampling (params
103105
can be "temperature", "top_k", "top_p")
106+
num_store_generated_samples: If given and positive, we write files
107+
containing the generated sequences along with SFT targets and raw
108+
targets. These files are written alongside metric files, using the
109+
same naming convention. They are written for the initial test set
110+
batches, until `num_store_generated_samples` cases are covered
111+
(rounded up to a multiple of `batch_size`). Must have
112+
`use_sample_metric == True`.
113+
skip_eval: If `True`, we skip evaluations and only write files related
114+
to `num_store_generated_samples`.
104115
105116
"""
106117
entry = {
@@ -124,4 +135,6 @@ def setup(
124135
use_sample_metric=use_sample_metric,
125136
sample_metric_max_generated_tokens=sample_metric_max_generated_tokens,
126137
sample_metric_kwargs=sample_metric_kwargs,
138+
num_store_generated_samples=num_store_generated_samples,
139+
skip_eval=skip_eval,
127140
)

0 commit comments

Comments
 (0)