Skip to content

Commit f9546ba

Browse files
[ColossalEval] support for vllm (#6056)
* support vllm * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * modify vllm and update readme * run pre-commit * remove dupilicated lines and refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update param name * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refine code * update readme * refine code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 4fa6b95 commit f9546ba

File tree

19 files changed

+576
-35
lines changed

19 files changed

+576
-35
lines changed

applications/ColossalEval/README.md

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ inference_kwargs = {
154154
"calculate_loss": True,
155155
"all_classes": ["A", "B", "C", "D"],
156156
"language": "Chinese",
157-
"pretrain": False,
157+
"calculate_overall_loss": False,
158158
"max_new_tokens": 32
159159
}
160160
```
@@ -163,7 +163,7 @@ The `inference_kwargs` currently contains 5 fields:
163163
- `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated
164164
- `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None.
165165
- `language` (str, compulsory): The language for the subcategory.
166-
- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
166+
- `calculate_overall_loss` (bool, compulsory): Whether to calculate the overall loss of sentences or not if the dataset is a pretrain dataset. It is usually used for calculate perplexity when you want to evaluate a model with extended context length.
167167
- `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference.
168168

169169
For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly.
@@ -230,7 +230,7 @@ Example:
230230
In this step, you will configure your tokenizer and model arguments to infer on the given datasets.
231231

232232
A config file consists of two parts.
233-
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
233+
1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel`, `ChatGLMModel2` and `vLLMModel`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. `vLLMModel` is for models that can be loaded with vllm offline inference `LLM` class. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields.
234234
2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`.
235235

236236
Once you have all config ready, the program will run inference on all the given datasets on all the given models.
@@ -272,7 +272,42 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM
272272
}
273273
```
274274

275-
Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
275+
An example config using model class `vLLMModel` and dataset class `CMMLUDataset` can be:
276+
```json
277+
{
278+
"model": [
279+
{
280+
"name": "model name",
281+
"model_class": "vLLMModel",
282+
"parameters": {
283+
"path": "path to model",
284+
"model_max_length": 2048,
285+
"tokenizer_path": "",
286+
"tokenizer_kwargs": {
287+
"trust_remote_code": true
288+
},
289+
"model_kwargs": {
290+
"trust_remote_code": true
291+
},
292+
"prompt_template": "plain",
293+
"batch_size": 4
294+
}
295+
}
296+
],
297+
"dataset": [
298+
{
299+
"name": "dataset name",
300+
"dataset_class": "CMMLUDataset",
301+
"debug": false,
302+
"few_shot": true,
303+
"path": "path to original dataset",
304+
"save_path": "path to save converted dataset"
305+
}
306+
]
307+
}
308+
```
309+
310+
Currently, we support Hugging Face models as well as vLLM models. For Hugging Face models, the `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. For vLLM model, the `tokenizer_kwargs` and `model_kwargs` are loaded together in `LLM` class.`few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong.
276311

277312
> For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation.
278313
@@ -287,7 +322,7 @@ torchrun --nproc_per_node=4 inference.py \
287322
--inference_save_path "path to save inference results"
288323
```
289324

290-
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size.
325+
You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size (currently not support for `vLLMModel`).
291326

292327
### Evaluation
293328

@@ -530,10 +565,6 @@ class CustomizedModel(BaseModel):
530565

531566
Once you have successfully added your own model, you can specify your model class in your inference config.
532567

533-
## To do
534-
535-
- [ ] Add visualization code for evaluation results on public dataset
536-
- [ ] Improve the way to label target tokens
537568

538569
## Citations
539570

applications/ColossalEval/colossal_eval/dataset/agieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
"calculate_loss": True,
4848
"all_classes": None,
4949
"language": "Chinese",
50-
"pretrain": False,
50+
"calculate_overall_loss": False,
5151
"max_new_tokens": 32,
5252
}
5353

applications/ColossalEval/colossal_eval/dataset/ceval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"calculate_loss": False,
7171
"all_classes": ["A", "B", "C", "D"],
7272
"language": "Chinese",
73-
"pretrain": False,
73+
"calculate_overall_loss": False,
7474
"max_new_tokens": 32,
7575
}
7676

applications/ColossalEval/colossal_eval/dataset/cmmlu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
"calculate_loss": True,
8282
"all_classes": ["A", "B", "C", "D"],
8383
"language": "Chinese",
84-
"pretrain": False,
84+
"calculate_overall_loss": False,
8585
"max_new_tokens": 32,
8686
}
8787

applications/ColossalEval/colossal_eval/dataset/colossalai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"calculate_loss": False,
1313
"all_classes": None,
1414
"language": "Chinese",
15-
"pretrain": False,
15+
"calculate_overall_loss": False,
1616
"max_new_tokens": 256,
1717
}
1818

applications/ColossalEval/colossal_eval/dataset/cvalues.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"calculate_loss": False,
1616
"all_classes": ["A", "B"],
1717
"language": LANGUAGE,
18-
"pretrain": False,
18+
"calculate_overall_loss": False,
1919
"max_new_tokens": 32,
2020
}
2121

applications/ColossalEval/colossal_eval/dataset/gaokaobench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"calculate_loss": True,
3737
"all_classes": None,
3838
"language": "Chinese",
39-
"pretrain": False,
39+
"calculate_overall_loss": False,
4040
"max_new_tokens": 32,
4141
}
4242

applications/ColossalEval/colossal_eval/dataset/gsm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
"calculate_loss": True,
7373
"all_classes": None,
7474
"language": "English",
75-
"pretrain": False,
75+
"calculate_overall_loss": False,
7676
"max_new_tokens": 256,
7777
}
7878

@@ -114,7 +114,7 @@ def load(
114114
dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs)
115115

116116
if forward_only:
117-
dataset[split][subject]["inference_kwargs"]["pretrain"] = True
117+
dataset[split][subject]["inference_kwargs"]["calculate_overall_loss"] = True
118118

119119
if split == "test" and few_shot:
120120
dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data()

applications/ColossalEval/colossal_eval/dataset/longbench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
"calculate_loss": True,
6161
"all_classes": None,
6262
"language": "Chinese",
63-
"pretrain": False,
63+
"calculate_overall_loss": False,
6464
"max_new_tokens": 32,
6565
}
6666

applications/ColossalEval/colossal_eval/dataset/mmlu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"calculate_loss": True,
1212
"all_classes": ["A", "B", "C", "D"],
1313
"language": "English",
14-
"pretrain": False,
14+
"calculate_overall_loss": False,
1515
"max_new_tokens": 32,
1616
}
1717

applications/ColossalEval/colossal_eval/dataset/mtbench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"calculate_loss": False,
1515
"all_classes": None,
1616
"language": "English",
17-
"pretrain": False,
17+
"calculate_overall_loss": False,
1818
"max_new_tokens": 1024,
1919
"turns": 2,
2020
}

applications/ColossalEval/colossal_eval/dataset/safetybench_en.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"calculate_loss": False,
2929
"all_classes": ["A", "B", "C", "D"],
3030
"language": LANGUAGE,
31-
"pretrain": False,
31+
"calculate_overall_loss": False,
3232
"max_new_tokens": 32,
3333
}
3434

applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
"calculate_loss": False,
2929
"all_classes": ["A", "B", "C", "D"],
3030
"language": LANGUAGE,
31-
"pretrain": False,
31+
"calculate_overall_loss": False,
3232
"max_new_tokens": 32,
3333
}
3434

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base import BaseModel
22
from .chatglm import ChatGLM2Model, ChatGLMModel
33
from .huggingface import HuggingFaceCausalLM, HuggingFaceModel
4+
from .vllm import vLLMModel
45

5-
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"]
6+
__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"]

applications/ColossalEval/colossal_eval/models/chatglm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List
2828

2929
@torch.no_grad()
3030
def get_loss(
31-
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
31+
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
3232
) -> List[List[float]]:
3333
"""
3434
Calculate loss only on target tokens.
@@ -225,7 +225,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str
225225

226226
@torch.no_grad()
227227
def get_loss(
228-
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
228+
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False
229229
) -> List[List[float]]:
230230
"""
231231
Calculate loss only on target tokens.

applications/ColossalEval/colossal_eval/models/huggingface.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kw
105105
elif hasattr(self.tokenizer, "eod_id"):
106106
# Qwen has an eod token "<|endoftext|>".
107107
self.tokenizer.pad_token_id = self.tokenizer.eod_id
108+
else:
109+
self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.")
110+
raise ValueError(
111+
"The tokenizer does not have a pad_token_id, eos_token, or eod_id. "
112+
"Please set pad_token_id manually."
113+
)
108114

109115
def _load_model(
110116
self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None
@@ -245,7 +251,7 @@ def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[L
245251
return input_ids_list, labels_list, bytes_list
246252

247253
def _get_input_ids_and_labels(
248-
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool
254+
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
249255
) -> Tuple[List[torch.LongTensor]]:
250256
"""
251257
Get input_ids and labels for the given data.
@@ -258,7 +264,7 @@ def _get_input_ids_and_labels(
258264
Input_ids and labels for the given batch.
259265
260266
"""
261-
if pretrain:
267+
if calculate_overall_loss:
262268
batch = []
263269
# Concatenate prompt and target answers.
264270
# You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space.
@@ -342,7 +348,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d
342348
calculate_loss = inference_kwargs["calculate_loss"]
343349
classes = inference_kwargs["all_classes"]
344350
language = inference_kwargs["language"]
345-
pretrain = inference_kwargs["pretrain"]
351+
calculate_overall_loss = inference_kwargs["calculate_overall_loss"]
346352
max_new_tokens = inference_kwargs["max_new_tokens"]
347353
few_shot_data = inference_kwargs.get("few_shot_data", None)
348354

@@ -384,12 +390,12 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d
384390
self.logger.info("-" * 120)
385391
self.logger.info(batch_prompt[0] + batch_target[0][0])
386392

387-
if not pretrain:
393+
if not calculate_overall_loss:
388394
batch_decodes, scores = self.generate(batch_prompt, max_new_tokens)
389395

390396
if calculate_loss:
391397
batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss(
392-
batch_prompt, batch_target, pretrain
398+
batch_prompt, batch_target, calculate_overall_loss
393399
)
394400

395401
probs = []
@@ -409,7 +415,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d
409415
]
410416

411417
for j in range(len(batch)):
412-
if not pretrain:
418+
if not calculate_overall_loss:
413419
if isinstance(batch[j]["output"], list):
414420
batch[j]["output"].append(batch_decodes[j].strip())
415421
else:
@@ -496,7 +502,9 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str
496502
return decoded_sequences, scores
497503

498504
@torch.no_grad()
499-
def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]:
505+
def get_loss(
506+
self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool
507+
) -> List[List[float]]:
500508
"""
501509
Calculate loss only on target tokens.
502510
@@ -513,13 +521,15 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr
513521
# We don't need to generate new tokens.
514522
# Target answer's length is usually << model_max_length, but we still call it in case.
515523
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
516-
if not pretrain:
524+
if not calculate_overall_loss:
517525
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
518526

519527
# Get the number of target answers for different questions
520528
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
521529

522-
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain)
530+
input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(
531+
batch_prompt, batch_target, calculate_overall_loss
532+
)
523533

524534
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
525535
# We will generate new batches.

0 commit comments

Comments
 (0)