|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import sys |
| 7 | +import time |
| 8 | +from pathlib import Path |
| 9 | +from typing import Optional |
| 10 | + |
| 11 | +import torch |
| 12 | +import torch._dynamo.config |
| 13 | +import torch._inductor.config |
| 14 | + |
| 15 | +torch._dynamo.config.automatic_dynamic_shapes = True |
| 16 | +torch._inductor.config.triton.unique_kernel_names = True |
| 17 | +torch._inductor.config.epilogue_fusion = False |
| 18 | +torch._inductor.config.triton.cudagraphs = True |
| 19 | +torch._dynamo.config.cache_size_limit = 100000 |
| 20 | + |
| 21 | +from sentencepiece import SentencePieceProcessor |
| 22 | + |
| 23 | +from model import Transformer |
| 24 | + |
| 25 | +try: |
| 26 | + import lm_eval |
| 27 | + lm_eval_available = True |
| 28 | +except: |
| 29 | + lm_eval_available = False |
| 30 | + |
| 31 | +from generate import _load_model, encode_tokens, model_forward |
| 32 | + |
| 33 | +if lm_eval_available: |
| 34 | + try: # lm_eval version 0.4 |
| 35 | + from lm_eval.models.huggingface import HFLM as eval_wrapper |
| 36 | + from lm_eval.tasks import get_task_dict |
| 37 | + from lm_eval.evaluator import evaluate |
| 38 | + except: #lm_eval version 0.3 |
| 39 | + from lm_eval import base |
| 40 | + from lm_eval import tasks |
| 41 | + from lm_eval import evaluator |
| 42 | + eval_wrapper=base.BaseLM |
| 43 | + get_task_dict=tasks.get_task_dict |
| 44 | + evaluate=evaluator.evaluate |
| 45 | + |
| 46 | + |
| 47 | +def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( |
| 48 | + model: Transformer, |
| 49 | + prompt: torch.Tensor, |
| 50 | + max_new_tokens: int, |
| 51 | + max_seq_length: Optional[int] = None, |
| 52 | +): |
| 53 | + """ |
| 54 | + Sets up model cache and does some bookkeeping calculations for prompt, input_pos and max_seq_length |
| 55 | + that are needed for prefill or model_forward |
| 56 | +
|
| 57 | + Args: |
| 58 | + model (LLaMA): The model whose cache gets set up |
| 59 | + prompt (torch.Tensor): Tensor of shape (T) with indices of the prompt sequence. |
| 60 | + max_new_tokens (int): The desired maximum number of new tokens that can be generated. |
| 61 | + max_seq_length (Optional[int], optional): The maximum sequence length allowed. |
| 62 | +
|
| 63 | + Returns: |
| 64 | + seq (torch.Tensor): prompt but padded with zeros to size max_seq_length |
| 65 | + input_pos (torch.Tensor): tensor of integers in increasing order |
| 66 | + max_seq_length (int): The maximum sequence length allowed, updated based on other numbers |
| 67 | + """ |
| 68 | + T = prompt.size(0) |
| 69 | + T_new = T + max_new_tokens |
| 70 | + if max_seq_length is None: |
| 71 | + max_seq_length = min(T_new, model.config.block_size) |
| 72 | + |
| 73 | + device, dtype = prompt.device, prompt.dtype |
| 74 | + # create an empty tensor of the expected final shape and fill in the current tokens |
| 75 | + empty = torch.empty(T_new, dtype=dtype, device=device) |
| 76 | + empty[:T] = prompt |
| 77 | + seq = empty |
| 78 | + input_pos = torch.arange(0, T, device=device) |
| 79 | + |
| 80 | + with torch.device(device): |
| 81 | + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) |
| 82 | + |
| 83 | + return seq, input_pos, max_seq_length |
| 84 | + |
| 85 | +class GPTFastEvalWrapper(eval_wrapper): |
| 86 | + """ |
| 87 | + A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. |
| 88 | + """ |
| 89 | + def __init__( |
| 90 | + self, |
| 91 | + model: Transformer, |
| 92 | + tokenizer, |
| 93 | + max_seq_length: Optional[int]=None, |
| 94 | + ): |
| 95 | + super().__init__() |
| 96 | + self._model = model |
| 97 | + self._tokenizer = tokenizer |
| 98 | + self._device = torch.device('cuda') |
| 99 | + self._max_seq_length = 2048 if max_seq_length is None else max_seq_length |
| 100 | + |
| 101 | + @property |
| 102 | + def eot_token_id(self): |
| 103 | + return self._tokenizer.eos_id() |
| 104 | + |
| 105 | + @property |
| 106 | + def max_length(self): |
| 107 | + return self._max_seq_length |
| 108 | + |
| 109 | + @property |
| 110 | + def max_gen_toks(self): |
| 111 | + return 50 |
| 112 | + |
| 113 | + @property |
| 114 | + def batch_size(self): |
| 115 | + return 1 |
| 116 | + |
| 117 | + @property |
| 118 | + def device(self): |
| 119 | + return self._device |
| 120 | + |
| 121 | + def tok_encode(self, string: str, **kwargs): |
| 122 | + encoded = encode_tokens(self._tokenizer, |
| 123 | + string, bos=True, device=self._device) |
| 124 | + # encoded is a pytorch tensor, but some internal logic in the |
| 125 | + # eval harness expects it to be a list instead |
| 126 | + # TODO: verify this for multi-batch as well |
| 127 | + encoded = encoded.tolist() |
| 128 | + return encoded |
| 129 | + |
| 130 | + def tok_decode(self, tokens): |
| 131 | + decoded = self._tokenizer.decode(tokens) |
| 132 | + return decoded |
| 133 | + |
| 134 | + def _model_call(self, inps): |
| 135 | + # TODO: make batches work |
| 136 | + inps = inps.squeeze(0) |
| 137 | + |
| 138 | + max_new_tokens = 1 |
| 139 | + seq, input_pos, max_seq_length = \ |
| 140 | + setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( |
| 141 | + self._model, |
| 142 | + inps, |
| 143 | + max_new_tokens, |
| 144 | + self.max_length, |
| 145 | + ) |
| 146 | + x = seq.index_select(0, input_pos).view(1, -1) |
| 147 | + logits = model_forward(self._model, x, input_pos) |
| 148 | + return logits |
| 149 | + |
| 150 | + def _model_generate(self, context, max_length, eos_token_id): |
| 151 | + raise Exception('unimplemented') |
| 152 | + |
| 153 | + |
| 154 | +@torch.no_grad() |
| 155 | +def eval( |
| 156 | + model: Transformer, |
| 157 | + tokenizer, |
| 158 | + tasks: list = ["hellaswag"], |
| 159 | + limit: Optional[int] = None, |
| 160 | + max_seq_length: Optional[int] = None, |
| 161 | +) -> dict: |
| 162 | + """ |
| 163 | + Evaluates a language model on a specified task using the lm-evaluation-harness library. |
| 164 | +
|
| 165 | + Args: |
| 166 | + model (Transformer): The pre-trained language model to evaluate. |
| 167 | + tokenizer: The tokenizer to use for encoding/decoding text. |
| 168 | + task (str): The name of the evaluation task to perform. |
| 169 | + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). |
| 170 | + max_seq_length (Optional[int]): The maximum sequence length allowed for input text. |
| 171 | +
|
| 172 | + Returns: |
| 173 | + eval_results (dict): A dictionary of evaluation results for the specified task(s). |
| 174 | + """ |
| 175 | + model_eval_wrapper = GPTFastEvalWrapper( |
| 176 | + model, |
| 177 | + tokenizer, |
| 178 | + max_seq_length, |
| 179 | + ) |
| 180 | + |
| 181 | + try: |
| 182 | + lm_eval.tasks.initialize_tasks() |
| 183 | + except: |
| 184 | + pass |
| 185 | + |
| 186 | + if 'hendrycks_test' in tasks: |
| 187 | + tasks.remove('hendrycks_test') |
| 188 | + tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] |
| 189 | + task_dict = get_task_dict(tasks) |
| 190 | + |
| 191 | + eval_results = evaluate( |
| 192 | + model_eval_wrapper, |
| 193 | + task_dict, |
| 194 | + limit=limit, |
| 195 | + ) |
| 196 | + return eval_results |
| 197 | + |
| 198 | + |
| 199 | +def main( |
| 200 | + checkpoint_path: Path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), |
| 201 | + compile: bool = False, |
| 202 | + tasks: list = ["hellaswag"], |
| 203 | + limit: Optional[int] = None, |
| 204 | + max_seq_length: Optional[int] = None, |
| 205 | +) -> None: |
| 206 | + """Evaluates model on a task from the `lm-evaluation-harness` library. |
| 207 | +
|
| 208 | + Args: |
| 209 | + checkpoint_path (Path): The path to the model checkpoint file to load. |
| 210 | + compile (bool): Whether or not to compile the model for optimization. |
| 211 | + task (Optional[str]): The name of the evaluation task or a list of tasks to perform. |
| 212 | + limit (Optional[int]): The maximum number of samples to evaluate (None for all available). |
| 213 | + max_seq_length (Optional[int]): The maximum sequence length allowed for input text. |
| 214 | +
|
| 215 | + """ |
| 216 | + |
| 217 | + assert checkpoint_path.is_file(), checkpoint_path |
| 218 | + |
| 219 | + tokenizer_path = checkpoint_path.parent / "tokenizer.model" |
| 220 | + assert tokenizer_path.is_file(), tokenizer_path |
| 221 | + |
| 222 | + device = 'cuda' |
| 223 | + precision = torch.bfloat16 |
| 224 | + |
| 225 | + print("Loading model ...") |
| 226 | + t0 = time.time() |
| 227 | + model = _load_model(checkpoint_path, device, precision, False) |
| 228 | + |
| 229 | + torch.cuda.synchronize() |
| 230 | + print(f"Time to load model: {time.time() - t0:.02f} seconds.") |
| 231 | + |
| 232 | + model.eval() |
| 233 | + |
| 234 | + tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) |
| 235 | + |
| 236 | + torch.manual_seed(1234) |
| 237 | + |
| 238 | + if compile: |
| 239 | + global model_forward |
| 240 | + model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) |
| 241 | + torch._inductor.config.coordinate_descent_tuning = True |
| 242 | + |
| 243 | + t1 = time.time() |
| 244 | + result = eval( |
| 245 | + model, |
| 246 | + tokenizer, |
| 247 | + tasks, |
| 248 | + limit, |
| 249 | + max_seq_length, |
| 250 | + ) |
| 251 | + print(f"Time to run eval: {time.time() - t1:.02f} seconds.") |
| 252 | + print(f"For model {checkpoint_path}") |
| 253 | + for task, res in result["results"].items(): |
| 254 | + print(f"{task}: {res}") |
| 255 | + |
| 256 | + |
| 257 | +if __name__ == '__main__': |
| 258 | + import argparse |
| 259 | + parser = argparse.ArgumentParser(description='Your CLI description.') |
| 260 | + |
| 261 | + parser.add_argument('--checkpoint_path', type=Path, default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/lit_model.pth"), help='Model checkpoint path.') |
| 262 | + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') |
| 263 | + parser.add_argument('--tasks', nargs='+', type=str, default=["hellaswag"], help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2') |
| 264 | + parser.add_argument('--limit', type=int, default=None, help='number of samples to evalulate') |
| 265 | + parser.add_argument('--max_seq_length', type=int, default=None, help='maximum length sequence to evaluate') |
| 266 | + |
| 267 | + args = parser.parse_args() |
| 268 | + main( |
| 269 | + Path(args.checkpoint_path), args.compile, args.tasks, args.limit, args.max_seq_length, |
| 270 | + ) |
0 commit comments