forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdefault.py
More file actions
210 lines (194 loc) · 8.59 KB
/
default.py
File metadata and controls
210 lines (194 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.encoder_cache import EncoderCache
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.mm.xdrope_utils import XDRopeState
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class DefaultModelState(ModelState):
def __init__(
self,
vllm_config: VllmConfig,
model: nn.Module,
encoder_cache: EncoderCache | None,
device: torch.device,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = vllm_config.scheduler_config
self.model = model
self.device = device
self.supports_mm_inputs = encoder_cache is not None
self.max_model_len = self.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.inputs_embeds_size = self.model_config.get_inputs_embeds_size()
self.dtype = self.model_config.dtype
if self.supports_mm_inputs:
assert encoder_cache is not None
self.encoder_cache = encoder_cache
self.encoder_runner = EncoderRunner(
model=self.model,
max_num_tokens=self.max_num_tokens,
hidden_size=self.inputs_embeds_size,
encoder_cache=encoder_cache,
dtype=self.dtype,
device=self.device,
)
self.uses_mrope = self.model_config.uses_mrope
if self.uses_mrope:
self.mrope_state = MRopeState(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
self.xdrope_state: XDRopeState | None = None
if self.model_config.uses_xdrope_dim > 0:
self.xdrope_state = XDRopeState(
uses_xdrope_dim=self.model_config.uses_xdrope_dim,
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
max_model_len=self.max_model_len,
device=self.device,
)
def add_request(self, req_index: int, new_req_data: NewRequestData) -> None:
if self.uses_mrope:
# Pre-compute M-RoPE positions for prefill.
assert new_req_data.prefill_token_ids is not None
self.mrope_state.init_prefill_mrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
elif self.xdrope_state is not None:
# Pre-compute XD-RoPE positions for prefill.
assert new_req_data.prefill_token_ids is not None
self.xdrope_state.init_prefill_xdrope_positions(
req_index,
self.model, # type: ignore
new_req_data.prefill_token_ids,
mm_features=new_req_data.mm_features,
)
def apply_staged_writes(self) -> None:
if self.uses_mrope:
self.mrope_state.apply_staged_writes()
elif self.xdrope_state is not None:
self.xdrope_state.apply_staged_writes()
def get_mm_embeddings(
self,
scheduled_encoder_inputs: dict[str, list[int]],
input_batch: InputBatch,
req_states: RequestState,
) -> torch.Tensor:
mm_hashes, mm_kwargs = self.encoder_runner.prepare_mm_inputs(
scheduled_encoder_inputs
)
if mm_kwargs:
# Execute the multimodal encoder.
encoder_outputs = self.encoder_runner.execute_mm_encoder(mm_kwargs)
# Cache the encoder outputs by mm_hash
self.encoder_cache.encoder_outputs.update(zip(mm_hashes, encoder_outputs))
mm_embeds, is_mm_embed = self.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
req_states.prefill_len.np[input_batch.idx_mapping_np],
req_states.num_computed_prefill_tokens[input_batch.idx_mapping_np],
)
# Use unpadded input_ids to match is_mm_embed size (num_tokens).
# input_batch.input_ids may be padded for CUDA graphs.
input_ids_unpadded = input_batch.input_ids[: input_batch.num_tokens]
inputs_embeds = self.encoder_runner.get_inputs_embeds(
input_ids_unpadded, mm_embeds, is_mm_embed
)
return inputs_embeds[: input_batch.num_tokens_after_padding]
def prepare_inputs(
self, input_batch: InputBatch, req_states: RequestState
) -> dict[str, torch.Tensor | None]:
if not self.uses_mrope and self.xdrope_state is None:
return {} # Common case (1D positions).
if self.uses_mrope:
# Prepare M-RoPE positions.
self.mrope_state.prepare_mrope_positions(
input_batch.idx_mapping,
input_batch.query_start_loc,
req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu,
)
mrope_positions = self.mrope_state.mrope_positions[
:, : input_batch.num_tokens_after_padding
]
return {"positions": mrope_positions}
# Prepare XD-RoPE positions.
assert self.xdrope_state is not None
self.xdrope_state.prepare_xdrope_positions(
input_batch.idx_mapping,
input_batch.query_start_loc,
req_states.prefill_len.gpu,
req_states.num_computed_tokens.gpu,
)
xdrope_positions = self.xdrope_state.xdrope_positions[
:, : input_batch.num_tokens_after_padding
]
return {"positions": xdrope_positions}
def prepare_dummy_inputs(self, num_reqs: int, num_tokens: int) -> dict[str, Any]:
model_inputs = {}
if self.supports_mm_inputs:
inputs_embeds = self.encoder_runner.inputs_embeds[:num_tokens]
model_inputs["inputs_embeds"] = inputs_embeds
if self.uses_mrope:
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
model_inputs["positions"] = mrope_positions
elif self.xdrope_state is not None:
xdrope_positions = self.xdrope_state.xdrope_positions[:, :num_tokens]
model_inputs["positions"] = xdrope_positions
return model_inputs
def prepare_attn(
self,
input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
for_capture: bool = False,
) -> dict[str, Any]:
if cudagraph_mode == CUDAGraphMode.FULL:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
# For piecewise cudagraphs and eager, use unpadded sizes.
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata