@@ -72,8 +72,8 @@ pytest test/
7272### FlashInfer CUDA Extension
7373
7474The library uses vendored FlashInfer CUDA kernels combined with a Triton
75- score-sum kernel for attention weight computation support KV cache policies
76- such as H2O. This must be built after installing the package.
75+ score-sum kernel for attention weight computation in order to support KV cache
76+ policies such as H2O. This must be built after installing the package.
7777
7878** Prerequisites:**
7979* NVIDIA GPU with compute capability >= 8.0 (A100, H100, etc.)
@@ -94,7 +94,7 @@ To verify the build worked:
9494pytest test/test_flashinfer_wrapper.py
9595```
9696
97- ### Installatiop with CUDA 12.8
97+ ### Installation with CUDA 12.8
9898
9999The following installation works if you are bound to use CUDA 12.8. Note that
100100this includes the FlashInfer extension.
@@ -111,10 +111,6 @@ rm constraints.txt
111111pip install ' litgpt[all,test,extra]'
112112cd keys_values
113113pip install -e .
114- ```
115-
116- Then:
117- ``` bash
118114python build_ext.py
119115```
120116
@@ -125,12 +121,29 @@ This example runs on a single `Nvidia A 100` GPU with 40 GB of RAM.
125121
126122``` bash
127123cd ${KEYS_VALUES_PATH}
128- python3 keys_values/__main__.py finetune_long_lora Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-torch-quantized8 --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
124+ python3 keys_values/__main__.py finetune_long_lora \
125+ Qwen/Qwen2.5-0.5B \
126+ --out_dir /home/ubuntu/out/finetune/longcontext_lora \
127+ --data LongBenchV2 \
128+ --data.max_seq_length 100000 \
129+ --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data \
130+ --head_model seq_classification_on_logits \
131+ --precision bf16-true \
132+ --verbose some \
133+ --kv_cache.name h2o-torch-quantized8 \
134+ --kv_cache.cache_length 16384 \
135+ --kv_cache.chunk_size 1024 \
136+ --train.save_interval 10 \
137+ --train.micro_batch_size 4 \
138+ --eval.interval 10
129139```
130140
131141What is happening here?
132142
133143* ` finetune_long_lora ` : Default fine-tuning script for ` LoRA `
144+ * ` --out_dir ` : Path for results. For example, checkpoints are written to
145+ directories ` step-000010 ` , ` step-000020 ` , ... below this path (due to
146+ ` --train.save_interval 10 ` , checkpoints are written every 10 iterations).
134147* ` --data LongBenchV2 ` : Using the ` LongBenchV2 ` benchmark with its data loaders.
135148 ` --data.max_seq_length 100000 ` filters for sequences less than 100k tokens.
136149 ` --data.metadata_dir ` stores metadata information about the dataset, so this
@@ -152,13 +165,13 @@ What is happening here?
152165 which case we use gradient averaging.
153166
154167If you use an AWS ` p4d.24xlarge ` instance, you can use 8 A 100 GPUs in parallel.
155- At present, we support data parallelism via
156- [ Lightning Fabric] ( https://lightning.ai/docs/fabric/stable/ ) . Modifying the
157- CLI command above like runs training with an effective batch size of 32:
168+ Modifying the CLI command above like runs training with an effective batch size
169+ of 32:
158170
159171``` bash
160172cd ${KEYS_VALUES_PATH}
161- python3 keys_values/__main__.py finetune_long_lora Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --devices 8 --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-default --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
173+ python3 keys_values/__main__.py finetune_long_lora \
174+ Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --devices 8 --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-default --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
162175```
163176
164177Here, ` --devices 8 --train.micro_batch_size 4 ` sets ` train.global_batch_size `
@@ -169,7 +182,10 @@ to 32, the per-device batch size to 4, and asks to use 8 devices.
169182* Try increasing ` kv_cache.cache_length ` and ` kv_cache.chunk_size ` . They have
170183 the [ largest impact on speed and accuracy] ( #cache-length-and-chunk-size ) .
171184* Play around with different [ cache policies] ( #kv-cache-policy-and-configuration ) ,
172- or try to use buffer quantization (both by ` kv_cache.name ` ).
185+ or try to use buffer quantization (both by ` kv_cache.name ` ). For example,
186+ ` --kv_cache.name h2o-torch-quantized8 ` halves the amount of GPU memory
187+ required for KV cache buffers and may even run faster (our code offloads
188+ KV cache buffers to CPU, which runs faster for less memory).
173189* Play round with different datasets. ` --data Helmet ` gives access to datasets
174190 from the Helmet benchmark.
175191* Try using ` finetune_offload_lora ` instead of ` finetune_long_lora ` , and
@@ -196,14 +212,14 @@ for fast scaled dot product attention (SDPA).
196212
197213Having said that, we are aware that this is not competitive with leading
198214inference libraries, such as [ vLLM] ( https://github.com/vllm-project/vllm ) or
199- [ SGLang] ( https://github.com/sgl-project/sglang ) . Our library lacks support
215+ [ SGLang] ( https://github.com/sgl-project/sglang ) . Our library currently lacks support
200216for multi-device strategies (context parallelism in particular) as well as
201217many crucial optimizations.
202218
203219We are providing a better support of advanced KV cache strategies like
204220[ Heavy Hitter Oracle] ( https://arxiv.org/abs/2306.14048 ) than vLLM. One reason
205221why sparse attention techniques like H2O are used less often than they deserve,
206- is that they run slowly due to poor support of low-level SDPA kernels. We provide
222+ is that they run slowly due to poor support from low-level SDPA kernels. We provide
207223a modification of the [ FlashInfer] ( https://github.com/flashinfer-ai/flashinfer )
208224kernels with which H2O becomes competitive. Stay tuned for more efforts in this
209225direction.
@@ -222,8 +238,10 @@ being able to run inference with long contexts without having to spend a lot
222238of money on many GPUs, and we think that advanced selective KV cache policies
223239are an important direction towards this goal.
224240
225- A script for evaluating fine-tuned models on long context test data is provided
226- in [ finetune/longcontext_eval.py] ( ./keys_values/finetune/longcontext_eval.py ) .
241+ Scripts for evaluating fine-tuned models on long context test data are provided
242+ in [ finetune/longcontext_eval.py] ( ./keys_values/finetune/longcontext_eval.py ) and
243+ [ finetune/longcontext_eval_ext.py] ( ./keys_values/finetune/longcontext_eval_ext.py ) ,
244+ more details are given [ below] ( #evaluation-of-fine-tuned-models ) .
227245
228246
229247## Long Context Fine-tuning
@@ -1024,6 +1042,154 @@ For a healthy run, you should see:
10241042 In particular, GPU memory should not build up across several snapshots
10251043
10261044
1045+ ## Evaluation of Fine-tuned Models
1046+
1047+ Our library provides scripts to evaluate fine-tuned models on test datasets.
1048+ While during fine-tuning, a metric is evaluated on a validation set, this is
1049+ usually just a part of the development set (which is split into training and
1050+ validation set). In general, we also need to compute metrics which are different
1051+ from the loss which drives the training. Some naming:
1052+
1053+ * A ** setup** is given by a base model, configuration, and dataset. The
1054+ dataset consists of a development and a test set. For fine-tuning, the
1055+ development set is typically split into training and validation set. The
1056+ model is fine-tuned on the training set, while a validation metric is
1057+ periodically computed on the validation set (every ` --eval.interval `
1058+ iterations). Moreover, ** checkpoints** are stored periodically (every
1059+ ` --train.save_interval ` iterations). Use the validation metric values for
1060+ early stopping, or to decide which checkpoints to use for test set
1061+ evaluation.
1062+ * A ** task** is a tuple of setup and checkpoint. For each evaluation metric,
1063+ the goal is to compute one value per task.
1064+ * The test dataset for a setup is partitioned into batches (these are
1065+ micro-batches in the naming used above). The evaluation scripts iterate over
1066+ tuples ` (task, batch) ` . They can be run on any number of devices in parallel,
1067+ jobs are assigned on a first-come-first-saved basis. The outcome for a job is
1068+ a CSV file containing the metric values for data cases in a batch. These can
1069+ be aggregated into metric values over the whole test set.
1070+
1071+ The following scripts can be used for evaluation:
1072+
1073+ * [ longcontext_eval] ( ./keys_values/finetune/longcontext_eval.py ) : Short ` eval_long ` .
1074+ Run evaluation for a single setup.
1075+ * [ longcontext_eval_ext] ( ./keys_values/finetune/longcontext_eval_ext.py ) : Short
1076+ ` eval_long_ext ` . Run evaluation for several setups, each with its own tasks.
1077+
1078+ ### Evaluation for Single Setup: ` eval_long `
1079+
1080+ Example:
1081+ ``` bash
1082+ python keys_values/__main__.py eval_long \
1083+ /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5 \
1084+ --model_type lora \
1085+ --verbose some \
1086+ --devices 2 \
1087+ --batch_size 2 \
1088+ --use_sample_metric True \
1089+ --sample_metric_max_generated_tokens 20 \
1090+ --tasks " step-000310,final,step-000410"
1091+ ```
1092+
1093+ * ` /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5 ` is the
1094+ ` --out_dir ` path passed to the training run for the setup.
1095+ * ` --model_type ` : Can be "lora" or "full".
1096+ * ` --devices ` : How many devices should the evaluation script use?
1097+ * ` --batch_size ` : Micro batch size for evaluation. Overrides
1098+ ` eval.micro_batch_size ` from the configuration of the setup.
1099+ * ` --use_sample_metric ` : Some datasets define a sample-based evaluation metric.
1100+ If ` True ` , this one is computed. Otherwise, the training loss function is
1101+ computed (but on the test set).
1102+ * ` --tasks ` : Name of tasks (or checkpoints) for which evaluation is to run. If
1103+ this is not given, the script runs evaluation for all checkpoints detected
1104+ under the ` out_dir ` .
1105+
1106+ Note that dataset and configurations are taken from the hyperparameters stored
1107+ with checkpoints (these must be the same for all checkpoints). Some of them can
1108+ be overwritten:
1109+
1110+ * ` --kv_cache.* ` : [ KVCacheArgs] ( ./keys_values/finetune/args.py#L51 ) . Allows to
1111+ use a different KV cache policy or different parameters for evaluation than
1112+ what has been used for fine-tuning.
1113+ * ` --sdpa.* ` : [ SDPAArgs] ( ./keys_values/finetune/args.py#L555 ) . Allows to
1114+ use a different SDPA kernel or different parameters for evaluation than
1115+ what has been used for fine-tuning.
1116+ * ` --lora_dropout ` : Overwrites ` lora.dropout ` .
1117+
1118+ The evaluation script works like this:
1119+
1120+ * On each device, a list of all jobs (i.e., tuples ` (task, batch) ` ) is created.
1121+ * These jobs are worked on in parallel, on a first-come-first-served basis. The
1122+ outcome for a job is a file ` <out_dir>/<task>/eval/eval_metrics_<no>.csv ` , a
1123+ CSV file with one row per case in a batch. Here, ` <no> ` is the index of the
1124+ first case in the batch. For our example above, this could be
1125+ ` .../h2o_lr5/step-000310/eval_metrics_256.csv ` .
1126+ * Jobs are iterated over in a nested loop, tasks in outer, batches in inner loop.
1127+ * A worker locks a job by writing the result file, but with bogus content. Once
1128+ the job is finished, this content is overwritten by the results.
1129+ * Whenever a worker switches to a new task, the respective checkpoint is loaded
1130+ there.
1131+
1132+ Once an evaluation has finished, result files for all jobs have been written.
1133+ The script [ collect_eval_results] ( ./keys_values/scripts/collect_eval_results.py )
1134+ can be used to collect all results into a single CSV file. Currently, this script
1135+ has to be adapted to work for different setups. If a setup is stored out ` out_dir ` ,
1136+ the outcome of this script is a file ` <out_dir>/eval_metrics_all.csv ` , which
1137+ collects all individual results. Moreover, the average evaluation metric per task
1138+ is printed for each task. The script also outputs the number of jobs which were
1139+ read for each task. If some of these numbers are too low, this may be due to lock
1140+ files which have not properly been removed for a failed worker. In this case,
1141+ clean up the lock files (see below) and run the script again: it will compute only
1142+ the missing jobs.
1143+
1144+ When workers are stopped before they can finish all jobs, there are in general
1145+ left-over lock files. Simply restarting the evaluation risks that metrics are not
1146+ evaluated for these jobs. In such a case, you obtain average metric values which
1147+ can be wrong. Use the script [ cleanup_evaluation] ( ./keys_values/scripts/cleanup_evaluation.py )
1148+ in order to remove left-over lock files. Currently, this script has to be adapted
1149+ to work for different setups.
1150+
1151+ ### Evaluation for Several Setups: ` eval_long_ext `
1152+
1153+ Example:
1154+ ``` bash
1155+ python keys_values/__main__.py eval_long_ext \
1156+ ./test_eval.yaml \
1157+ --verbose some \
1158+ --devices 2 \
1159+ --batch_size 2 \
1160+ --use_sample_metric True \
1161+ --sample_metric_max_generated_tokens 20
1162+ ```
1163+
1164+ Here, ` test_eval.yaml ` is a YAML file describing the setups and the tasks for setup.
1165+ For example:
1166+ ``` yaml
1167+ - out_dir : /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5
1168+ model_type : lora
1169+ eval_tasks :
1170+ - step-000450
1171+ - step-000010
1172+ - final
1173+ - out_dir : /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_nq_64k/slr_lr5
1174+ model_type : lora
1175+ eval_tasks :
1176+ - step-000260
1177+ - step-000010
1178+ - final
1179+ - out_dir : /home/ubuntu/out/finetune/full/qwen3_4b/helmet_hotpot_qa_32k/h2o_lr5
1180+ model_type : full
1181+ eval_tasks :
1182+ - step-000420
1183+ - step-000010
1184+ - final
1185+ ` ` `
1186+
1187+ A setup entry can also contain ` kv_cache` and `sdpa` fields, being nested
1188+ dictionaries. If an entry does not contain a `eval_tasks` field, then all
1189+ checkpoints found there are tasks. Jobs are iterated over in a nested loop,
1190+ outer over setups, middle over tasks, inner over batches.
1191+
1192+
10271193# # Implementing New KV Cache Policies
10281194
10291195Currently supported KV cache policies are detailed
0 commit comments