Skip to content

Commit d3186e9

Browse files
committed
Fix issue with filter_sdpa_kernels; refactor; update longcontext_eval
1 parent e176346 commit d3186e9

10 files changed

Lines changed: 123 additions & 117 deletions

keys_values/attention_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
from keys_values.utils import repeat_interleave, index_to_3d
2727

2828

29+
SDPA_KERNELS_BEST_ORDERING = [
30+
SDPBackend.FLASH_ATTENTION,
31+
SDPBackend.EFFICIENT_ATTENTION,
32+
SDPBackend.CUDNN_ATTENTION,
33+
SDPBackend.MATH,
34+
]
35+
36+
2937
def filter_sdpa_kernels(
3038
sdpa_kernels: List[SDPBackend],
3139
query: torch.Tensor,
@@ -40,7 +48,9 @@ def filter_sdpa_kernels(
4048
params = SDPAParams(query, key, value, attn_mask, dropout_p, is_causal, enable_gqa)
4149
new_kernels = []
4250
for kernel in sdpa_kernels:
43-
if kernel == SDPBackend.FLASH_ATTENTION and not can_use_flash_attention(params):
51+
if not torch.cuda.is_available() and kernel != SDPBackend.MATH:
52+
continue
53+
elif kernel == SDPBackend.FLASH_ATTENTION and not can_use_flash_attention(params):
4454
continue
4555
elif (
4656
kernel == SDPBackend.EFFICIENT_ATTENTION

keys_values/finetune/longcon_offload_full.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@
4646
)
4747

4848
from keys_values.array_limit import TemporaryArrayLimit
49-
from keys_values.attention_utils import DEFAULT_TMP_ARRAY_LIMIT_GB
49+
from keys_values.attention_utils import (
50+
DEFAULT_TMP_ARRAY_LIMIT_GB,
51+
SDPA_KERNELS_BEST_ORDERING,
52+
)
5053
from keys_values.data import LongBenchV2, INPUT_IDS_NAME
5154
from keys_values.finetune.args import (
5255
EvalArgs,
@@ -384,12 +387,6 @@ def main(
384387

385388
os.makedirs(out_dir, exist_ok=True)
386389
# Order of preference for SDPA kernels
387-
sdpa_kernels = [
388-
SDPBackend.FLASH_ATTENTION,
389-
SDPBackend.EFFICIENT_ATTENTION,
390-
SDPBackend.CUDNN_ATTENTION,
391-
SDPBackend.MATH,
392-
]
393390
limit_gb = attention_forward_temp_size_gb
394391
if limit_gb is None:
395392
limit_gb = DEFAULT_TMP_ARRAY_LIMIT_GB
@@ -399,12 +396,12 @@ def main(
399396
name="attention_forward_temp_size_gb",
400397
)
401398
mha_kwargs = dict(
402-
sdpa_kernels=sdpa_kernels,
399+
sdpa_kernels=SDPA_KERNELS_BEST_ORDERING,
403400
tmp_array_limit_gb=tmp_array_limit_forward,
404401
pos_encoding=position_encoding_factory(config, do_yarn=yarn_rope),
405402
)
406403
if "sdpa_kernels" not in kv_cache.cache_kwargs:
407-
kv_cache.cache_kwargs["sdpa_kernels"] = sdpa_kernels
404+
kv_cache.cache_kwargs["sdpa_kernels"] = SDPA_KERNELS_BEST_ORDERING
408405
kv_cache.cache_kwargs["tmp_array_limit_gb"] = tmp_array_limit_forward
409406
kv_cache.cache_kwargs["pos_encoding"] = mha_kwargs["pos_encoding"]
410407
# We create the GPT model on the device, then copy. This is faster

keys_values/finetune/longcon_offload_lora.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,12 @@
2424

2525
import torch
2626
from torch.utils.data import DataLoader
27-
from torch.nn.attention import SDPBackend
2827
from torchmetrics import RunningMean
2928

3029
from litgpt.args import TrainArgs
3130
from litgpt.data import DataModule
3231
from litgpt.lora import Config, mark_only_lora_as_trainable, lora_filter
3332
from litgpt.prompts import save_prompt_style
34-
from litgpt.scripts.merge_lora import merge_lora
3533
from litgpt.tokenizer import Tokenizer
3634
from litgpt.utils import (
3735
CycleIterator,
@@ -46,7 +44,10 @@
4644
)
4745

4846
from keys_values.array_limit import TemporaryArrayLimit
49-
from keys_values.attention_utils import DEFAULT_TMP_ARRAY_LIMIT_GB
47+
from keys_values.attention_utils import (
48+
DEFAULT_TMP_ARRAY_LIMIT_GB,
49+
SDPA_KERNELS_BEST_ORDERING,
50+
)
5051
from keys_values.data import LongBenchV2, INPUT_IDS_NAME
5152
from keys_values.finetune.args import (
5253
EvalArgs,
@@ -415,12 +416,6 @@ def main(
415416

416417
os.makedirs(out_dir, exist_ok=True)
417418
# Order of preference for SDPA kernels
418-
sdpa_kernels = [
419-
SDPBackend.FLASH_ATTENTION,
420-
SDPBackend.EFFICIENT_ATTENTION,
421-
SDPBackend.CUDNN_ATTENTION,
422-
SDPBackend.MATH,
423-
]
424419
limit_gb = attention_forward_temp_size_gb
425420
if limit_gb is None:
426421
limit_gb = DEFAULT_TMP_ARRAY_LIMIT_GB
@@ -430,12 +425,12 @@ def main(
430425
name="attention_forward_temp_size_gb",
431426
)
432427
mha_kwargs = dict(
433-
sdpa_kernels=sdpa_kernels,
428+
sdpa_kernels=SDPA_KERNELS_BEST_ORDERING,
434429
tmp_array_limit_gb=tmp_array_limit_forward,
435430
pos_encoding=position_encoding_factory(config, do_yarn=yarn_rope),
436431
)
437432
if "sdpa_kernels" not in kv_cache.cache_kwargs:
438-
kv_cache.cache_kwargs["sdpa_kernels"] = sdpa_kernels
433+
kv_cache.cache_kwargs["sdpa_kernels"] = SDPA_KERNELS_BEST_ORDERING
439434
kv_cache.cache_kwargs["tmp_array_limit_gb"] = tmp_array_limit_forward
440435
kv_cache.cache_kwargs["pos_encoding"] = mha_kwargs["pos_encoding"]
441436
# We create the GPT model on the device, then copy. This is faster
@@ -600,11 +595,6 @@ def main(
600595
save_hyperparameters(setup, save_dir)
601596
if hasattr(data, "prompt_style"):
602597
save_prompt_style(data.prompt_style, save_dir)
603-
merge_lora(
604-
checkpoint_dir=save_dir,
605-
lora_fname=LORA_WEIGHTS_FNAME,
606-
pretrained_fname=LIT_MODEL_FNAME,
607-
)
608598

609599

610600
def fit(

keys_values/finetune/longcontext_eval.py

Lines changed: 66 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from lightning.fabric.plugins import BitsandbytesPrecision
2424
from lightning.fabric.strategies import DDPStrategy
2525
from torch.utils.data import DataLoader
26-
from torch.nn.attention import SDPBackend
2726
import yaml
2827

2928
from litgpt.config import Config as ConfigFull
@@ -39,34 +38,39 @@
3938
)
4039

4140
from keys_values.array_limit import TemporaryArrayLimit
42-
from keys_values.attention_utils import DEFAULT_TMP_ARRAY_LIMIT_GB
41+
from keys_values.attention_utils import (
42+
DEFAULT_TMP_ARRAY_LIMIT_GB,
43+
SDPA_KERNELS_BEST_ORDERING,
44+
)
4345
from keys_values.data import LongBenchV2, INPUT_IDS_NAME
4446
from keys_values.data.evaluation import (
4547
TASK_NAME,
4648
ORIG_IDX_NAME,
4749
EvaluationTasks,
4850
EvaluationWithTasksHelper,
4951
)
52+
from keys_values.finetune.args import KVCacheArgs
5053
from keys_values.finetune.batch_transform import BatchTransformFactory
5154
from keys_values.finetune.longcontext_full import (
5255
wrap_gpt_model,
5356
)
54-
from keys_values.utils import flush_io_streams
5557
from keys_values.finetune.utils import (
5658
LIT_MODEL_FNAME,
5759
HEAD_MODEL_FNAME,
5860
LORA_WEIGHTS_FNAME,
5961
LORA_WEIGHTS_FNAME_OLD,
6062
check_kv_cache,
63+
adapt_requires_grad,
6164
)
6265
from keys_values.head_model_factory import HeadModelFactory
63-
from keys_values.kvcache.utils import VerbosityLevels
66+
from keys_values.kvcache.utils import VerbosityLevels, fabric_precision_to_dtype
6467
from keys_values.long_context import (
6568
LongContextInferenceModel,
6669
)
67-
from keys_values.finetune.args import KVCacheArgs
6870
from keys_values.lora import GPT as GPTLoRA
6971
from keys_values.model import GPT as GPTFull
72+
from keys_values.pos_encoding import position_encoding_factory
73+
from keys_values.utils import flush_io_streams
7074

7175

7276
@dataclass(frozen=True)
@@ -96,19 +100,19 @@ def setup(
96100
batch_size: Optional[int] = None,
97101
kv_cache: KVCacheArgs = KVCacheArgs(
98102
name="h2o-torch-quantized8",
99-
cache_length=8192,
100-
layers_per_cell=1,
101-
chunk_size=256,
103+
cache_length=16384,
104+
chunk_size=1024,
102105
cache_kwargs={
103106
"replay_log_blocksize": 1024,
104107
"allocate_buffers": False,
105108
"max_num_ranges": 4,
106109
},
107110
randomize_chunk_sizes=False,
108-
single_tokens_for_targets=False,
109-
verbose=VerbosityLevels.SOME.value,
110-
attention_forward_temp_size_gb=4,
111+
allocate_buffers=False,
111112
),
113+
verbose: Optional[str] = None,
114+
attention_forward_temp_size_gb: Optional[float] = None,
115+
yarn_rope: bool = True,
112116
) -> None:
113117
"""Evaluate a range of model checkpoints on a test set
114118
@@ -121,8 +125,19 @@ def setup(
121125
seed: The random seed to use for reproducibility.
122126
access_token: Optional API token to access models with restrictions.
123127
batch_size: Size for test set batches
124-
kv_cache: Configuration for the KV caches. If not given, the configuration
125-
of the checkpoints is being used.
128+
kv_cache: Configuration for the KV caches. See
129+
``keys_values.finetune.args.KVCacheArgs`` for details. Defaults to
130+
H2O with PyTorch 8-bit quantization. Make sure to adjust
131+
`kv_cache.cache_length`.
132+
verbose: Verbosity level for logging outputs.
133+
attention_forward_temp_size_gb: Size of GPU memory buffers (in GB) used
134+
in naive SDPA. At present, naive SDPA is used with KV caches which
135+
require attention weights (e.g., H2O).
136+
yarn_rope: Should YaRN be used to adjust RoPE (position encoding) to the
137+
sequence length for each batch? Defaults to `True`. If not, RoPE is
138+
determined by the model configuration, and is static (no dependence
139+
on sequence length).
140+
TODO: Should be stored as hyperparameter and loaded with checkpoint!
126141
127142
"""
128143
# Collect evaluation tasks
@@ -163,6 +178,20 @@ def setup(
163178
kv_cache = KVCacheArgs(**hyp_pars["kv_cache"])
164179
check_kv_cache(kv_cache)
165180
check_valid_checkpoint_dir(checkpoint_dir)
181+
# Legacy arguments
182+
if verbose is None:
183+
if kv_cache.verbose is not None:
184+
verbose = kv_cache.verbose
185+
kv_cache.verbose = None
186+
else:
187+
verbose = VerbosityLevels.SOME.value
188+
verbose = VerbosityLevels(verbose)
189+
if attention_forward_temp_size_gb is None:
190+
if kv_cache.attention_forward_temp_size_gb is not None:
191+
attention_forward_temp_size_gb = kv_cache.attention_forward_temp_size_gb
192+
kv_cache.attention_forward_temp_size_gb = None
193+
else:
194+
attention_forward_temp_size_gb = 4
166195

167196
precision = hyp_pars["precision"] or get_default_supported_precision(training=True)
168197
if devices * num_nodes > 1:
@@ -189,6 +218,9 @@ def setup(
189218
model_config=model_config,
190219
eval_tasks=eval.tasks,
191220
devices=devices,
221+
verbose=verbose,
222+
attention_forward_temp_size_gb=attention_forward_temp_size_gb,
223+
yarn_rope=yarn_rope,
192224
)
193225

194226

@@ -204,6 +236,9 @@ def main(
204236
model_config: ModelConfiguration,
205237
eval_tasks: List[str],
206238
devices: int,
239+
verbose: VerbosityLevels,
240+
attention_forward_temp_size_gb: float,
241+
yarn_rope: bool,
207242
) -> None:
208243
tokenizer = Tokenizer(checkpoint_dir)
209244
# Test dataloader is over cross product of test dataset and evaluation
@@ -229,15 +264,6 @@ def main(
229264

230265
with fabric.init_module(empty_init=(fabric.world_size > 1)):
231266
# Order of preference for SDPA kernels
232-
sdpa_kernels = [
233-
SDPBackend.FLASH_ATTENTION,
234-
SDPBackend.EFFICIENT_ATTENTION,
235-
SDPBackend.CUDNN_ATTENTION,
236-
SDPBackend.MATH,
237-
]
238-
mha_kwargs = {"sdpa_kernels": sdpa_kernels}
239-
if "sdpa_kernels" not in kv_cache.cache_kwargs:
240-
kv_cache.cache_kwargs["sdpa_kernels"] = sdpa_kernels
241267
limit_gb = kv_cache.attention_forward_temp_size_gb
242268
if limit_gb is None:
243269
limit_gb = DEFAULT_TMP_ARRAY_LIMIT_GB
@@ -246,8 +272,17 @@ def main(
246272
init_val=limit_gb,
247273
name="attention_forward_temp_size_gb",
248274
)
249-
mha_kwargs["tmp_array_limit_gb"] = tmp_array_limit_forward
275+
mha_kwargs = {
276+
"sdpa_kernels": SDPA_KERNELS_BEST_ORDERING,
277+
"tmp_array_limit_gb": tmp_array_limit_forward,
278+
"pos_encoding": position_encoding_factory(
279+
model_config.config, do_yarn=yarn_rope,
280+
),
281+
}
282+
if "sdpa_kernels" not in kv_cache.cache_kwargs:
283+
kv_cache.cache_kwargs["sdpa_kernels"] = SDPA_KERNELS_BEST_ORDERING
250284
kv_cache.cache_kwargs["tmp_array_limit_gb"] = tmp_array_limit_forward
285+
kv_cache.cache_kwargs["pos_encoding"] = mha_kwargs["pos_encoding"]
251286
if model_type == "full":
252287
gpt_model = GPTFull(model_config.config, **mha_kwargs)
253288
else:
@@ -272,16 +307,20 @@ def main(
272307
data=data,
273308
**model_config.head_model_kwargs,
274309
)
310+
if model_type == "lora":
311+
mark_only_lora_as_trainable(gpt_model)
312+
adapt_requires_grad(gpt_model, head_model)
275313
model = wrap_gpt_model(
276-
fabric=fabric,
277314
gpt_model=gpt_model,
278315
head_model=head_model,
279316
kv_cache=kv_cache,
317+
grad=None,
318+
verbose=verbose,
319+
attention_backward_temp_size_gb=None,
280320
max_batch_size=batch_size,
281-
model_for_training=False,
321+
dtype=fabric_precision_to_dtype(fabric._precision.precision),
322+
fabric=fabric,
282323
)
283-
if model_type == "lora":
284-
mark_only_lora_as_trainable(model.gpt_model)
285324
model = fabric.setup_module(model)
286325

287326
# Load base model

0 commit comments

Comments
 (0)