forked from hiyouga/EasyR1
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvllm_rollout_spmd.py
More file actions
226 lines (194 loc) · 9.94 KB
/
vllm_rollout_spmd.py
File metadata and controls
226 lines (194 loc) · 9.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import torch.distributed
from tensordict import TensorDict
from transformers import PreTrainedTokenizer
from vllm import LLM, RequestOutput, SamplingParams
from ...protocol import DataProto
from ...utils import torch_functional as VF
from ...utils.dataset import process_image
from ...utils.tokenizer import get_processor
from ...utils.torch_dtypes import PrecisionType
from .base import BaseRollout
from .config import RolloutConfig
def _repeat_interleave(value: Union[torch.Tensor, np.ndarray], repeats: int) -> Union[torch.Tensor, List[Any]]:
# repeat the elements, supports both tensor and numpy array
if isinstance(value, torch.Tensor):
return value.repeat_interleave(repeats, dim=0)
else:
return np.repeat(value, repeats, axis=0)
def _get_logit_bias(model_path: str, trust_remote_code: bool) -> Optional[Dict[int, float]]:
# enforce vllm to not output image token
# TODO: add video token
processor = get_processor(model_path, trust_remote_code=trust_remote_code)
if processor is not None and hasattr(processor, "image_token"):
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
return {image_token_id: -100}
else:
return None
def _process_multi_modal_data(multi_modal_data: Dict[str, Any], min_pixels: int, max_pixels: int) -> Dict[str, Any]:
# may convert image path to image object
# TODO: add video
images = []
for image in multi_modal_data["images"]:
images.append(process_image(image, min_pixels=min_pixels, max_pixels=max_pixels))
return {"image": images}
class vLLMRollout(BaseRollout):
def __init__(self, model_path: str, config: RolloutConfig, tokenizer: PreTrainedTokenizer):
"""A vLLM rollout. It requires the module is supported by the vllm.
Args:
module: module here follows huggingface APIs
config: DictConfig
tokenizer: the task/model tokenizer
"""
super().__init__()
self.rank = int(os.getenv("RANK", "0"))
self.config = config
self.pad_token_id = tokenizer.pad_token_id
if config.tensor_parallel_size > torch.distributed.get_world_size():
raise ValueError("Tensor parallelism size should be less than world size.")
if config.max_num_batched_tokens < config.prompt_length + config.response_length:
raise ValueError("max_num_batched_tokens should be greater than prompt_length + response_length.")
engine_kwargs = {}
if config.limit_images:
engine_kwargs["limit_mm_per_prompt"] = {"image": config.limit_images}
engine_kwargs["disable_mm_preprocessor_cache"] = True
self.inference_engine = LLM(
model=model_path,
skip_tokenizer_init=False,
trust_remote_code=config.trust_remote_code,
load_format="dummy",
dtype=PrecisionType.to_str(PrecisionType.to_dtype(config.dtype)),
seed=config.seed,
max_model_len=config.max_model_len or config.prompt_length + config.response_length,
distributed_executor_backend="external_launcher",
tensor_parallel_size=config.tensor_parallel_size,
gpu_memory_utilization=config.gpu_memory_utilization,
max_num_batched_tokens=config.max_num_batched_tokens,
disable_log_stats=config.disable_log_stats,
enforce_eager=config.enforce_eager,
disable_custom_all_reduce=True,
enable_chunked_prefill=config.enable_chunked_prefill,
enable_sleep_mode=True,
**engine_kwargs,
)
# Offload vllm model to reduce peak memory usage
self.inference_engine.sleep(level=1)
sampling_kwargs = {
"max_tokens": config.response_length,
"detokenize": False,
"logit_bias": _get_logit_bias(model_path, trust_remote_code=config.trust_remote_code),
}
default_sampling_params = SamplingParams()
for key in config.to_dict().keys():
if hasattr(default_sampling_params, key):
sampling_kwargs[key] = getattr(config, key)
print(f"Sampling params: {sampling_kwargs}.")
self.sampling_params = SamplingParams(**sampling_kwargs)
@contextmanager
def update_sampling_params(self, **kwargs):
# update sampling params
old_sampling_params_args = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params, key):
old_value = getattr(self.sampling_params, key)
old_sampling_params_args[key] = old_value
setattr(self.sampling_params, key, value)
yield
# roll back to previous sampling params
for key, value in old_sampling_params_args.items():
setattr(self.sampling_params, key, value)
@torch.no_grad()
def generate_sequences(self, prompts: DataProto) -> DataProto:
if self.rank == 0:
print("[Rollout] Start generating sequences.")
# left-padded attention_mask
input_ids: torch.Tensor = prompts.batch["input_ids"] # (bs, prompt_length)
attention_mask: torch.Tensor = prompts.batch["attention_mask"]
position_ids: torch.Tensor = prompts.batch["position_ids"]
eos_token_id: int = prompts.meta_info["eos_token_id"]
batch_size = input_ids.size(0)
non_tensor_batch = prompts.non_tensor_batch
batch_raw_prompt_ids = non_tensor_batch.pop("raw_prompt_ids")
batch_multi_modal_data = non_tensor_batch.get("multi_modal_data") # do not pop it
if batch_size != len(batch_raw_prompt_ids):
raise RuntimeError("vllm sharding manager is not work properly.")
if batch_multi_modal_data is not None:
min_pixels, max_pixels = prompts.meta_info["min_pixels"], prompts.meta_info["max_pixels"]
vllm_inputs = []
for raw_prompt_ids, multi_modal_data in zip(batch_raw_prompt_ids, batch_multi_modal_data):
vllm_inputs.append(
{
"prompt_token_ids": list(raw_prompt_ids),
"multi_modal_data": _process_multi_modal_data(multi_modal_data, min_pixels, max_pixels),
}
)
else:
vllm_inputs = [
{"prompt_token_ids": list(raw_prompt_ids)} for raw_prompt_ids in batch_raw_prompt_ids
]
# users can customize different sampling_params at different run
with self.update_sampling_params(**prompts.meta_info):
completions: List[RequestOutput] = self.inference_engine.generate(
prompts=vllm_inputs, sampling_params=self.sampling_params, use_tqdm=False
)
response_ids = [output.token_ids for completion in completions for output in completion.outputs]
response_ids = VF.pad_2d_list_to_length(
response_ids, self.pad_token_id, max_length=self.config.response_length
).to(input_ids.device)
if self.sampling_params.n > 1:
batch_size = batch_size * self.sampling_params.n
input_ids = _repeat_interleave(input_ids, self.sampling_params.n)
attention_mask = _repeat_interleave(attention_mask, self.sampling_params.n)
position_ids = _repeat_interleave(position_ids, self.sampling_params.n)
if "multi_modal_data" in non_tensor_batch:
non_tensor_batch["multi_modal_data"] = _repeat_interleave(
non_tensor_batch["multi_modal_data"], self.sampling_params.n
)
sequence_ids = torch.cat([input_ids, response_ids], dim=-1)
response_length = response_ids.size(1)
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
if position_ids.dim() == 3: # qwen2vl mrope
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
# prompt: left pad + response: right pad
# attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3 | 4,5,6,7,8,9,10,11]
response_position_ids = position_ids[..., -1:] + delta_position_id
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
response_mask = VF.get_response_mask(
response_ids=response_ids, eos_token_id=eos_token_id, dtype=attention_mask.dtype
)
attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
# all the tp ranks should contain the same data here. data in all ranks are valid
batch = TensorDict(
{
"prompts": input_ids,
"responses": response_ids,
"input_ids": sequence_ids, # here input_ids become the whole sentences
"attention_mask": attention_mask,
"response_mask": response_mask,
"position_ids": position_ids,
},
batch_size=batch_size,
)
if self.rank == 0:
print("[Rollout] Finish generating sequences.")
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch, meta_info=prompts.meta_info)