forked from vllm-project/tpu-inference
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtpu_runner.py
More file actions
1845 lines (1611 loc) · 82.3 KB
/
tpu_runner.py
File metadata and controls
1845 lines (1611 loc) · 82.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import logging
import random
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
import jax
import jax.numpy as jnp
import jaxtyping
import numpy as np
import vllm.envs as vllm_envs
from flax import nnx
from jax._src.pallas.utils import next_power_of_2
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.forward_context import set_forward_context
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import GrammarOutput
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, KVConnectorOutput, LogprobsLists,
ModelRunnerOutput)
from vllm.v1.request import Request
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorModelRunnerMixin
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
import tpu_inference.envs as envs
from tpu_inference import utils as common_utils
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
MESH_AXIS_NAMES_2D,
ShardingAxisName,
ShardingConfigManager)
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
gather_logprobs, sample)
from tpu_inference.layers.jax.sample.sampling_metadata import \
TPUSupportedSamplingMetadata
from tpu_inference.logger import init_logger
from tpu_inference.models.common.interface import PoolerFunc
from tpu_inference.models.common.model_loader import get_model
from tpu_inference.models.jax.jax_intermediate_tensor import \
JaxIntermediateTensors
from tpu_inference.models.jax.utils.weight_utils import (
shard_put, transfer_state_with_mappings)
from tpu_inference.runner import utils as runner_utils
from tpu_inference.runner.compilation_manager import CompilationManager
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
from tpu_inference.runner.kv_cache_manager import KVCacheManager
from tpu_inference.runner.lora_utils import LoraUtils
from tpu_inference.runner.multimodal_manager import MultiModalManager
from tpu_inference.runner.persistent_batch_manager import \
PersistentBatchManager
from tpu_inference.runner.speculative_decoding_manager import (
SpecDecodeMetadata, SpeculativeDecodingManager)
from tpu_inference.runner.structured_decoding_manager import \
StructuredDecodingManager
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
from tpu_inference.utils import (device_array, make_optimized_mesh,
time_function, to_jax_dtype, to_torch_dtype)
logger = init_logger(__name__)
logging.getLogger("torchax.tensor").setLevel(logging.ERROR)
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
"""Holds asynchronous model output specifically from a TPU runner.
This class acts as a wrapper around the standard ModelRunnerOutput. Its
primary purpose is to hold references to data still on the TPU device
(like the `next_tokens` JAX array) without blocking the main thread.
The `get_output()` method is called to resolve these async results,
triggering the JAX device-to-host (CPU) data transfer and populating
the final `ModelRunnerOutput` object.
"""
def __init__(
self,
model_runner_output: ModelRunnerOutput,
next_tokens: jax.Array,
num_reqs: int,
discard_sampled_tokens_req_indices: list[int],
logits_indices_selector: Optional[List[int]] = None,
):
self._model_runner_output = model_runner_output
self._next_tokens = next_tokens
self._num_reqs = num_reqs
self._discard_sampled_tokens_req_indices = discard_sampled_tokens_req_indices
self.logits_indices_selector: list[int] = logits_indices_selector
def get_output(self) -> ModelRunnerOutput:
next_tokens_cpu = np.asarray(jax.device_get(self._next_tokens))
if self.logits_indices_selector is not None:
next_tokens_cpu = next_tokens_cpu[self.logits_indices_selector]
selected_token_ids = np.expand_dims(next_tokens_cpu[:self._num_reqs],
1)
valid_sampled_token_ids = selected_token_ids.tolist()
for i in self._discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
self._model_runner_output.sampled_token_ids = valid_sampled_token_ids
return self._model_runner_output
@dataclass
class AsyncPreResults:
req_ids: list[str]
next_tokens: jax.Array
request_seq_lens: list[tuple[int, CachedRequestState, int]]
discard_sampled_tokens_req_indices: list[int]
placeholder_req_id_to_index: dict[str, int]
logits_indices_selector: Optional[List[int]] = None
@dataclass
class ExecuteModelState:
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output: "VllmSchedulerOutput"
attn_metadata: AttentionMetadata
input_ids: Optional[jax.Array]
hidden_states: jax.Array
logits: jax.Array
aux_hidden_states: Optional[jax.Array]
spec_decode_metadata: Optional[SpecDecodeMetadata]
kv_connector_output: Optional[KVConnectorOutput]
logits_indices_selector: Optional[List[int]] = None
padded_num_reqs: Optional[int] = None
@jax.jit(donate_argnums=(0, 1, 2))
def _substitute_placeholder_token(
input_ids: jax.Array, token_in_tpu_cur_input_indices: jax.Array,
token_in_tpu_pre_next_tokens_indices: jax.Array,
next_tokens: jax.Array, placeholder_num: int):
"""Substitute placeholder tokens from TPU for async scheduler
Padding for parallelisation of the substitute_placeholder_token_fn
[1, 3] => [1, 3, 0, 2, 4, 5, 6, 7, 8]
The reason for such a special padding instead of padding with -1 is:
An edge case when the end index needs to be updated and padding is required.
If we pad the array with -1, the _substitute_placeholder_token_fn will repeatedly update the end element with the original value
Although such a scenario is unlikely to happen in vLLM, it is best to eliminate any potential risks.
Args:
input_ids: possible input_ids size
token_in_tpu_cur_input_indices: replace holder idx in input_ids. Length the same to input_ids.
token_in_tpu_pre_next_tokens_indices: value idx in next_tokens. Length the same to input_ids.
next_tokens: next tokens on the TPU from previous step.
placeholder_num: number of placeholders. placeholder_num <= len(token_in_tpu_cur_input_indices)
Return:
input_ids after replace placeholder tokens
"""
assert input_ids.shape == token_in_tpu_cur_input_indices.shape == token_in_tpu_pre_next_tokens_indices.shape, \
f"Shape mismatch: input_ids and index arrays must have identical shapes due to precompilation assumptions. " \
f"Got: {input_ids.shape=}, {token_in_tpu_cur_input_indices.shape=}, {token_in_tpu_pre_next_tokens_indices.shape=}"
# updates the input_ids for all placeholders.
mask = jnp.arange(input_ids.shape[0]) < placeholder_num
new_token_values = next_tokens[token_in_tpu_pre_next_tokens_indices]
original_values = input_ids[token_in_tpu_cur_input_indices]
update_values = jnp.where(mask, new_token_values, original_values)
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
def _jax_logprobs_to_lists(logprobs_tensors,
logits_indices_selector=None,
cu_num_generated_tokens=None):
"""Convert JAX LogprobsTensors to LogprobsLists by converting JAX arrays to numpy."""
log_token_ids_list = logprobs_tensors.logprob_token_ids.tolist()
logprobs_list = logprobs_tensors.logprobs.tolist()
selected_token_ranks_list = logprobs_tensors.selected_token_ranks.tolist()
if logits_indices_selector is not None:
log_token_ids_list = [
log_token_ids_list[i] for i in logits_indices_selector
]
logprobs_list = [logprobs_list[i] for i in logits_indices_selector]
selected_token_ranks_list = [
selected_token_ranks_list[i] for i in logits_indices_selector
]
return LogprobsLists(
logprob_token_ids=np.asarray(log_token_ids_list),
logprobs=np.asarray(logprobs_list),
sampled_token_ranks=np.asarray(selected_token_ranks_list),
cu_num_generated_tokens=cu_num_generated_tokens,
)
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
def __init__(
self,
vllm_config: VllmConfig,
devices: List[Any],
rank: int = 0,
is_first_rank: bool = True,
is_last_rank: bool = True,
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
# TODO(jevinjiang): override block size based on RPA v3.
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device_config = vllm_config.device_config
self.devices = devices
self.dtype = self.model_config.dtype
self.maybe_forbid_compile = runner_utils.ForbidCompile(
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
self.dp_size = self.vllm_config.sharding_config.total_dp_size
self.rank = rank
self.is_first_rank = is_first_rank
self.is_last_rank = is_last_rank
self._init_random()
self._init_mesh()
self._init_phased_profiling()
self._init_mm()
self._init_inputs()
self._init_speculative_decoding()
# Delegate functions to specific manager classes.
self.compilation_manager = CompilationManager(self)
if self.is_last_rank:
self.speculative_decoding_manager = SpeculativeDecodingManager(
self)
self.structured_decoding_manager = StructuredDecodingManager(self)
self.kv_cache_manager = KVCacheManager(self)
self.mm_manager = MultiModalManager(self)
self.persistent_batch_manager = PersistentBatchManager(
self.requests, self.input_batch, self.encoder_cache,
self.uses_mrope, self.model_config, self.is_last_rank)
self.lora_utils = LoraUtils(self)
cache_dtype = self.cache_config.cache_dtype
if cache_dtype == "auto":
cache_dtype = self.dtype
self.kv_cache_dtype = to_torch_dtype(cache_dtype)
self._pre_async_results: AsyncPreResults | None = None
self._substitute_placeholder_token_fn = _substitute_placeholder_token
self.execute_model_state: ExecuteModelState | None = None
self.kv_caches: list[jax.Array] = []
self.layer_name_to_kvcache_index: dict[str, int] = {}
self.is_pooling_model: bool = self.model_config.runner_type == "pooling"
"""Generative model or pooling model select different computations."""
def _init_random(self):
if self.model_config.seed is None:
self.model_config.seed = 0
random.seed(self.model_config.seed)
np.random.seed(self.model_config.seed)
self.rng_key = jax.random.key(self.model_config.seed)
def _init_mesh(self) -> None:
if envs.NEW_MODEL_DESIGN:
self.mesh = self._create_new_model_mesh()
else:
# NOTE(wenxindongwork): The new MoE kernel expects a 2D mesh, so we need
# to create a 2D mesh for now. We should make the new_model_mesh as the default
# in the future.
self.mesh = self._create_2d_mesh()
logger.info(f"Init mesh | mesh={self.mesh}")
def _create_new_model_mesh(self) -> jax.sharding.Mesh:
num_slices = envs.NUM_SLICES
logger.info(f"Creating new model mesh | devices={len(self.devices)}, "
f"num_slices={num_slices}")
if num_slices == 1:
devices_array = self._create_single_slice_mesh()
else:
devices_array = self._create_multi_slice_mesh(num_slices)
return jax.sharding.Mesh(devices_array, MESH_AXIS_NAMES)
def _create_single_slice_mesh(self) -> jax.Array:
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
mesh_shape = (
sharding_strategy.model_dp_size,
sharding_strategy.attn_dp_size,
sharding_strategy.attn_dp_expert_size,
sharding_strategy.expert_size,
sharding_strategy.tp_size,
)
# Attempt to create a physically optimized mesh. Fall back to a simple
# logical reshape for non-power-of-two device counts (e.g., DP=6) to
# bypass strict physical topology constraints.
try:
return mesh_utils.create_device_mesh(
mesh_shape,
self.devices,
allow_split_physical_axes=True,
)
except (AssertionError, ValueError, RuntimeError) as e:
logger.warning(
"Physical mesh creation failed (shape=%s, devices=%d). "
"Falling back to logical reshape. Error: %s", mesh_shape,
len(self.devices), e)
return np.array(self.devices).reshape(mesh_shape)
def _create_multi_slice_mesh(self, num_slices: int) -> jax.Array:
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
dp_inner = sharding_strategy.model_dp_size // num_slices
# Splits data parallelism across multiple slices.
ici_mesh_shape = (
dp_inner,
sharding_strategy.attn_dp_size,
sharding_strategy.attn_dp_expert_size,
sharding_strategy.expert_size,
sharding_strategy.tp_size,
)
dcn_mesh_shape = (num_slices, 1, 1, 1, 1)
# Attempt to create a physically optimized hybrid mesh (ICI + DCN).
# Fall back to a logical reshape for non-power-of-two device counts
# to bypass strict hardware topology constraints across slices.
try:
return mesh_utils.create_hybrid_device_mesh(
mesh_shape=ici_mesh_shape,
dcn_mesh_shape=dcn_mesh_shape,
devices=self.devices,
allow_split_physical_axes=True,
)
except (AssertionError, ValueError, RuntimeError) as e:
logger.warning(
"Hybrid physical mesh creation failed. Falling back to logical reshape. "
"ICI shape: %s, DCN shape: %s, Error: %s", ici_mesh_shape,
dcn_mesh_shape, e)
return np.array(self.devices).reshape(
tuple(i * d for i, d in zip(ici_mesh_shape, dcn_mesh_shape)))
def _create_2d_mesh(self) -> jax.sharding.Mesh:
sharding_strategy: ShardingConfigManager = self.vllm_config.sharding_config
mesh_shape = (
sharding_strategy.model_dp_size,
sharding_strategy.tp_size,
)
enforce_device_order = (
self.vllm_config.sharding_config.device_indexes is not None
and len(self.vllm_config.sharding_config.device_indexes) > 0)
if enforce_device_order:
return jax.make_mesh(mesh_shape,
MESH_AXIS_NAMES_2D,
devices=self.devices)
else:
return make_optimized_mesh(mesh_shape,
MESH_AXIS_NAMES_2D,
devices=self.devices)
def _init_phased_profiling(self) -> None:
self.phased_profiling_dir = envs.PHASED_PROFILING_DIR
self.phase_based_profiler = None
if self.phased_profiling_dir:
self.phase_based_profiler = runner_utils.PhasedBasedProfiler(
self.phased_profiling_dir)
def _init_mm(self) -> None:
self.is_multimodal_model = None
self.uses_mrope = self.model_config.uses_mrope
self.supports_mm_inputs = True
def _init_speculative_decoding(self) -> None:
self.drafter = None
if self.speculative_config:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method == "eagle3":
self.drafter = Eagle3Proposer(self.vllm_config, self)
else:
raise NotImplementedError(
"Unsupported speculative decoding method: "
f"{self.speculative_config.method}")
self.rejection_sampler = RejectionSampler()
def _init_inputs(self) -> None:
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
# The total number of requests is dp_size * max_num_seqs
self.max_num_reqs = max(self.dp_size * scheduler_config.max_num_seqs,
MIN_NUM_SEQS)
additional_sizes = self.vllm_config.additional_config.get(
"compilation_sizes", [])
# [16, 32, 64, 128, 256, 512, 1024, 2048]
cache_dtype = self.cache_config.cache_dtype
if cache_dtype == "auto":
cache_dtype = self.dtype
kv_cache_dtype = to_jax_dtype(cache_dtype)
kv_packing = common_utils.get_dtype_packing(kv_cache_dtype)
self.num_tokens_paddings = runner_utils.get_token_paddings(
min_token_size=max(16, next_power_of_2(self.dp_size * kv_packing)),
max_token_size=scheduler_config.max_num_batched_tokens *
self.dp_size,
padding_gap=vllm_envs.VLLM_TPU_BUCKET_PADDING_GAP)
self.num_tokens_paddings = sorted(self.num_tokens_paddings +
additional_sizes)
self.num_tokens_paddings_per_dp = [
padding // self.dp_size for padding in self.num_tokens_paddings
]
# In case `max_num_tokens < max(num_tokens_paddings)` use the actual
# padded max value to pre-allocate data structures and pre-compile.
self.max_num_tokens = self.num_tokens_paddings[-1]
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# mm_hash -> encoder_output
self.encoder_cache: dict[str, jax.Array] = {}
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
pin_memory=False,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
)
self.input_ids_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
self.positions_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
# Note: self.input_batch and self.block_tables_cpu are both initialized
# with only 1 block_size. For hybrid kv cache, it will be re-init
# in kv_cache_manager's maybe_reinitialize_input_batch.
self.block_tables_cpu = [
np.zeros((self.max_num_reqs, self.max_num_blocks_per_req),
dtype=np.int32)
]
self.query_start_loc_cpu = np.zeros(self.max_num_reqs + self.dp_size,
dtype=np.int32)
self.seq_lens_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
self.logits_indices_cpu = np.zeros(self.max_num_reqs, dtype=np.int32)
# Range tensor with values [0 .. self.max_num_tokens - 1].
# Used to initialize positions / context_lens / seq_lens
# Keep in int64 to avoid overflow with long context
self.arange_cpu = np.arange(self.max_num_tokens, dtype=np.int64)
min_num_reqs = max(MIN_NUM_SEQS, next_power_of_2(self.dp_size))
self.num_reqs_paddings = runner_utils.get_req_paddings(
min_req_size=min_num_reqs, max_req_size=self.max_num_reqs)
self.num_reqs_paddings_per_dp = [
padding // self.dp_size for padding in self.num_reqs_paddings
]
# Padding for logits. Without speculative decoding, each request has one position to select from.
# With speculative decoding, each request has multiple positions to select from.
max_logits_per_req = 1
if self.speculative_config:
max_logits_per_req = self.speculative_config.num_speculative_tokens + 1 # Including bonus token
self.num_logits_paddings = runner_utils.get_token_paddings(
min_token_size=MIN_NUM_SEQS,
max_token_size=self.max_num_reqs * max_logits_per_req,
padding_gap=0)
else:
self.num_logits_paddings = None
self.temperatures_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
self.top_ps_cpu = np.zeros(self.max_num_tokens, dtype=np.float32)
self.top_ks_cpu = np.zeros(self.max_num_tokens, dtype=np.int32)
# tensors for structured decoding
self.vocab_size = self.model_config.get_vocab_size()
self.grammar_bitmask_cpu = np.zeros(
(self.max_num_reqs, cdiv(self.vocab_size, 32)),
dtype=np.int32,
)
self.require_structured_out_cpu = np.zeros(
(self.max_num_reqs, 1),
dtype=np.bool_,
)
self.structured_decode_arange = np.arange(0, 32, dtype=np.int32)
# multi-modal support
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions_cpu = np.zeros((3, self.max_num_tokens),
dtype=np.int64)
def load_model(self):
self.model_fn, self.compute_logits_fn, self.pooler_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
self.vllm_config,
self.rng_key,
self.mesh,
)
multimodal_fns = multimodal_fns or {}
self.precompile_vision_encoder_fn = multimodal_fns.get(
"precompile_vision_encoder_fn", None)
self.embed_multimodal_fn = multimodal_fns.get("embed_multimodal_fn",
None)
self.embed_input_ids_fn = multimodal_fns.get("embed_input_ids_fn",
None)
self.get_mrope_input_positions_fn = multimodal_fns.get(
"get_mrope_input_positions_fn", None)
if self.drafter is not None:
logger.info("Loading drafter model...")
self.drafter.load_model(self.state)
self.rng_params_for_sampling = nnx.Rngs(
jax.random.key(self.model_config.seed)).params()
self.is_multimodal_model = (self.model_config.is_multimodal_model
and self.embed_multimodal_fn is not None
and hasattr(self.model_config.hf_config,
"architectures"))
logger.info(f"Init model | "
f"hbm={common_utils.hbm_usage_gb(self.devices)}GiB")
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
runner_type = self.model_config.runner_type
if runner_type == "generate":
return ("generate", )
if runner_type == "pooling":
return ("embed", )
assert False, f"unsupported runner type: {runner_type}"
def get_kv_cache_spec(self):
return self.kv_cache_manager.get_kv_cache_spec()
def initialize_kv_cache(self,
kv_cache_config: KVCacheConfig,
topology_order_id: int = 0) -> None:
self.topology_order_id = topology_order_id
self.kv_cache_config = kv_cache_config
self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1
self.kv_cache_manager.initialize_kv_cache(kv_cache_config)
if has_kv_transfer_group():
get_kv_transfer_group().register_runner(self)
def delete_kv_cache(self) -> None:
self.kv_cache_manager.delete_kv_cache()
def reinitialize_kv_cache(self) -> None:
self.kv_cache_manager.reinitialize_kv_cache()
def capture_model(self) -> None:
self.compilation_manager.capture_model()
@time_function
def execute_model(
self,
scheduler_output: "VllmSchedulerOutput",
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
) -> ModelRunnerOutput | JaxIntermediateTensors | None:
if self.execute_model_state is not None:
raise RuntimeError("State error: sample_tokens() must be called "
"after execute_model() returns None.")
reqs = self.input_batch.num_reqs
toks = scheduler_output.total_num_scheduled_tokens
with jax.set_mesh(self.mesh), jax.profiler.TraceAnnotation(
f"execute_model: {reqs} reqs, {toks} toks"):
output = self._execute_model(scheduler_output,
intermediate_tensors)
return output
def sample_tokens(
self,
grammar_output: "GrammarOutput | None",
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
if self.execute_model_state is None:
# This can happen in pipeline parallel case.
return EMPTY_MODEL_RUNNER_OUTPUT
(scheduler_output, attn_metadata, input_ids, hidden_states, logits,
aux_hidden_states, spec_decode_metadata, kv_connector_output,
logits_indices_selector,
padded_num_reqs) = (self.execute_model_state.scheduler_output,
self.execute_model_state.attn_metadata,
self.execute_model_state.input_ids,
self.execute_model_state.hidden_states,
self.execute_model_state.logits,
self.execute_model_state.aux_hidden_states,
self.execute_model_state.spec_decode_metadata,
self.execute_model_state.kv_connector_output,
self.execute_model_state.logits_indices_selector,
self.execute_model_state.padded_num_reqs)
self.execute_model_state = None
if grammar_output is not None:
(
require_struct_decoding, grammar_bitmask_padded, arange
) = self.structured_decoding_manager.prepare_structured_decoding_input(
logits, grammar_output)
logits = self.structured_decoding_manager.structured_decode_fn(
require_struct_decoding,
grammar_bitmask_padded,
logits,
arange,
)
return self._sample_from_logits(
scheduler_output, attn_metadata, input_ids, hidden_states, logits,
aux_hidden_states, spec_decode_metadata, kv_connector_output,
logits_indices_selector, padded_num_reqs)
def _modify_prev_results(self):
# If copy to host has not been done, we just wait.
# device_get should return immediately as we have scheduled it in previous function call.
assert self._pre_async_results is not None, "When we call _modify_prev_results(), self._pre_async_results should already exist"
pre_req_ids = self._pre_async_results.req_ids
pre_next_tokens = self._pre_async_results.next_tokens
pre_request_seq_lens = self._pre_async_results.request_seq_lens
pre_discard_sampled_tokens_req_indices = self._pre_async_results.discard_sampled_tokens_req_indices
pre_logits_indices_selector = self._pre_async_results.logits_indices_selector
next_tokens_cpu = np.asarray(jax.device_get(pre_next_tokens))
if pre_logits_indices_selector is not None:
next_tokens_cpu = next_tokens_cpu[pre_logits_indices_selector]
selected_token_ids = np.expand_dims(next_tokens_cpu[:len(pre_req_ids)],
1)
valid_sampled_token_ids = selected_token_ids.tolist()
# Mask out the sampled tokens that should not be sampled.
for i in pre_discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Append sampled tokens
for pre_req_idx, req_state, _ in pre_request_seq_lens:
sampled_ids = valid_sampled_token_ids[pre_req_idx]
if not sampled_ids:
continue
# If request not active in the *current* batch (e.g. finished or evicted), skip it.
req_id = pre_req_ids[pre_req_idx]
if req_id not in self.input_batch.req_id_to_index:
continue
req_idx = self.input_batch.req_id_to_index[req_id]
assert req_state is self.requests[
req_id], "The req_state should be valid and identical"
# Updated on previous execute
end_idx = self.input_batch.num_tokens_no_spec[req_idx]
assert len(sampled_ids) == 1, "do not support spec decode yet"
start_idx = end_idx - 1
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
# Replace previous placeholder
req_state.output_token_ids[-1] = sampled_ids[-1]
def _update_placeholder(self,
discard_sampled_tokens_req_indices,
request_seq_lens,
logits_indices_selector=None):
placeholder_req_id_to_index: dict[str, int] = {}
discard_sampled_tokens_req_indices_set = set(
discard_sampled_tokens_req_indices)
for req_idx, req_state, _ in request_seq_lens:
if req_idx in discard_sampled_tokens_req_indices_set:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
# Not supporting spec decode yet, assume only 1 new token
end_idx = start_idx + 1
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
# Update cpu tokens at next execute and prepare input from tpu
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
# For placeholder, should be update on next execute.
req_state.output_token_ids.extend([0])
if logits_indices_selector is None:
placeholder_req_id_to_index[req_state.req_id] = req_idx
else:
placeholder_req_id_to_index[
req_state.req_id] = logits_indices_selector[req_idx]
return placeholder_req_id_to_index
def _execute_model(
self,
scheduler_output: "VllmSchedulerOutput",
intermediate_tensors: Optional[JaxIntermediateTensors] = None,
) -> JaxIntermediateTensors | ModelRunnerOutput | None:
self.persistent_batch_manager.update_states(
scheduler_output, self.get_mrope_input_positions_fn)
if not scheduler_output.total_num_scheduled_tokens:
if has_kv_transfer_group():
return self.kv_connector_no_forward(scheduler_output,
self.vllm_config)
# Return empty ModelRunnerOutput if there's no work to do.
# TODO(fhzhang): We rely on empty cycles to remove requests in input batch. Fix it to reduce overhead.
logger.debug(f"Nothing scheduled: {scheduler_output}!")
# NOTE(pooyam): There is no guarantee that scheduler is not sending empty output: https://github.com/vllm-project/vllm/blob/7cfea0df390c154c1026f77d3682e2733ca4aca8/vllm/v1/engine/core.py#L275
# Why they are not preventing that is not clear to me.
if len(scheduler_output.finished_req_ids) == 0:
logger.warning(
"Should not schedule a request that does nothing!")
# raise Exception(
# "Should not schedule a request that does nothing!")
return EMPTY_MODEL_RUNNER_OUTPUT
# TODO(pooyam): I guess we can remove returning sampling_metadata in `_prepare_inputs` after https://github.com/njhill/vllm/commit/b7433ca1a47732394b1bdea4099d98389515954b
(
input_ids,
input_positions,
attn_metadata,
_,
logits_indices,
spec_decode_metadata,
logits_indices_selector,
padded_num_reqs,
) = self._prepare_inputs(scheduler_output)
# multi-modal support
if self.is_multimodal_model:
# Run the multimodal encoder if any.
# We have the modality embeds at this time.
self.mm_manager.execute_mm_encoder(scheduler_output)
mm_embeds = self.mm_manager.gather_mm_embeddings(
scheduler_output, input_ids.shape[0])
else:
mm_embeds = []
# NOTE(Wenlong): For multi-modal model,
# it will embed the text tokens and merge with the existing modality embeds
# Later, the multi-modality model will take the embedding as the input.
# For text-only model, this does nothing. It will input the input_ids and
# leave the mebedding job inside the forward pass
input_ids, inputs_embeds = self._get_input_ids_embeds(
input_ids, mm_embeds)
lora_metadata = self.lora_utils.extract_lora_metadata()
# TODO: make _get_input_ids_embeds within this context
# NOTE: right now, mm model will use embeddings as the input,
# but text-only model will use input_ids
with self.maybe_forbid_compile:
with set_forward_context(
None,
self.vllm_config,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:
# NOTE(Wenlong): It takes both `input_ids` and `inputs_embeds`,
# but one of them would be `None`
(self.kv_caches, hidden_states,
aux_hidden_states) = self.model_fn(
self.state,
self.kv_caches,
input_ids,
attn_metadata,
inputs_embeds,
input_positions,
tuple(self.layer_name_to_kvcache_index.items()),
lora_metadata,
intermediate_tensors,
self.is_first_rank,
self.is_last_rank,
)
if not self.is_last_rank:
assert isinstance(hidden_states, JaxIntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
if self.is_pooling_model:
seq_lens = self.seq_lens_cpu[:self.input_batch.num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata()
pooler_fn: PoolerFunc = self.pooler_fn
pooler_output = pooler_fn(
hidden_states,
pooling_metadata,
seq_lens,
)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
)
hidden_states = self._select_from_array_fn(hidden_states,
logits_indices)
logits = self.compute_logits_fn(
self.state,
hidden_states,
lora_metadata,
)
self.execute_model_state = ExecuteModelState(
scheduler_output=scheduler_output,
attn_metadata=attn_metadata,
input_ids=input_ids,
hidden_states=hidden_states,
logits=logits,
aux_hidden_states=aux_hidden_states,
spec_decode_metadata=spec_decode_metadata,
kv_connector_output=kv_connector_output,
logits_indices_selector=logits_indices_selector,
padded_num_reqs=padded_num_reqs)
return None
def _sample_from_logits(
self,
scheduler_output: "VllmSchedulerOutput",
attn_metadata: AttentionMetadata,
input_ids: Optional[jax.Array],
hidden_states: jax.Array,
logits: jax.Array,
aux_hidden_states: Optional[jax.Array],
spec_decode_metadata: Optional[SpecDecodeMetadata],
kv_connector_output: Optional[KVConnectorOutput],
logits_indices_selector: Optional[List[int]] = None,
padded_num_reqs: Optional[int] = None,
) -> ModelRunnerOutput | AsyncTPUModelRunnerOutput:
if padded_num_reqs is None:
padded_num_reqs = runner_utils.get_padded_num_reqs_with_upper_limit(
self.input_batch.num_reqs, self.max_num_reqs)
sharding = None
if self.dp_size > 1:
sharding = NamedSharding(self.mesh,
PartitionSpec(ShardingAxisName.MLP_DATA))
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.mesh, self.input_batch, padded_num_reqs, sharding=sharding)
# TODO(pooyam): Should we move this to `_prepare_inputs`?
if tpu_sampling_metadata.do_sampling:
self.rng_params_for_sampling, step_rng = jax.random.split(
self.rng_params_for_sampling)
else:
step_rng = self.rng_params_for_sampling
if spec_decode_metadata is None:
next_tokens = sample(
step_rng,
self.mesh,
logits,
tpu_sampling_metadata,
)
else:
if tpu_sampling_metadata.do_sampling:
bonus_rng, rejection_rng = jax.random.split(step_rng)
else:
bonus_rng = step_rng
rejection_rng = step_rng
bonus_logits = self._select_from_array_fn(
logits, spec_decode_metadata.bonus_logits_indices)
bonus_token_ids = sample(
bonus_rng,
self.mesh,
bonus_logits,
tpu_sampling_metadata,
)
target_logits = self._select_from_array_fn(
logits, spec_decode_metadata.target_logits_indices)
next_tokens = self.rejection_sampler(
draft_token_ids=spec_decode_metadata.draft_token_ids,
num_draft_tokens=spec_decode_metadata.draft_lengths,
draft_probs=None,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=tpu_sampling_metadata,
key=rejection_rng,
)
if tpu_sampling_metadata.logprobs:
logprobs = self._compute_and_gather_logprobs(
logits, next_tokens, self.model_config.max_logprobs)
else:
logprobs = None
num_reqs = self.input_batch.num_reqs
# Update the cache state concurrently. Code above will not block until
# We use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
discard_sampled_tokens_req_indices = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs])
prompt_logprobs_dict = {}
for req_id in self.input_batch.req_ids[:num_reqs]:
prompt_logprobs_dict[req_id] = None
# If async scheduler enabled
if self.scheduler_config.async_scheduling:
# Get previous results from TPU and replace the placeholder.
if self._pre_async_results is not None:
assert not self.speculative_config and spec_decode_metadata is None, "Async scheduler does not support speculative decoding yet."
self._modify_prev_results()
# Set placeholder for next tokens that is not yet generated
placeholder_req_id_to_index: dict[
str, int] = self._update_placeholder(
discard_sampled_tokens_req_indices, request_seq_lens,
logits_indices_selector)
if logprobs is not None:
# Map logprobs back to the pre-dp shuffling order
logprobs_lists = _jax_logprobs_to_lists(
logprobs, logits_indices_selector)
else:
logprobs_lists = None
# Save the previous results
next_tokens = jax.copy_to_host_async(next_tokens)
self._pre_async_results = AsyncPreResults(
req_ids=req_ids,
next_tokens=next_tokens,
request_seq_lens=request_seq_lens,
discard_sampled_tokens_req_indices=
discard_sampled_tokens_req_indices,
placeholder_req_id_to_index=placeholder_req_id_to_index,
logits_indices_selector=logits_indices_selector)
# Return Model output to executor
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,