2323from lightning .fabric .plugins import BitsandbytesPrecision
2424from lightning .fabric .strategies import DDPStrategy
2525from torch .utils .data import DataLoader
26- from torch .nn .attention import SDPBackend
2726import yaml
2827
2928from litgpt .config import Config as ConfigFull
3938)
4039
4140from keys_values .array_limit import TemporaryArrayLimit
42- from keys_values .attention_utils import DEFAULT_TMP_ARRAY_LIMIT_GB
41+ from keys_values .attention_utils import (
42+ DEFAULT_TMP_ARRAY_LIMIT_GB ,
43+ SDPA_KERNELS_BEST_ORDERING ,
44+ )
4345from keys_values .data import LongBenchV2 , INPUT_IDS_NAME
4446from keys_values .data .evaluation import (
4547 TASK_NAME ,
4648 ORIG_IDX_NAME ,
4749 EvaluationTasks ,
4850 EvaluationWithTasksHelper ,
4951)
52+ from keys_values .finetune .args import KVCacheArgs
5053from keys_values .finetune .batch_transform import BatchTransformFactory
5154from keys_values .finetune .longcontext_full import (
5255 wrap_gpt_model ,
5356)
54- from keys_values .utils import flush_io_streams
5557from keys_values .finetune .utils import (
5658 LIT_MODEL_FNAME ,
5759 HEAD_MODEL_FNAME ,
5860 LORA_WEIGHTS_FNAME ,
5961 LORA_WEIGHTS_FNAME_OLD ,
6062 check_kv_cache ,
63+ adapt_requires_grad ,
6164)
6265from keys_values .head_model_factory import HeadModelFactory
63- from keys_values .kvcache .utils import VerbosityLevels
66+ from keys_values .kvcache .utils import VerbosityLevels , fabric_precision_to_dtype
6467from keys_values .long_context import (
6568 LongContextInferenceModel ,
6669)
67- from keys_values .finetune .args import KVCacheArgs
6870from keys_values .lora import GPT as GPTLoRA
6971from keys_values .model import GPT as GPTFull
72+ from keys_values .pos_encoding import position_encoding_factory
73+ from keys_values .utils import flush_io_streams
7074
7175
7276@dataclass (frozen = True )
@@ -96,19 +100,19 @@ def setup(
96100 batch_size : Optional [int ] = None ,
97101 kv_cache : KVCacheArgs = KVCacheArgs (
98102 name = "h2o-torch-quantized8" ,
99- cache_length = 8192 ,
100- layers_per_cell = 1 ,
101- chunk_size = 256 ,
103+ cache_length = 16384 ,
104+ chunk_size = 1024 ,
102105 cache_kwargs = {
103106 "replay_log_blocksize" : 1024 ,
104107 "allocate_buffers" : False ,
105108 "max_num_ranges" : 4 ,
106109 },
107110 randomize_chunk_sizes = False ,
108- single_tokens_for_targets = False ,
109- verbose = VerbosityLevels .SOME .value ,
110- attention_forward_temp_size_gb = 4 ,
111+ allocate_buffers = False ,
111112 ),
113+ verbose : Optional [str ] = None ,
114+ attention_forward_temp_size_gb : Optional [float ] = None ,
115+ yarn_rope : bool = True ,
112116) -> None :
113117 """Evaluate a range of model checkpoints on a test set
114118
@@ -121,8 +125,19 @@ def setup(
121125 seed: The random seed to use for reproducibility.
122126 access_token: Optional API token to access models with restrictions.
123127 batch_size: Size for test set batches
124- kv_cache: Configuration for the KV caches. If not given, the configuration
125- of the checkpoints is being used.
128+ kv_cache: Configuration for the KV caches. See
129+ ``keys_values.finetune.args.KVCacheArgs`` for details. Defaults to
130+ H2O with PyTorch 8-bit quantization. Make sure to adjust
131+ `kv_cache.cache_length`.
132+ verbose: Verbosity level for logging outputs.
133+ attention_forward_temp_size_gb: Size of GPU memory buffers (in GB) used
134+ in naive SDPA. At present, naive SDPA is used with KV caches which
135+ require attention weights (e.g., H2O).
136+ yarn_rope: Should YaRN be used to adjust RoPE (position encoding) to the
137+ sequence length for each batch? Defaults to `True`. If not, RoPE is
138+ determined by the model configuration, and is static (no dependence
139+ on sequence length).
140+ TODO: Should be stored as hyperparameter and loaded with checkpoint!
126141
127142 """
128143 # Collect evaluation tasks
@@ -163,6 +178,20 @@ def setup(
163178 kv_cache = KVCacheArgs (** hyp_pars ["kv_cache" ])
164179 check_kv_cache (kv_cache )
165180 check_valid_checkpoint_dir (checkpoint_dir )
181+ # Legacy arguments
182+ if verbose is None :
183+ if kv_cache .verbose is not None :
184+ verbose = kv_cache .verbose
185+ kv_cache .verbose = None
186+ else :
187+ verbose = VerbosityLevels .SOME .value
188+ verbose = VerbosityLevels (verbose )
189+ if attention_forward_temp_size_gb is None :
190+ if kv_cache .attention_forward_temp_size_gb is not None :
191+ attention_forward_temp_size_gb = kv_cache .attention_forward_temp_size_gb
192+ kv_cache .attention_forward_temp_size_gb = None
193+ else :
194+ attention_forward_temp_size_gb = 4
166195
167196 precision = hyp_pars ["precision" ] or get_default_supported_precision (training = True )
168197 if devices * num_nodes > 1 :
@@ -189,6 +218,9 @@ def setup(
189218 model_config = model_config ,
190219 eval_tasks = eval .tasks ,
191220 devices = devices ,
221+ verbose = verbose ,
222+ attention_forward_temp_size_gb = attention_forward_temp_size_gb ,
223+ yarn_rope = yarn_rope ,
192224 )
193225
194226
@@ -204,6 +236,9 @@ def main(
204236 model_config : ModelConfiguration ,
205237 eval_tasks : List [str ],
206238 devices : int ,
239+ verbose : VerbosityLevels ,
240+ attention_forward_temp_size_gb : float ,
241+ yarn_rope : bool ,
207242) -> None :
208243 tokenizer = Tokenizer (checkpoint_dir )
209244 # Test dataloader is over cross product of test dataset and evaluation
@@ -229,15 +264,6 @@ def main(
229264
230265 with fabric .init_module (empty_init = (fabric .world_size > 1 )):
231266 # Order of preference for SDPA kernels
232- sdpa_kernels = [
233- SDPBackend .FLASH_ATTENTION ,
234- SDPBackend .EFFICIENT_ATTENTION ,
235- SDPBackend .CUDNN_ATTENTION ,
236- SDPBackend .MATH ,
237- ]
238- mha_kwargs = {"sdpa_kernels" : sdpa_kernels }
239- if "sdpa_kernels" not in kv_cache .cache_kwargs :
240- kv_cache .cache_kwargs ["sdpa_kernels" ] = sdpa_kernels
241267 limit_gb = kv_cache .attention_forward_temp_size_gb
242268 if limit_gb is None :
243269 limit_gb = DEFAULT_TMP_ARRAY_LIMIT_GB
@@ -246,8 +272,17 @@ def main(
246272 init_val = limit_gb ,
247273 name = "attention_forward_temp_size_gb" ,
248274 )
249- mha_kwargs ["tmp_array_limit_gb" ] = tmp_array_limit_forward
275+ mha_kwargs = {
276+ "sdpa_kernels" : SDPA_KERNELS_BEST_ORDERING ,
277+ "tmp_array_limit_gb" : tmp_array_limit_forward ,
278+ "pos_encoding" : position_encoding_factory (
279+ model_config .config , do_yarn = yarn_rope ,
280+ ),
281+ }
282+ if "sdpa_kernels" not in kv_cache .cache_kwargs :
283+ kv_cache .cache_kwargs ["sdpa_kernels" ] = SDPA_KERNELS_BEST_ORDERING
250284 kv_cache .cache_kwargs ["tmp_array_limit_gb" ] = tmp_array_limit_forward
285+ kv_cache .cache_kwargs ["pos_encoding" ] = mha_kwargs ["pos_encoding" ]
251286 if model_type == "full" :
252287 gpt_model = GPTFull (model_config .config , ** mha_kwargs )
253288 else :
@@ -272,16 +307,20 @@ def main(
272307 data = data ,
273308 ** model_config .head_model_kwargs ,
274309 )
310+ if model_type == "lora" :
311+ mark_only_lora_as_trainable (gpt_model )
312+ adapt_requires_grad (gpt_model , head_model )
275313 model = wrap_gpt_model (
276- fabric = fabric ,
277314 gpt_model = gpt_model ,
278315 head_model = head_model ,
279316 kv_cache = kv_cache ,
317+ grad = None ,
318+ verbose = verbose ,
319+ attention_backward_temp_size_gb = None ,
280320 max_batch_size = batch_size ,
281- model_for_training = False ,
321+ dtype = fabric_precision_to_dtype (fabric ._precision .precision ),
322+ fabric = fabric ,
282323 )
283- if model_type == "lora" :
284- mark_only_lora_as_trainable (model .gpt_model )
285324 model = fabric .setup_module (model )
286325
287326 # Load base model
0 commit comments