Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 183 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ pytest test/
### FlashInfer CUDA Extension

The library uses vendored FlashInfer CUDA kernels combined with a Triton
score-sum kernel for attention weight computation support KV cache policies
such as H2O. This must be built after installing the package.
score-sum kernel for attention weight computation in order to support KV cache
policies such as H2O. This must be built after installing the package.

**Prerequisites:**
* NVIDIA GPU with compute capability >= 8.0 (A100, H100, etc.)
Expand All @@ -94,7 +94,7 @@ To verify the build worked:
pytest test/test_flashinfer_wrapper.py
```

### Installatiop with CUDA 12.8
### Installation with CUDA 12.8

The following installation works if you are bound to use CUDA 12.8. Note that
this includes the FlashInfer extension.
Expand All @@ -111,10 +111,6 @@ rm constraints.txt
pip install 'litgpt[all,test,extra]'
cd keys_values
pip install -e .
```

Then:
```bash
python build_ext.py
```

Expand All @@ -125,12 +121,29 @@ This example runs on a single `Nvidia A 100` GPU with 40 GB of RAM.

```bash
cd ${KEYS_VALUES_PATH}
python3 keys_values/__main__.py finetune_long_lora Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-torch-quantized8 --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
python3 keys_values/__main__.py finetune_long_lora \
Qwen/Qwen2.5-0.5B \
--out_dir /home/ubuntu/out/finetune/longcontext_lora \
--data LongBenchV2 \
--data.max_seq_length 100000 \
--data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data \
--head_model seq_classification_on_logits \
--precision bf16-true \
--verbose some \
--kv_cache.name h2o-torch-quantized8 \
--kv_cache.cache_length 16384 \
--kv_cache.chunk_size 1024 \
--train.save_interval 10 \
--train.micro_batch_size 4 \
--eval.interval 10
```

What is happening here?

* `finetune_long_lora`: Default fine-tuning script for `LoRA`
* `--out_dir`: Path for results. For example, checkpoints are written to
directories `step-000010`, `step-000020`, ... below this path (due to
`--train.save_interval 10`, checkpoints are written every 10 iterations).
* `--data LongBenchV2`: Using the `LongBenchV2` benchmark with its data loaders.
`--data.max_seq_length 100000` filters for sequences less than 100k tokens.
`--data.metadata_dir` stores metadata information about the dataset, so this
Expand All @@ -152,13 +165,13 @@ What is happening here?
which case we use gradient averaging.

If you use an AWS `p4d.24xlarge` instance, you can use 8 A 100 GPUs in parallel.
At present, we support data parallelism via
[Lightning Fabric](https://lightning.ai/docs/fabric/stable/). Modifying the
CLI command above like runs training with an effective batch size of 32:
Modifying the CLI command above like runs training with an effective batch size
of 32:

```bash
cd ${KEYS_VALUES_PATH}
python3 keys_values/__main__.py finetune_long_lora Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --devices 8 --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-default --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
python3 keys_values/__main__.py finetune_long_lora \
Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/longcontext_lora --devices 8 --data LongBenchV2 --data.max_seq_length 100000 --data.metadata_dir /home/ubuntu/out/finetune/longcontext_lora/data --head_model seq_classification_on_logits --precision bf16-true --verbose some --kv_cache.name h2o-default --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --train.save_interval 10 --train.micro_batch_size 4 --eval.interval 10
```

Here, `--devices 8 --train.micro_batch_size 4` sets `train.global_batch_size`
Expand All @@ -169,7 +182,10 @@ to 32, the per-device batch size to 4, and asks to use 8 devices.
* Try increasing `kv_cache.cache_length` and `kv_cache.chunk_size`. They have
the [largest impact on speed and accuracy](#cache-length-and-chunk-size).
* Play around with different [cache policies](#kv-cache-policy-and-configuration),
or try to use buffer quantization (both by `kv_cache.name`).
or try to use buffer quantization (both by `kv_cache.name`). For example,
`--kv_cache.name h2o-torch-quantized8` halves the amount of GPU memory
required for KV cache buffers and may even run faster (our code offloads
KV cache buffers to CPU, which runs faster for less memory).
* Play round with different datasets. `--data Helmet` gives access to datasets
from the Helmet benchmark.
* Try using `finetune_offload_lora` instead of `finetune_long_lora`, and
Expand All @@ -196,14 +212,14 @@ for fast scaled dot product attention (SDPA).

Having said that, we are aware that this is not competitive with leading
inference libraries, such as [vLLM](https://github.com/vllm-project/vllm) or
[SGLang](https://github.com/sgl-project/sglang). Our library lacks support
[SGLang](https://github.com/sgl-project/sglang). Our library currently lacks support
for multi-device strategies (context parallelism in particular) as well as
many crucial optimizations.

We are providing a better support of advanced KV cache strategies like
[Heavy Hitter Oracle](https://arxiv.org/abs/2306.14048) than vLLM. One reason
why sparse attention techniques like H2O are used less often than they deserve,
is that they run slowly due to poor support of low-level SDPA kernels. We provide
is that they run slowly due to poor support from low-level SDPA kernels. We provide
a modification of the [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
kernels with which H2O becomes competitive. Stay tuned for more efforts in this
direction.
Expand All @@ -222,8 +238,10 @@ being able to run inference with long contexts without having to spend a lot
of money on many GPUs, and we think that advanced selective KV cache policies
are an important direction towards this goal.

A script for evaluating fine-tuned models on long context test data is provided
in [finetune/longcontext_eval.py](./keys_values/finetune/longcontext_eval.py).
Scripts for evaluating fine-tuned models on long context test data are provided
in [finetune/longcontext_eval.py](./keys_values/finetune/longcontext_eval.py) and
[finetune/longcontext_eval_ext.py](./keys_values/finetune/longcontext_eval_ext.py),
more details are given [below](#evaluation-of-fine-tuned-models).


## Long Context Fine-tuning
Expand Down Expand Up @@ -1024,6 +1042,154 @@ For a healthy run, you should see:
In particular, GPU memory should not build up across several snapshots


## Evaluation of Fine-tuned Models

Our library provides scripts to evaluate fine-tuned models on test datasets.
While during fine-tuning, a metric is evaluated on a validation set, this is
usually just a part of the development set (which is split into training and
validation set). In general, we also need to compute metrics which are different
from the loss which drives the training. Some naming:

* A **setup** is given by a base model, configuration, and dataset. The
dataset consists of a development and a test set. For fine-tuning, the
development set is typically split into training and validation set. The
model is fine-tuned on the training set, while a validation metric is
periodically computed on the validation set (every `--eval.interval`
iterations). Moreover, **checkpoints** are stored periodically (every
`--train.save_interval` iterations). Use the validation metric values for
early stopping, or to decide which checkpoints to use for test set
evaluation.
* A **task** is a tuple of setup and checkpoint. For each evaluation metric,
the goal is to compute one value per task.
* The test dataset for a setup is partitioned into batches (these are
micro-batches in the naming used above). The evaluation scripts iterate over
tuples `(task, batch)`. They can be run on any number of devices in parallel,
jobs are assigned on a first-come-first-saved basis. The outcome for a job is
a CSV file containing the metric values for data cases in a batch. These can
be aggregated into metric values over the whole test set.

The following scripts can be used for evaluation:

* [longcontext_eval](./keys_values/finetune/longcontext_eval.py): Short `eval_long`.
Run evaluation for a single setup.
* [longcontext_eval_ext](./keys_values/finetune/longcontext_eval_ext.py): Short
`eval_long_ext`. Run evaluation for several setups, each with its own tasks.

### Evaluation for Single Setup: `eval_long`

Example:
```bash
python keys_values/__main__.py eval_long \
/home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5 \
--model_type lora \
--verbose some \
--devices 2 \
--batch_size 2 \
--use_sample_metric True \
--sample_metric_max_generated_tokens 20 \
--tasks "step-000310,final,step-000410"
```

* `/home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5` is the
`--out_dir` path passed to the training run for the setup.
* `--model_type`: Can be "lora" or "full".
* `--devices`: How many devices should the evaluation script use?
* `--batch_size`: Micro batch size for evaluation. Overrides
`eval.micro_batch_size` from the configuration of the setup.
* `--use_sample_metric`: Some datasets define a sample-based evaluation metric.
If `True`, this one is computed. Otherwise, the training loss function is
computed (but on the test set).
* `--tasks`: Name of tasks (or checkpoints) for which evaluation is to run. If
this is not given, the script runs evaluation for all checkpoints detected
under the `out_dir`.

Note that dataset and configurations are taken from the hyperparameters stored
with checkpoints (these must be the same for all checkpoints). Some of them can
be overwritten:

* `--kv_cache.*`: [KVCacheArgs](./keys_values/finetune/args.py#L51). Allows to
use a different KV cache policy or different parameters for evaluation than
what has been used for fine-tuning.
* `--sdpa.*`: [SDPAArgs](./keys_values/finetune/args.py#L555). Allows to
use a different SDPA kernel or different parameters for evaluation than
what has been used for fine-tuning.
* `--lora_dropout`: Overwrites `lora.dropout`.

The evaluation script works like this:

* On each device, a list of all jobs (i.e., tuples `(task, batch)`) is created.
* These jobs are worked on in parallel, on a first-come-first-served basis. The
outcome for a job is a file `<out_dir>/<task>/eval/eval_metrics_<no>.csv`, a
CSV file with one row per case in a batch. Here, `<no>` is the index of the
first case in the batch. For our example above, this could be
`.../h2o_lr5/step-000310/eval_metrics_256.csv`.
* Jobs are iterated over in a nested loop, tasks in outer, batches in inner loop.
* A worker locks a job by writing the result file, but with bogus content. Once
the job is finished, this content is overwritten by the results.
* Whenever a worker switches to a new task, the respective checkpoint is loaded
there.

Once an evaluation has finished, result files for all jobs have been written.
The script [collect_eval_results](./keys_values/scripts/collect_eval_results.py)
can be used to collect all results into a single CSV file. Currently, this script
has to be adapted to work for different setups. If a setup is stored out `out_dir`,
the outcome of this script is a file `<out_dir>/eval_metrics_all.csv`, which
collects all individual results. Moreover, the average evaluation metric per task
is printed for each task. The script also outputs the number of jobs which were
read for each task. If some of these numbers are too low, this may be due to lock
files which have not properly been removed for a failed worker. In this case,
clean up the lock files (see below) and run the script again: it will compute only
the missing jobs.

When workers are stopped before they can finish all jobs, there are in general
left-over lock files. Simply restarting the evaluation risks that metrics are not
evaluated for these jobs. In such a case, you obtain average metric values which
can be wrong. Use the script [cleanup_evaluation](./keys_values/scripts/cleanup_evaluation.py)
in order to remove left-over lock files. Currently, this script has to be adapted
to work for different setups.

### Evaluation for Several Setups: `eval_long_ext`

Example:
```bash
python keys_values/__main__.py eval_long_ext \
./test_eval.yaml \
--verbose some \
--devices 2 \
--batch_size 2 \
--use_sample_metric True \
--sample_metric_max_generated_tokens 20
```

Here, `test_eval.yaml` is a YAML file describing the setups and the tasks for setup.
For example:
```yaml
- out_dir: /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_hotpot_qa_64k/h2o_lr5
model_type: lora
eval_tasks:
- step-000450
- step-000010
- final
- out_dir: /home/ubuntu/out/finetune/lora/qwen3_4b/helmet_nq_64k/slr_lr5
model_type: lora
eval_tasks:
- step-000260
- step-000010
- final
- out_dir: /home/ubuntu/out/finetune/full/qwen3_4b/helmet_hotpot_qa_32k/h2o_lr5
model_type: full
eval_tasks:
- step-000420
- step-000010
- final
```

A setup entry can also contain `kv_cache` and `sdpa` fields, being nested
dictionaries. If an entry does not contain a `eval_tasks` field, then all
checkpoints found there are tasks. Jobs are iterated over in a nested loop,
outer over setups, middle over tasks, inner over batches.


## Implementing New KV Cache Policies

Currently supported KV cache policies are detailed
Expand Down
4 changes: 3 additions & 1 deletion keys_values/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from litgpt.__main__ import PARSER_DATA as PARSER_DATA_LITGPT

from keys_values.finetune.longcontext_eval import setup as eval_long_fn
from keys_values.finetune.longcontext_eval_ext import setup as eval_long_ext_fn
from keys_values.finetune.longcontext_full import setup as finetune_long_full_fn
from keys_values.finetune.longcontext_lora import setup as finetune_long_lora_fn
from keys_values.finetune.longcon_offload_full import setup as finetune_offload_full_fn
Expand All @@ -35,9 +36,10 @@

PARSER_DATA = {
**PARSER_DATA_LITGPT,
"eval_long": eval_long_fn,
"eval_long_ext": eval_long_ext_fn,
"finetune_long_full": finetune_long_full_fn,
"finetune_long_lora": finetune_long_lora_fn,
"eval_long": eval_long_fn,
"finetune_offload_full": finetune_offload_full_fn,
"finetune_offload_lora": finetune_offload_lora_fn,
}
Expand Down
24 changes: 18 additions & 6 deletions keys_values/evaluation/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
_REQUIRED_FILES = [
"hyperparameters.yaml",
"model_config.yaml",
"tokenizer.json",
"tokenizer_config.json",
]

REQUIRED_FILES = {
Expand All @@ -60,6 +58,8 @@ def __init__(
tasks: Optional[List[str]] = None,
collect_results: bool = False,
):
if isinstance(out_dir, str):
out_dir = Path(out_dir)
self._out_dir = out_dir
self.model_type = model_type
self._tasks = tasks.copy() if tasks is not None else None
Expand Down Expand Up @@ -125,27 +125,39 @@ def check_complete(task_path: Path, model_type: str) -> bool:
else:
return True

def eval_result_files(self) -> Iterable[Tuple[str, List[Path]]]:
def eval_result_files(
self,
return_incompletes: bool = False,
) -> Iterable[Tuple[str, List[Path]]]:
"""
Args:
return_incompletes: If `True`, we return the complete lock files.
Defaults to `False`, so lock files are filtered out.
Yields:
`(task_name, result_file_paths)`, where `result_file_paths`
is list of paths of evaluation result files for this task name.
These files are filtered to not contain incomplete lock files.
But if `return_incompletes == True`, only incomplete files are
returned.

"""
for task_name in self._tasks:
result_file_paths = self._filter_incomplete_files(
(self._out_dir / task_name).glob(EVAL_METRICS_GLOB)
(self._out_dir / task_name).glob(EVAL_METRICS_GLOB),
return_incompletes=return_incompletes,
)
if result_file_paths:
yield task_name, result_file_paths

@staticmethod
def _filter_incomplete_files(paths: Iterable[Path]) -> List[Path]:
def _filter_incomplete_files(
paths: Iterable[Path],
return_incompletes: bool = False,
) -> List[Path]:
result = []
for path in paths:
with path.open("r") as fp:
if not fp.readline().startswith(FILE_LOCK_TEXT):
if fp.readline().startswith(FILE_LOCK_TEXT) == return_incompletes:
result.append(path)
return result

Expand Down
Loading
Loading