|
37 | 37 | from vllm.model_executor import SamplingMetadata |
38 | 38 | from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding |
39 | 39 | from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader |
40 | | -from vllm.sampling_params import SamplingType |
| 40 | +from vllm.sampling_params import SamplingParams, SamplingType |
41 | 41 | from vllm.sequence import IntermediateTensors |
42 | 42 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LazyLoader, check_use_alibi, |
43 | 43 | is_pin_memory_available, make_tensor_with_pad) |
44 | 44 | from vllm.v1.attention.backends.utils import CommonAttentionMetadata |
| 45 | +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput |
45 | 46 | from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, |
46 | 47 | KVCacheConfig, KVCacheSpec, |
47 | 48 | SlidingWindowSpec) |
|
60 | 61 | if TYPE_CHECKING: |
61 | 62 | import xgrammar as xgr |
62 | 63 | from vllm.model_executor.model_loader.tensorizer import TensorizerConfig |
63 | | - from vllm.v1.core.sched.output import SchedulerOutput |
64 | 64 | else: |
65 | 65 | xgr = LazyLoader("xgr", globals(), "xgrammar") |
66 | 66 |
|
@@ -253,6 +253,8 @@ def __init__( |
253 | 253 | self.max_num_batched_tokens = ( |
254 | 254 | self.scheduler_config.max_num_batched_tokens) |
255 | 255 |
|
| 256 | + self._accumulative_compilation_count = 0 |
| 257 | + |
256 | 258 | def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
257 | 259 | """Update the cached states and the persistent batch with the scheduler |
258 | 260 | output. |
@@ -759,6 +761,111 @@ def compute_logits( |
759 | 761 |
|
760 | 762 | return self.model.compute_logits(hidden_states, sampling_metadata) |
761 | 763 |
|
| 764 | + @torch.inference_mode() |
| 765 | + def warmup_model(self) -> None: |
| 766 | + # compile prefill graph |
| 767 | + prefill_seq_len = (self.scheduler_config.max_num_batched_tokens |
| 768 | + if self.scheduler_config.chunked_prefill_enabled |
| 769 | + else self.scheduler_config.max_model_len) |
| 770 | + dummy_prefill_schedule = SchedulerOutput( |
| 771 | + scheduled_new_reqs=[ |
| 772 | + NewRequestData( |
| 773 | + req_id="dummy_prefill", |
| 774 | + prompt_token_ids=list(range(prefill_seq_len)), |
| 775 | + mm_hashes=[], |
| 776 | + mm_inputs=[], |
| 777 | + mm_positions=[], |
| 778 | + sampling_params=SamplingParams(temperature=0.0), |
| 779 | + block_ids=([0], ), |
| 780 | + num_computed_tokens=0, |
| 781 | + lora_request=None, |
| 782 | + ) |
| 783 | + ], |
| 784 | + scheduled_cached_reqs=[], |
| 785 | + num_scheduled_tokens={"dummy_prefill": prefill_seq_len}, |
| 786 | + total_num_scheduled_tokens=prefill_seq_len, |
| 787 | + scheduled_spec_decode_tokens={}, |
| 788 | + scheduled_encoder_inputs={}, |
| 789 | + num_common_prefix_blocks=[0], |
| 790 | + finished_req_ids=set(), |
| 791 | + free_encoder_input_ids=[], |
| 792 | + structured_output_request_ids={}, |
| 793 | + grammar_bitmask=None, |
| 794 | + kv_connector_metadata=None) |
| 795 | + dummy_prefill_cleanup = SchedulerOutput( |
| 796 | + scheduled_new_reqs=[], |
| 797 | + scheduled_cached_reqs=[], |
| 798 | + num_scheduled_tokens={}, |
| 799 | + total_num_scheduled_tokens=0, |
| 800 | + scheduled_spec_decode_tokens={}, |
| 801 | + scheduled_encoder_inputs={}, |
| 802 | + num_common_prefix_blocks=[1], |
| 803 | + finished_req_ids={ |
| 804 | + "dummy_prefill", |
| 805 | + }, |
| 806 | + free_encoder_input_ids=[], |
| 807 | + structured_output_request_ids={}, |
| 808 | + grammar_bitmask=None, |
| 809 | + kv_connector_metadata=None) |
| 810 | + self.execute_model(dummy_prefill_schedule) |
| 811 | + self.execute_model(dummy_prefill_cleanup) |
| 812 | + |
| 813 | + num_prefill_graphs = self._accumulative_compilation_count |
| 814 | + logger.info("Compiled %d graph(s) for prefill", num_prefill_graphs) |
| 815 | + |
| 816 | + # compile decode graph |
| 817 | + decode_max_batch_size = self.scheduler_config.max_num_seqs |
| 818 | + decode_max_seq_len = self.scheduler_config.max_model_len |
| 819 | + decode_max_num_blocks = (decode_max_seq_len + self.block_size - |
| 820 | + 1) // self.block_size |
| 821 | + dummy_decode_schedule = SchedulerOutput( |
| 822 | + scheduled_new_reqs=[ |
| 823 | + NewRequestData( |
| 824 | + req_id=f"dummy_decode_{i}", |
| 825 | + prompt_token_ids=list(range(decode_max_seq_len - 1)), |
| 826 | + mm_hashes=[], |
| 827 | + mm_inputs=[], |
| 828 | + mm_positions=[], |
| 829 | + sampling_params=SamplingParams(temperature=0.0), |
| 830 | + block_ids=([0] * decode_max_num_blocks, ), |
| 831 | + num_computed_tokens=decode_max_seq_len - 1, |
| 832 | + lora_request=None, |
| 833 | + ) for i in range(decode_max_batch_size) |
| 834 | + ], |
| 835 | + scheduled_cached_reqs=[], |
| 836 | + num_scheduled_tokens={ |
| 837 | + f"dummy_decode_{i}": 1 |
| 838 | + for i in range(decode_max_batch_size) |
| 839 | + }, |
| 840 | + total_num_scheduled_tokens=decode_max_batch_size, |
| 841 | + scheduled_spec_decode_tokens={}, |
| 842 | + scheduled_encoder_inputs={}, |
| 843 | + num_common_prefix_blocks=[0], |
| 844 | + finished_req_ids=set(), |
| 845 | + free_encoder_input_ids=[], |
| 846 | + structured_output_request_ids={}, |
| 847 | + grammar_bitmask=None, |
| 848 | + kv_connector_metadata=None) |
| 849 | + dummy_decode_cleanup = SchedulerOutput( |
| 850 | + scheduled_new_reqs=[], |
| 851 | + scheduled_cached_reqs=[], |
| 852 | + num_scheduled_tokens={}, |
| 853 | + total_num_scheduled_tokens=0, |
| 854 | + scheduled_spec_decode_tokens={}, |
| 855 | + scheduled_encoder_inputs={}, |
| 856 | + num_common_prefix_blocks=[1], |
| 857 | + finished_req_ids=set(f"dummy_decode_{i}" |
| 858 | + for i in range(decode_max_batch_size)), |
| 859 | + free_encoder_input_ids=[], |
| 860 | + structured_output_request_ids={}, |
| 861 | + grammar_bitmask=None, |
| 862 | + kv_connector_metadata=None) |
| 863 | + self.execute_model(dummy_decode_schedule) |
| 864 | + self.execute_model(dummy_decode_cleanup) |
| 865 | + |
| 866 | + logger.info("Compiled %d graph(s) for decode", |
| 867 | + self._accumulative_compilation_count - num_prefill_graphs) |
| 868 | + |
762 | 869 | @torch.inference_mode() |
763 | 870 | def execute_model( |
764 | 871 | self, |
@@ -1155,6 +1262,22 @@ def execute_model( |
1155 | 1262 | if has_kv_transfer_group(): |
1156 | 1263 | get_kv_transfer_group().clear_connector_metadata() |
1157 | 1264 |
|
| 1265 | + compilation_metrics = torch._dynamo.utils.get_compilation_metrics() |
| 1266 | + if len(compilation_metrics) > self._accumulative_compilation_count: |
| 1267 | + new_compilation_metrics = compilation_metrics[ |
| 1268 | + self._accumulative_compilation_count:] |
| 1269 | + reasons = ", ".join([ |
| 1270 | + cm.recompile_reason or "initial compilation" |
| 1271 | + for cm in new_compilation_metrics |
| 1272 | + ]) |
| 1273 | + logger.debug( |
| 1274 | + "graph compilation(s) triggered due to following reason(s): %s", |
| 1275 | + reasons) |
| 1276 | + self._accumulative_compilation_count += len( |
| 1277 | + new_compilation_metrics) |
| 1278 | + logger.debug("accumulative compilation count: %s", |
| 1279 | + self._accumulative_compilation_count) |
| 1280 | + |
1158 | 1281 | return ModelRunnerOutput( |
1159 | 1282 | req_ids=self.input_batch.req_ids, |
1160 | 1283 | req_id_to_index=self.input_batch.req_id_to_index, |
@@ -1633,6 +1756,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: |
1633 | 1756 | if has_kv_transfer_group(): |
1634 | 1757 | get_kv_transfer_group().register_kv_caches(kv_caches) |
1635 | 1758 |
|
| 1759 | + self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks |
| 1760 | + self.cache_config.num_cpu_blocks = 0 |
| 1761 | + |
1636 | 1762 | def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: |
1637 | 1763 | """ |
1638 | 1764 | Generates the KVCacheSpec by parsing the kv cache format from each |
|
0 commit comments