Skip to content

Commit f9af3bd

Browse files
committed
New evaluation script which iterates over several setups (#106)
1 parent ce3bccb commit f9af3bd

8 files changed

Lines changed: 987 additions & 593 deletions

File tree

README.md

Lines changed: 183 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ pytest test/
7272
### FlashInfer CUDA Extension
7373

7474
The library uses vendored FlashInfer CUDA kernels combined with a Triton
75-
score-sum kernel for attention weight computation support KV cache policies
76-
such as H2O. This must be built after installing the package.
75+
score-sum kernel for attention weight computation in order to support KV cache
76+
policies such as H2O. This must be built after installing the package.
7777

7878
**Prerequisites:**
7979
* NVIDIA GPU with compute capability >= 8.0 (A100, H100, etc.)
@@ -94,7 +94,7 @@ To verify the build worked:
9494
pytest test/test_flashinfer_wrapper.py
9595
```
9696

97-
### Installatiop with CUDA 12.8
97+
### Installation with CUDA 12.8
9898

9999
The following installation works if you are bound to use CUDA 12.8. Note that
100100
this includes the FlashInfer extension.
@@ -111,10 +111,6 @@ rm constraints.txt
111111
pip install 'litgpt[all,test,extra]'
112112
cd keys_values
113113
pip install -e .
114-
```
115-
116-
Then:
117-
```bash
118114
python build_ext.py
119115
```
120116

@@ -125,12 +121,29 @@ This example runs on a single `Nvidia A 100` GPU with 40 GB of RAM.
125121

126122
```bash
127123
cd ${KEYS_VALUES_PATH}
128-
python3 keys_values/__main__.py finetune_long_lora Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-torch-quantized8 --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
124+
python3 keys_values/__main__.py finetune_long_lora \
125+
Qwen/Qwen2.5-0.5B \
126+
--out_dir /home/ubuntu/out/finetune/longcontext_lora \
127+
--data LongBenchV2 \
128+
--data.max_seq_length 100000 \
129+
--data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data \
130+
--head_model seq_classification_on_logits \
131+
--precision bf16-true \
132+
--verbose some \
133+
--kv_cache.name h2o-torch-quantized8 \
134+
--kv_cache.cache_length 16384 \
135+
--kv_cache.chunk_size 1024 \
136+
--train.save_interval 10 \
137+
--train.micro_batch_size 4 \
138+
--eval.interval 10
129139
```
130140

131141
What is happening here?
132142

133143
* `finetune_long_lora`: Default fine-tuning script for `LoRA`
144+
* `--out_dir`: Path for results. For example, checkpoints are written to
145+
directories `step-000010`, `step-000020`, ... below this path (due to
146+
`--train.save_interval 10`, checkpoints are written every 10 iterations).
134147
* `--data LongBenchV2`: Using the `LongBenchV2` benchmark with its data loaders.
135148
`--data.max_seq_length 100000` filters for sequences less than 100k tokens.
136149
`--data.metadata_dir` stores metadata information about the dataset, so this
@@ -152,13 +165,13 @@ What is happening here?
152165
which case we use gradient averaging.
153166

154167
If you use an AWS `p4d.24xlarge` instance, you can use 8 A 100 GPUs in parallel.
155-
At present, we support data parallelism via
156-
[Lightning Fabric](https://lightning.ai/docs/fabric/stable/). Modifying the
157-
CLI command above like runs training with an effective batch size of 32:
168+
Modifying the CLI command above like runs training with an effective batch size
169+
of 32:
158170

159171
```bash
160172
cd ${KEYS_VALUES_PATH}
161-
python3 keys_values/__main__.py finetune_long_lora Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --devices 8 --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-default --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
173+
python3 keys_values/__main__.py finetune_long_lora \
174+
Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --devices 8 --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-default --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
162175
```
163176

164177
Here, `--devices 8 --train.micro_batch_size 4` sets `train.global_batch_size`
@@ -169,7 +182,10 @@ to 32, the per-device batch size to 4, and asks to use 8 devices.
169182
* Try increasing `kv_cache.cache_length` and `kv_cache.chunk_size`. They have
170183
the [largest impact on speed and accuracy](#cache-length-and-chunk-size).
171184
* Play around with different [cache policies](#kv-cache-policy-and-configuration),
172-
or try to use buffer quantization (both by `kv_cache.name`).
185+
or try to use buffer quantization (both by `kv_cache.name`). For example,
186+
`--kv_cache.name h2o-torch-quantized8` halves the amount of GPU memory
187+
required for KV cache buffers and may even run faster (our code offloads
188+
KV cache buffers to CPU, which runs faster for less memory).
173189
* Play round with different datasets. `--data Helmet` gives access to datasets
174190
from the Helmet benchmark.
175191
* Try using `finetune_offload_lora` instead of `finetune_long_lora`, and
@@ -196,14 +212,14 @@ for fast scaled dot product attention (SDPA).
196212

197213
Having said that, we are aware that this is not competitive with leading
198214
inference libraries, such as [vLLM](https://github.com/vllm-project/vllm) or
199-
[SGLang](https://github.com/sgl-project/sglang). Our library lacks support
215+
[SGLang](https://github.com/sgl-project/sglang). Our library currently lacks support
200216
for multi-device strategies (context parallelism in particular) as well as
201217
many crucial optimizations.
202218

203219
We are providing a better support of advanced KV cache strategies like
204220
[Heavy Hitter Oracle](https://arxiv.org/abs/2306.14048) than vLLM. One reason
205221
why sparse attention techniques like H2O are used less often than they deserve,
206-
is that they run slowly due to poor support of low-level SDPA kernels. We provide
222+
is that they run slowly due to poor support from low-level SDPA kernels. We provide
207223
a modification of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
208224
kernels with which H2O becomes competitive. Stay tuned for more efforts in this
209225
direction.
@@ -222,8 +238,10 @@ being able to run inference with long contexts without having to spend a lot
222238
of money on many GPUs, and we think that advanced selective KV cache policies
223239
are an important direction towards this goal.
224240

225-
A script for evaluating fine-tuned models on long context test data is provided
226-
in [finetune/longcontext_eval.py](./keys_values/finetune/longcontext_eval.py).
241+
Scripts for evaluating fine-tuned models on long context test data are provided
242+
in [finetune/longcontext_eval.py](./keys_values/finetune/longcontext_eval.py) and
243+
[finetune/longcontext_eval_ext.py](./keys_values/finetune/longcontext_eval_ext.py),
244+
more details are given [below](#evaluation-of-fine-tuned-models).
227245

228246

229247
## Long Context Fine-tuning
@@ -1024,6 +1042,154 @@ For a healthy run, you should see:
10241042
In particular, GPU memory should not build up across several snapshots
10251043

10261044

1045+
## Evaluation of Fine-tuned Models
1046+
1047+
Our library provides scripts to evaluate fine-tuned models on test datasets.
1048+
While during fine-tuning, a metric is evaluated on a validation set, this is
1049+
usually just a part of the development set (which is split into training and
1050+
validation set). In general, we also need to compute metrics which are different
1051+
from the loss which drives the training. Some naming:
1052+
1053+
* A **setup** is given by a base model, configuration, and dataset. The
1054+
dataset consists of a development and a test set. For fine-tuning, the
1055+
development set is typically split into training and validation set. The
1056+
model is fine-tuned on the training set, while a validation metric is
1057+
periodically computed on the validation set (every `--eval.interval`
1058+
iterations). Moreover, **checkpoints** are stored periodically (every
1059+
`--train.save_interval` iterations). Use the validation metric values for
1060+
early stopping, or to decide which checkpoints to use for test set
1061+
evaluation.
1062+
* A **task** is a tuple of setup and checkpoint. For each evaluation metric,
1063+
the goal is to compute one value per task.
1064+
* The test dataset for a setup is partitioned into batches (these are
1065+
micro-batches in the naming used above). The evaluation scripts iterate over
1066+
tuples `(task, batch)`. They can be run on any number of devices in parallel,
1067+
jobs are assigned on a first-come-first-saved basis. The outcome for a job is
1068+
a CSV file containing the metric values for data cases in a batch. These can
1069+
be aggregated into metric values over the whole test set.
1070+
1071+
The following scripts can be used for evaluation:
1072+
1073+
* [longcontext_eval](./keys_values/finetune/longcontext_eval.py): Short `eval_long`.
1074+
Run evaluation for a single setup.
1075+
* [longcontext_eval_ext](./keys_values/finetune/longcontext_eval_ext.py): Short
1076+
`eval_long_ext`. Run evaluation for several setups, each with its own tasks.
1077+
1078+
### Evaluation for Single Setup: `eval_long`
1079+
1080+
Example:
1081+
```bash
1082+
python keys_values/__main__.py eval_long \
1083+
/home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5 \
1084+
--model_type lora \
1085+
--verbose some \
1086+
--devices 2 \
1087+
--batch_size 2 \
1088+
--use_sample_metric True \
1089+
--sample_metric_max_generated_tokens 20 \
1090+
--tasks "step-000310,final,step-000410"
1091+
```
1092+
1093+
* `/home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5` is the
1094+
`--out_dir` path passed to the training run for the setup.
1095+
* `--model_type`: Can be "lora" or "full".
1096+
* `--devices`: How many devices should the evaluation script use?
1097+
* `--batch_size`: Micro batch size for evaluation. Overrides
1098+
`eval.micro_batch_size` from the configuration of the setup.
1099+
* `--use_sample_metric`: Some datasets define a sample-based evaluation metric.
1100+
If `True`, this one is computed. Otherwise, the training loss function is
1101+
computed (but on the test set).
1102+
* `--tasks`: Name of tasks (or checkpoints) for which evaluation is to run. If
1103+
this is not given, the script runs evaluation for all checkpoints detected
1104+
under the `out_dir`.
1105+
1106+
Note that dataset and configurations are taken from the hyperparameters stored
1107+
with checkpoints (these must be the same for all checkpoints). Some of them can
1108+
be overwritten:
1109+
1110+
* `--kv_cache.*`: [KVCacheArgs](./keys_values/finetune/args.py#L51). Allows to
1111+
use a different KV cache policy or different parameters for evaluation than
1112+
what has been used for fine-tuning.
1113+
* `--sdpa.*`: [SDPAArgs](./keys_values/finetune/args.py#L555). Allows to
1114+
use a different SDPA kernel or different parameters for evaluation than
1115+
what has been used for fine-tuning.
1116+
* `--lora_dropout`: Overwrites `lora.dropout`.
1117+
1118+
The evaluation script works like this:
1119+
1120+
* On each device, a list of all jobs (i.e., tuples `(task, batch)`) is created.
1121+
* These jobs are worked on in parallel, on a first-come-first-served basis. The
1122+
outcome for a job is a file `<out_dir>/<task>/eval/eval_metrics_<no>.csv`, a
1123+
CSV file with one row per case in a batch. Here, `<no>` is the index of the
1124+
first case in the batch. For our example above, this could be
1125+
`.../h2o_lr5/step-000310/eval_metrics_256.csv`.
1126+
* Jobs are iterated over in a nested loop, tasks in outer, batches in inner loop.
1127+
* A worker locks a job by writing the result file, but with bogus content. Once
1128+
the job is finished, this content is overwritten by the results.
1129+
* Whenever a worker switches to a new task, the respective checkpoint is loaded
1130+
there.
1131+
1132+
Once an evaluation has finished, result files for all jobs have been written.
1133+
The script [collect_eval_results](./keys_values/scripts/collect_eval_results.py)
1134+
can be used to collect all results into a single CSV file. Currently, this script
1135+
has to be adapted to work for different setups. If a setup is stored out `out_dir`,
1136+
the outcome of this script is a file `<out_dir>/eval_metrics_all.csv`, which
1137+
collects all individual results. Moreover, the average evaluation metric per task
1138+
is printed for each task. The script also outputs the number of jobs which were
1139+
read for each task. If some of these numbers are too low, this may be due to lock
1140+
files which have not properly been removed for a failed worker. In this case,
1141+
clean up the lock files (see below) and run the script again: it will compute only
1142+
the missing jobs.
1143+
1144+
When workers are stopped before they can finish all jobs, there are in general
1145+
left-over lock files. Simply restarting the evaluation risks that metrics are not
1146+
evaluated for these jobs. In such a case, you obtain average metric values which
1147+
can be wrong. Use the script [cleanup_evaluation](./keys_values/scripts/cleanup_evaluation.py)
1148+
in order to remove left-over lock files. Currently, this script has to be adapted
1149+
to work for different setups.
1150+
1151+
### Evaluation for Several Setups: `eval_long_ext`
1152+
1153+
Example:
1154+
```bash
1155+
python keys_values/__main__.py eval_long_ext \
1156+
./test_eval.yaml \
1157+
--verbose some \
1158+
--devices 2 \
1159+
--batch_size 2 \
1160+
--use_sample_metric True \
1161+
--sample_metric_max_generated_tokens 20
1162+
```
1163+
1164+
Here, `test_eval.yaml` is a YAML file describing the setups and the tasks for setup.
1165+
For example:
1166+
```yaml
1167+
- out_dir: /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5
1168+
model_type: lora
1169+
eval_tasks:
1170+
- step-000450
1171+
- step-000010
1172+
- final
1173+
- out_dir: /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_nq_64k/slr_lr5
1174+
model_type: lora
1175+
eval_tasks:
1176+
- step-000260
1177+
- step-000010
1178+
- final
1179+
- out_dir: /home/ubuntu/out/finetune/full/qwen3_4b/helmet_hotpot_qa_32k/h2o_lr5
1180+
model_type: full
1181+
eval_tasks:
1182+
- step-000420
1183+
- step-000010
1184+
- final
1185+
```
1186+
1187+
A setup entry can also contain `kv_cache` and `sdpa` fields, being nested
1188+
dictionaries. If an entry does not contain a `eval_tasks` field, then all
1189+
checkpoints found there are tasks. Jobs are iterated over in a nested loop,
1190+
outer over setups, middle over tasks, inner over batches.
1191+
1192+
10271193
## Implementing New KV Cache Policies
10281194

10291195
Currently supported KV cache policies are detailed

keys_values/__main__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from litgpt.__main__ import PARSER_DATA as PARSER_DATA_LITGPT
2323

2424
from keys_values.finetune.longcontext_eval import setup as eval_long_fn
25+
from keys_values.finetune.longcontext_eval_ext import setup as eval_long_ext_fn
2526
from keys_values.finetune.longcontext_full import setup as finetune_long_full_fn
2627
from keys_values.finetune.longcontext_lora import setup as finetune_long_lora_fn
2728
from keys_values.finetune.longcon_offload_full import setup as finetune_offload_full_fn
@@ -35,9 +36,10 @@
3536

3637
PARSER_DATA = {
3738
**PARSER_DATA_LITGPT,
39+
"eval_long": eval_long_fn,
40+
"eval_long_ext": eval_long_ext_fn,
3841
"finetune_long_full": finetune_long_full_fn,
3942
"finetune_long_lora": finetune_long_lora_fn,
40-
"eval_long": eval_long_fn,
4143
"finetune_offload_full": finetune_offload_full_fn,
4244
"finetune_offload_lora": finetune_offload_lora_fn,
4345
}

keys_values/evaluation/tasks.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
_REQUIRED_FILES = [
3333
"hyperparameters.yaml",
3434
"model_config.yaml",
35-
"tokenizer.json",
36-
"tokenizer_config.json",
3735
]
3836

3937
REQUIRED_FILES = {
@@ -60,6 +58,8 @@ def __init__(
6058
tasks: Optional[List[str]] = None,
6159
collect_results: bool = False,
6260
):
61+
if isinstance(out_dir, str):
62+
out_dir = Path(out_dir)
6363
self._out_dir = out_dir
6464
self.model_type = model_type
6565
self._tasks = tasks.copy() if tasks is not None else None
@@ -125,27 +125,39 @@ def check_complete(task_path: Path, model_type: str) -> bool:
125125
else:
126126
return True
127127

128-
def eval_result_files(self) -> Iterable[Tuple[str, List[Path]]]:
128+
def eval_result_files(
129+
self,
130+
return_incompletes: bool = False,
131+
) -> Iterable[Tuple[str, List[Path]]]:
129132
"""
133+
Args:
134+
return_incompletes: If `True`, we return the complete lock files.
135+
Defaults to `False`, so lock files are filtered out.
130136
Yields:
131137
`(task_name, result_file_paths)`, where `result_file_paths`
132138
is list of paths of evaluation result files for this task name.
133139
These files are filtered to not contain incomplete lock files.
140+
But if `return_incompletes == True`, only incomplete files are
141+
returned.
134142
135143
"""
136144
for task_name in self._tasks:
137145
result_file_paths = self._filter_incomplete_files(
138-
(self._out_dir / task_name).glob(EVAL_METRICS_GLOB)
146+
(self._out_dir / task_name).glob(EVAL_METRICS_GLOB),
147+
return_incompletes=return_incompletes,
139148
)
140149
if result_file_paths:
141150
yield task_name, result_file_paths
142151

143152
@staticmethod
144-
def _filter_incomplete_files(paths: Iterable[Path]) -> List[Path]:
153+
def _filter_incomplete_files(
154+
paths: Iterable[Path],
155+
return_incompletes: bool = False,
156+
) -> List[Path]:
145157
result = []
146158
for path in paths:
147159
with path.open("r") as fp:
148-
if not fp.readline().startswith(FILE_LOCK_TEXT):
160+
if fp.readline().startswith(FILE_LOCK_TEXT) == return_incompletes:
149161
result.append(path)
150162
return result
151163

0 commit comments

Comments
 (0)