diff --git a/.github/workflows/pr_test_full.yaml b/.github/workflows/pr_test_full.yaml index 308a87128ba..7895be30530 100644 --- a/.github/workflows/pr_test_full.yaml +++ b/.github/workflows/pr_test_full.yaml @@ -74,7 +74,7 @@ jobs: name: e2e-full strategy: matrix: - vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0] + vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0] needs: [changes] if: ${{ needs.changes.outputs.e2e_tracker == 'true' }} uses: ./.github/workflows/_e2e_test.yaml diff --git a/.github/workflows/pr_test_light.yaml b/.github/workflows/pr_test_light.yaml index 0d5287a9fa8..b627b09d23e 100644 --- a/.github/workflows/pr_test_light.yaml +++ b/.github/workflows/pr_test_light.yaml @@ -42,7 +42,7 @@ jobs: lint: uses: ./.github/workflows/_pre_commit.yml with: - vllm: ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 + vllm: 97f2f160fda2805f9149b0e44da76b5d3b1f7c7e changes: runs-on: linux-aarch64-a2-0 outputs: @@ -90,7 +90,7 @@ jobs: SOC_VERSION: ascend910b1 strategy: matrix: - vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0] + vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0] steps: - name: Free up disk space @@ -154,7 +154,7 @@ jobs: name: e2e-light strategy: matrix: - vllm_version: [ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0] + vllm_version: [97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0] # Note (yikun): If CI resource are limited we can split job into two chain jobs needs: [lint, changes] # only trigger e2e test after lint passed and the change is e2e related with pull request. diff --git a/docs/source/community/versioning_policy.md b/docs/source/community/versioning_policy.md index 27c445beafb..80dbab6a734 100644 --- a/docs/source/community/versioning_policy.md +++ b/docs/source/community/versioning_policy.md @@ -45,7 +45,7 @@ The table below is the release compatibility matrix for vLLM Ascend release. For main branch of vLLM Ascend, we usually make it compatible with the latest vLLM release and a newer commit hash of vLLM. Please note that this table is usually updated. Please check it regularly. | vLLM Ascend | vLLM | Python | Stable CANN | PyTorch/torch_npu | |-------------|--------------|------------------|-------------|--------------------| -| main | ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 | +| main | 97f2f160fda2805f9149b0e44da76b5d3b1f7c7e, v0.12.0 tag | >= 3.10, < 3.12 | 8.3.RC2 | 2.8.0 / 2.8.0 | ## Release cadence diff --git a/tests/ut/compilation/test_acl_graph.py b/tests/ut/compilation/test_acl_graph.py index a8142fc66e4..5df8c29130d 100644 --- a/tests/ut/compilation/test_acl_graph.py +++ b/tests/ut/compilation/test_acl_graph.py @@ -803,7 +803,9 @@ def test_update_mla_dcp_pcp_params(self, _mock_graph_task_end): (q_nope, q_pe, k_nope, k_pe, block_table, seq_lens, num_heads, scale, num_kv_heads, out, lse)) - update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, 4) + with patch("torch_npu._C._npu_setStream", return_value=None): + update_mla_attn_dcp_pcp_params(self.update_stream, forward_context, + 4) _mock_graph_task_end.assert_called_once() @@ -842,6 +844,7 @@ def test_update_attn_dcp_pcp_params(self, _mock_graph_task_end): block_table, 128, actual_seq_lengths_kv, actual_seq_lengths_q, out, lse, 2, 0, 0)) - update_attn_dcp_pcp_params(self.update_stream, forward_context, 4) + with patch("torch_npu._C._npu_setStream", return_value=None): + update_attn_dcp_pcp_params(self.update_stream, forward_context, 4) _mock_graph_task_end.assert_called_once() diff --git a/tests/ut/spec_decode/test_eagle_proposer.py b/tests/ut/spec_decode/test_eagle_proposer.py index 094ca78aee2..745be6e564d 100644 --- a/tests/ut/spec_decode/test_eagle_proposer.py +++ b/tests/ut/spec_decode/test_eagle_proposer.py @@ -95,6 +95,8 @@ def test_load_model_pp1(self, mock_pp_group, mock_get_model, mock_model = MagicMock() mock_model.model.embed_tokens = MagicMock() mock_model.lm_head = MagicMock() + mock_model.multimodal_cpu_fields = None + mock_model.merge_by_field_config = None mock_get_model.return_value = MagicMock() self.proposer.name = SpecDcodeType.EAGLE @@ -117,6 +119,8 @@ def test_load_model_pp_gt1(self, mock_pp_group, mock_get_model, mock_model = MagicMock() original_embed = MagicMock() + mock_model.multimodal_cpu_fields = None + mock_model.merge_by_field_config = None mock_get_model.return_value = MagicMock(model=MagicMock( embed_tokens=original_embed)) diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py deleted file mode 100644 index 5b2cd82bfe7..00000000000 --- a/tests/ut/worker/test_input_batch.py +++ /dev/null @@ -1,375 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -import inspect -from collections.abc import Sequence -from typing import Optional - -import numpy as np -import pytest -import torch -from vllm.sampling_params import SamplingParams -from vllm.utils.torch_utils import make_tensor_with_pad -from vllm.v1.pool.metadata import PoolingMetadata -from vllm.v1.sample.logits_processor import LogitsProcessors -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.utils import CpuGpuBuffer - -from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable -from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch - -VOCAB_SIZE = 1024 -NUM_OUTPUT_TOKENS = 20 -MAX_PROMPT_SIZE = 100 -MAX_NUM_PROMPT_TOKENS = 64 - - -def _compare_objs(obj1, - obj2, - skip: Sequence = ("logitsprocs", "batch_update_builder")): - attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) - attr_names = set([ - a[0] for a in attrs - if not (a[0].startswith('__') and a[0].endswith('__')) - ]) - for attr_name in attr_names: - if attr_name in skip: - continue - - a = getattr(obj1, attr_name) - b = getattr(obj2, attr_name) - - is_same = False - if isinstance(a, torch.Tensor): - if (a.numel() == 0 or b.numel() == 0): - is_same = (a.numel() == 0 and b.numel() == 0) - elif torch.allclose(a, b): - is_same = True - elif isinstance(a, np.ndarray): - if np.allclose(a, b): - is_same = True - elif isinstance(a, MultiGroupBlockTable): - for a_i, b_i in zip(a.block_tables, b.block_tables): - _compare_objs(a_i, b_i) - is_same = True - elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)): - _compare_objs(a, b) - is_same = True # if we make it here must be same - elif a == b: - is_same = True - elif isinstance(a, CpuGpuBuffer): - is_same = np.allclose(a.np, b.np) and torch.allclose(a.gpu, b.gpu) - assert is_same, f"Attribute {attr_name} is different"\ - f" in {obj1} and {obj2}: {a} != {b}" - - -def _remove_requests(input_batch: InputBatch, batch_size: int, - reqs: list[CachedRequestState]) -> set[str]: - """ - Remove some requests randomly from the batch and returns - set of request removed - """ - - num_reqs_to_remove = np.random.randint(0, batch_size) - req_indices_to_remove: set[int] = set() - for _ in range(num_reqs_to_remove): - req_index_to_remove = np.random.randint(0, batch_size) - req_indices_to_remove.add(req_index_to_remove) - - req_ids_to_remove: set[str] = set() - for index in req_indices_to_remove: - input_batch.remove_request(reqs[index].req_id) - req_ids_to_remove.add(reqs[index].req_id) - return req_ids_to_remove - - -def _construct_expected_sampling_metadata( - reqs: list[CachedRequestState], - req_ids_retained: set[int], - req_id_index_in_input_batch: dict[str, int], - device: torch.device, -) -> SamplingMetadata: - """ - Constructs and returns the expected SamplingMetadata for this - batch. - """ - num_reqs = len(req_ids_retained) - output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)] - prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)] - presence_penalties = [0.0 for _ in range(num_reqs)] - frequency_penalties = [0.0 for _ in range(num_reqs)] - repetition_penalties = [1.0 for _ in range(num_reqs)] - top_k = [0 for _ in range(num_reqs)] - top_p = [0.0 for _ in range(num_reqs)] - temperature = [0.0 for _ in range(num_reqs)] - min_tokens = {} - logit_bias = [None] * num_reqs - allowed_token_ids_mask = torch.zeros(num_reqs, - VOCAB_SIZE, - dtype=torch.bool, - device=device) - bad_words_token_ids = {} - for req in reqs: - if req.req_id not in req_ids_retained: - continue - index_in_input_batch = req_id_index_in_input_batch[req.req_id] - output_token_ids[index_in_input_batch] = req.output_token_ids - prompt_token_ids[index_in_input_batch] = req.prompt_token_ids - presence_penalties[ - index_in_input_batch] = req.sampling_params.presence_penalty - frequency_penalties[index_in_input_batch] = ( - req.sampling_params.frequency_penalty) - repetition_penalties[index_in_input_batch] = ( - req.sampling_params.repetition_penalty) - top_k[index_in_input_batch] = req.sampling_params.top_k - top_p[index_in_input_batch] = req.sampling_params.top_p - temperature[index_in_input_batch] = req.sampling_params.temperature - min_tokens[index_in_input_batch] = ( - req.sampling_params.min_tokens, - req.sampling_params.all_stop_token_ids) - logit_bias[index_in_input_batch] = req.sampling_params.logit_bias - if req.sampling_params.allowed_token_ids: - allowed_token_ids_mask[index_in_input_batch][ - req.sampling_params.allowed_token_ids] = True - if req.sampling_params.bad_words_token_ids: - bad_words_token_ids[ - index_in_input_batch] = req.sampling_params.bad_words_token_ids - - return SamplingMetadata( - temperature=torch.tensor(temperature, dtype=torch.float, - device=device), - all_greedy=False, - all_random=True, - top_p=None if all(x == 1.0 for x in top_p) else torch.tensor( - top_p, dtype=torch.float, device=device), - top_k=None if all(x == 0 for x in top_k) else torch.tensor( - top_k, dtype=torch.int, device=device), - generators={}, - max_num_logprobs=0, - prompt_token_ids=make_tensor_with_pad( - prompt_token_ids, - pad=VOCAB_SIZE, - device=torch.device(device), - dtype=torch.int64, - ), - frequency_penalties=torch.tensor(frequency_penalties, - dtype=torch.float, - device=device), - presence_penalties=torch.tensor(presence_penalties, - dtype=torch.float, - device=device), - repetition_penalties=torch.tensor(repetition_penalties, - dtype=torch.float, - device=device), - output_token_ids=output_token_ids, - no_penalties=(all(x == 0 for x in presence_penalties) - and all(x == 0 for x in frequency_penalties) - and all(x == 1 for x in repetition_penalties)), - allowed_token_ids_mask=allowed_token_ids_mask, - bad_words_token_ids=bad_words_token_ids, - logitsprocs=LogitsProcessors(), - ) - - -def _create_sampling_params(): - return SamplingParams( - top_k=np.random.randint(1, 10), - top_p=np.random.uniform(0.0, 1.0), - presence_penalty=np.random.uniform(-2.0, 2.0), - repetition_penalty=np.random.uniform(0.0, 2.0), - frequency_penalty=np.random.uniform(-2.0, 2.0), - min_tokens=np.random.randint(1, 10), - stop_token_ids=[ - np.random.randint(0, VOCAB_SIZE) - for _ in range(np.random.randint(10)) - ], - logit_bias={0: np.random.uniform(-3.0, 3.0)}, - ) - - -def _construct_cached_request_state(req_id_suffix: int): - prompt_token_ids = [ - np.random.randint(0, VOCAB_SIZE) - for _ in range(np.random.randint(0, MAX_PROMPT_SIZE)) - ] - output_token_ids = [ - np.random.randint(0, VOCAB_SIZE) - for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS)) - ] - return CachedRequestState( - req_id=f"req_id_{req_id_suffix}", - prompt_token_ids=prompt_token_ids, - sampling_params=_create_sampling_params(), - pooling_params=None, - mm_kwargs=[], - mm_positions=[], - block_ids=([], ), - generator=None, - num_computed_tokens=len(output_token_ids), - output_token_ids=output_token_ids, - mm_hashes=None, - ) - - -@pytest.mark.parametrize("device", ["cpu"]) -@pytest.mark.parametrize("batch_size", [1, 2, 32, 64]) -def test_sampling_metadata_in_input_batch(device: str, batch_size: int): - """ - Tests the logic for managing sampling metadata in the InputBatch. - - This test involves adding a set of requests to the InputBatch, - followed by removing a subset of them. Afterward, the batch is compacted, - and the `make_sampling_metadata` method is invoked on the batch. The - output of `make_sampling_metadata` is then compared against the expected - results to ensure correctness. - - Note: Ignore logits processor logic, which is tested separately - """ - input_batch: InputBatch = InputBatch( - max_num_reqs=batch_size, - max_model_len=1024, - max_num_batched_tokens=1024, - device=torch.device(device), - pin_memory=False, - vocab_size=1024, - block_sizes=[1], - ) - reqs: list[CachedRequestState] = [] - req_id_reqs = {} - req_id_output_token_ids = {} - - # Add requests - for req_index in range(batch_size): - req: CachedRequestState = _construct_cached_request_state(req_index) - assigned_req_index = input_batch.add_request(req) - assert req_index == assigned_req_index - reqs.append(req) - req_id_reqs[req.req_id] = req - req_id_output_token_ids[req.req_id] = req.output_token_ids - - # Remove some requests - req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs) - req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove - - # Compact the input batch - input_batch.condense() - - # Generate the sampling metadata - sampling_metadata = input_batch._make_sampling_metadata() - - # Create expected output. - expected_sampling_metadata = _construct_expected_sampling_metadata( - reqs, - req_ids_retained, - input_batch.req_id_to_index, - device=torch.device(device)) - - def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool: - return (t1 is None - and t2 is None) or (t1 is not None and t2 is not None - and torch.allclose(t1, t2)) - - # Assert the actual and expected output. - assert torch.allclose(expected_sampling_metadata.temperature, - sampling_metadata.temperature) - assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p) - assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k) - assert torch.allclose( - expected_sampling_metadata.frequency_penalties, - sampling_metadata.frequency_penalties, - ) - assert torch.allclose( - expected_sampling_metadata.presence_penalties, - sampling_metadata.presence_penalties, - ) - assert torch.allclose( - expected_sampling_metadata.repetition_penalties, - sampling_metadata.repetition_penalties, - ) - assert torch.allclose(expected_sampling_metadata.prompt_token_ids, - sampling_metadata.prompt_token_ids) - assert (expected_sampling_metadata.output_token_ids == - sampling_metadata.output_token_ids) - assert expected_sampling_metadata.no_penalties == \ - sampling_metadata.no_penalties - if sampling_metadata.allowed_token_ids_mask: - assert torch.allclose( - expected_sampling_metadata.allowed_token_ids_mask, - sampling_metadata.allowed_token_ids_mask) - assert expected_sampling_metadata.bad_words_token_ids == \ - sampling_metadata.bad_words_token_ids - - -@pytest.mark.parametrize("device", ["cpu"]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("swap_list", [((0, 1), )]) -def test_swap_states_in_input_batch(device: str, batch_size: int, - swap_list: list): - """ - Tests the logic for managing sampling metadata in the InputBatch. - - This test involves adding a set of requests to the InputBatch, - followed by removing a subset of them. Afterward, the batch is compacted, - and the `make_sampling_metadata` method is invoked on the batch. The - output of `make_sampling_metadata` is then compared against the expected - results to ensure correctness. - - Note: Ignore logits processor logic, which is tested separately - """ - input_batch: InputBatch = InputBatch( - max_num_reqs=batch_size, - max_model_len=1024, - max_num_batched_tokens=1024, - device=torch.device(device), - pin_memory=False, - vocab_size=1024, - block_sizes=[1], - ) - ref_input_batch: InputBatch = InputBatch( - max_num_reqs=batch_size, - max_model_len=1024, - max_num_batched_tokens=1024, - device=torch.device(device), - pin_memory=False, - vocab_size=1024, - block_sizes=[1], - ) - - reqs: list[CachedRequestState] = [] - req_id_reqs = {} - req_id_output_token_ids = {} - # Add requests - for req_index in range(batch_size): - req: CachedRequestState = _construct_cached_request_state(req_index) - assigned_req_index = input_batch.add_request(req) - assert assigned_req_index == req_index - reqs.append(req) - req_id_reqs[req.req_id] = req - req_id_output_token_ids[req.req_id] = req.output_token_ids - - reordered_reqs = reqs.copy() - for swap_pair in swap_list: - reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \ - reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]] - input_batch.swap_states(swap_pair[0], swap_pair[1]) - - for req_index in range(batch_size): - req = reordered_reqs[req_index] - assigned_req_index = ref_input_batch.add_request(req) - assert assigned_req_index == req_index - - input_batch.refresh_metadata() - ref_input_batch.refresh_metadata() - - _compare_objs(input_batch, ref_input_batch) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e206ae84045..29643c0505b 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -41,7 +41,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, flashcomm2_o_shared_enabled, is_enable_nz, weak_ref_tensors) -from vllm_ascend.worker.npu_input_batch import InputBatch +from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -280,7 +280,7 @@ def __init__(self, dtype=torch.uint8, device=device) - def reorder_batch(self, input_batch: "InputBatch", + def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at # the front and the "prefill" requests are at the using the least amount diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index bc1de75fc4f..f6b338a807c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -32,7 +32,7 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, _round_up, dispose_layer, enable_sp, is_enable_nz, replace_layer) -from vllm_ascend.worker.npu_input_batch import InputBatch +from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -149,7 +149,7 @@ def __init__(self, self.enable_sfa_cp = enable_sp() and \ hasattr(self.model_config.hf_config, "index_topk") - def reorder_batch(self, input_batch: "InputBatch", + def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # No need to reorder for Ascend SFA return False diff --git a/vllm_ascend/eplb/utils.py b/vllm_ascend/eplb/utils.py index 8dfaf56293a..7099c25fb49 100644 --- a/vllm_ascend/eplb/utils.py +++ b/vllm_ascend/eplb/utils.py @@ -24,7 +24,7 @@ def get_expert_map(self, layer_id): - return self.model.layers[layer_id].mlp.experts.get_map() + return self.model.layers[layer_id].mlp.experts.expert_map def get_log2phy_map(self, layer_id): diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 4ce526e22ff..f4396992495 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -153,7 +153,7 @@ def __init__(self, *args, **kwargs): AscendFusedMoE.moe_counter += 1 self.moe_instance_id = AscendFusedMoE.moe_counter - self.expert_map = None + self._expert_map = None self.log2phy = None if self.quant_config is None: @@ -184,7 +184,7 @@ def __init__(self, *args, **kwargs): dtype=vllm_config.model_config.dtype) # init moe. - self.local_num_experts, self.expert_map, _ = determine_expert_map( + self.local_num_experts, self._expert_map, _ = determine_expert_map( self.ep_size, self.ep_rank, self.global_num_experts) # TODO: Temporary flag to indicate if static EPLB is enabled. This is a # workaround to bypass a quantization check that fails with float weights. @@ -200,7 +200,7 @@ def __init__(self, *args, **kwargs): self.expert_load_balancer.get_global_redundant_expert_num()) self.global_num_experts = num_experts + self.global_redundant_expert_num try: - self.local_num_experts, self.expert_map = ( + self.local_num_experts, self._expert_map = ( self.expert_load_balancer.get_rank_placement_map( self.moe_instance_id, self.ep_rank)) self.log2phy = self.expert_load_balancer.get_rank_log2phy_map( @@ -216,16 +216,16 @@ def __init__(self, *args, **kwargs): if self.dynamic_eplb: self.log2phy = determine_default_log2phy_map( self.global_num_experts, self.ep_size, self.ep_rank).npu() - if self.expert_map is not None and isinstance(self.expert_map, - torch.Tensor): + if self._expert_map is not None and isinstance(self._expert_map, + torch.Tensor): logger.info_once( "[EP Rank %s/%s] Expert parallelism is enabled. Local/global" " number of experts: %s/%s. Experts local to global index map:" " %s.", self.ep_rank, self.ep_size, self.local_num_experts, self.global_num_experts, - get_compressed_expert_map(self.expert_map)) + get_compressed_expert_map(self._expert_map)) local_num_experts = (torch.sum( - self.expert_map != -1) if self.expert_map is not None else + self._expert_map != -1) if self._expert_map is not None else self.global_num_experts) if self.dynamic_eplb: self.moe_load = torch.zeros(local_num_experts, @@ -276,10 +276,16 @@ def _get_quant_type(self) -> QuantType: return QuantType.NONE def update_expert_map(self, new_expert_map): - self.expert_map = new_expert_map + self._expert_map = new_expert_map - def get_map(self): - return self.expert_map + @property + def expert_map(self) -> torch.Tensor | None: + return self._expert_map + + @expert_map.setter + def expert_map(self, new_expert_map): + # TODO(Potabk): Remove this once we drop vllm v0.12.0(This makes backward compatibility with vllm v0.12.0) + self._expert_map = new_expert_map def get_log2phy_map(self): return self.log2phy diff --git a/vllm_ascend/patch/platform/__init__.py b/vllm_ascend/patch/platform/__init__.py index 26c4dc86771..0dff139ff84 100644 --- a/vllm_ascend/patch/platform/__init__.py +++ b/vllm_ascend/patch/platform/__init__.py @@ -17,10 +17,15 @@ import os import vllm_ascend.patch.platform.patch_distributed # noqa -import vllm_ascend.patch.platform.patch_ec_connector # noqa import vllm_ascend.patch.platform.patch_mamba_config # noqa import vllm_ascend.patch.platform.patch_sched_yield # noqa +from vllm_ascend.utils import vllm_version_is if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv( "EXPERT_MAP_RECORD", "false") == "true": import vllm_ascend.patch.platform.patch_multiproc_executor # noqa + +if vllm_version_is("0.12.0"): + import vllm_ascend.patch.platform.patch_ec_connector012 # noqa +else: + import vllm_ascend.patch.platform.patch_ec_connector # noqa diff --git a/vllm_ascend/patch/platform/patch_ec_connector.py b/vllm_ascend/patch/platform/patch_ec_connector.py index f0464b75e91..61ca8535052 100644 --- a/vllm_ascend/patch/platform/patch_ec_connector.py +++ b/vllm_ascend/patch/platform/patch_ec_connector.py @@ -1,16 +1,15 @@ -import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector +import vllm.distributed.ec_transfer.ec_connector.example_connector from safetensors.torch import load_file -from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata -from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( - ECSharedStorageConnector, ECSharedStorageConnectorMetadata) +from vllm.distributed.ec_transfer.ec_connector.example_connector import ( + ECConnectorMetadata, ECExampleConnector) from vllm.logger import logger -class AscendECSharedStorageConnector(ECSharedStorageConnector): +class AscendECExampleConnector(ECExampleConnector): def start_load_caches(self, encoder_cache, **kwargs) -> None: metadata: ECConnectorMetadata = self._get_connector_metadata() - assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert isinstance(metadata, ECConnectorMetadata) assert encoder_cache is not None if metadata is None: logger.warning(( @@ -29,4 +28,4 @@ def start_load_caches(self, encoder_cache, **kwargs) -> None: mm_data.mm_hash) -vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector +vllm.distributed.ec_transfer.ec_connector.example_connector.ECExampleConnector = AscendECExampleConnector diff --git a/vllm_ascend/patch/platform/patch_ec_connector012.py b/vllm_ascend/patch/platform/patch_ec_connector012.py new file mode 100644 index 00000000000..f0015738fb2 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_ec_connector012.py @@ -0,0 +1,33 @@ +import vllm.distributed.ec_transfer.ec_connector.shared_storage_connector # type: ignore[import-not-found] # noqa +from safetensors.torch import load_file +from vllm.distributed.ec_transfer.ec_connector.base import \ + ECConnectorMetadata # type: ignore[import-not-found] # noqa +from vllm.distributed.ec_transfer.ec_connector.shared_storage_connector import ( # type: ignore[import-not-found] # noqa + ECSharedStorageConnector, ECSharedStorageConnectorMetadata) +from vllm.logger import logger + + +class AscendECSharedStorageConnector(ECSharedStorageConnector): + + def start_load_caches(self, encoder_cache, **kwargs) -> None: + metadata: ECConnectorMetadata = self._get_connector_metadata() + assert isinstance(metadata, ECSharedStorageConnectorMetadata) + assert encoder_cache is not None + if metadata is None: + logger.warning(( + "In connector.start_load_caches, ", + "but the connector metadata is None", + )) + return + # Load the EC for each mm data + for mm_data in metadata.mm_datas: + if mm_data.mm_hash in encoder_cache: + continue + filename = self._generate_filename_debug(mm_data.mm_hash) + ec_cache = load_file(filename)["ec_cache"].npu() + encoder_cache[mm_data.mm_hash] = ec_cache + logger.debug("Success load encoder cache for hash %s", + mm_data.mm_hash) + + +vllm.distributed.ec_transfer.ec_connector.shared_storage_connector.ECSharedStorageConnector = AscendECSharedStorageConnector diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 45c874776e0..0eefec29f22 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -365,6 +365,10 @@ def get_attn_backend_cls( use_mla, has_sink=False, use_sparse=False, + # NOTE: Please pay special attention to the order of these parameters. + # Although we are only using some of them so far + # vllm passes them in sequence when using this interface. + use_mm_prefix: bool = False, attn_type: str | None = None, ): # choose attention backend based on use_mla diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 80299a05797..dc02e1bce84 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -476,9 +476,10 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None: # Calculate maximum supported batch sizes considering model architecture resources_per_graph = num_hidden_layers + 1 - if vllm_config.speculative_config is not None: - draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config - resources_per_graph += draft_model_hf_config.num_hidden_layers + 1 + # For suffix decoding, use the suffix path when no draft_model_config is provided. + if (spec := vllm_config.speculative_config) and \ + (draft := spec.draft_model_config): + resources_per_graph += draft.hf_config.num_hidden_layers + 1 # TODO: Find out whether we need to take into account the pp_size num_comm_groups = sum(size > 1 for size in [ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c56a4562518..4933d85f9d0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -121,8 +121,8 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendDeviceType, ProfileExecuteDuration, enable_sp, get_ascend_device_type, is_enable_nz, - is_moe_model, lmhead_tp_enable) -from vllm_ascend.worker.npu_input_batch import InputBatch + is_moe_model, lmhead_tp_enable, vllm_version_is) +from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -249,13 +249,24 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # Set up Attention self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, "index_topk") - self.attn_backend = get_attn_backend(0, - self.dtype, - None, - self.block_size, - use_mla=self.model_config.use_mla, - use_sparse=self.use_sparse) - + if vllm_version_is('0.12.0'): + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse) + else: + self.attn_backend = get_attn_backend( + 0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse, + use_mm_prefix=self.model_config is not None + and self.model_config.is_mm_prefix_lm) self.attn_mask_builder = AttentionMaskBuilder(self.device) self._set_up_drafter() @@ -353,7 +364,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): # solution, we initialize the input batch here, and re-initialize it # in `initialize_kv_cache` if the block_sizes here is different from # the block_sizes in the kv cache config. - self.input_batch = InputBatch( + self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, max_num_batched_tokens=self.max_num_tokens, @@ -1995,19 +2006,36 @@ def _build_dummy_attn_metadata( self.speculative_config.method == "mtp": attn_state = AscendAttentionState.SpecDecoding - common_metadata = CommonAttentionMetadata( - query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs + + if vllm_version_is("0.12.0"): + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[:num_reqs + 1], - seq_lens_cpu=self.seq_lens.cpu[:num_reqs], - seq_lens=self.seq_lens.cpu[:num_reqs], - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - block_table_tensor=block_table_tensor[:num_reqs], - slot_mapping=slot_mapping.gpu, - num_computed_tokens_cpu=num_computed_tokens_cpu, - max_query_len=max_query_len, - max_seq_len=seq_lens) + query_start_loc_cpu=self.query_start_loc. + cpu[:num_reqs + 1], + seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + seq_lens=self.seq_lens.cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=slot_mapping.gpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + max_seq_len=seq_lens) + else: + common_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc.gpu[:num_reqs + + 1], + query_start_loc_cpu=self.query_start_loc. + cpu[:num_reqs + 1], + _seq_lens_cpu=self.seq_lens.cpu[:num_reqs], + seq_lens=self.seq_lens.cpu[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + block_table_tensor=block_table_tensor[:num_reqs], + slot_mapping=slot_mapping.gpu, + _num_computed_tokens_cpu=num_computed_tokens_cpu, + max_query_len=max_query_len, + max_seq_len=seq_lens) for attn_group in self.attn_groups[kv_cache_group_id]: builder = attn_group.get_metadata_builder() @@ -2773,7 +2801,7 @@ def may_reinitialize_input_batch(self, "Cannot re-initialize the input batch when CPU weight " "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "for more details.") - self.input_batch = InputBatch( + self.input_batch = NPUInputBatch( max_num_reqs=self.max_num_reqs, max_model_len=self.model_config.max_model_len, max_num_batched_tokens=self.max_num_tokens, diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index d9db156640f..846c0d83fca 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -17,92 +17,29 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py # -from dataclasses import dataclass -from typing import Optional, cast - import numpy as np import torch -from typing_extensions import deprecated from vllm.lora.request import LoRARequest -from vllm.multimodal.inputs import (MultiModalFeatureSpec, - MultiModalKwargsItem, - MultiModalKwargsItems, PlaceholderRange) from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import length_from_prompt_token_ids_or_embeds -from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors -from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, - LogitsProcessors, - MoveDirectionality) -from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import is_spec_decode_unsupported -from vllm.v1.utils import copy_slice + LogitsProcessors) +from vllm.v1.worker.gpu_input_batch import InputBatch from vllm_ascend.worker.block_table import MultiGroupBlockTable -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: Optional[list[int]] - sampling_params: Optional[SamplingParams] - pooling_params: Optional[PoolingParams] - generator: Optional[torch.Generator] - - block_ids: tuple[list[int], ...] - num_computed_tokens: int - output_token_ids: list[int] - - mrope_positions: Optional[torch.Tensor] = None - mrope_position_delta: Optional[int] = None - - mm_features: Optional[list[MultiModalFeatureSpec]] = None - # for back-compatibility, will be removed in next major release - mm_kwargs: Optional[list[MultiModalKwargsItem]] = None - mm_positions: Optional[list[PlaceholderRange]] = None - mm_hashes: Optional[list[PlaceholderRange]] = None - - lora_request: Optional[LoRARequest] = None - prompt_embeds: Optional[torch.Tensor] = None - - prev_num_draft_len: int = 0 # previous number of draft tokens +class PoolingStates: + # NOTE: This should be removed after we drop support of vLLM v0.12.0 + def __init__(self): + # for chunked prefill with ALL pooling + self.hidden_states_cache: list[torch.Tensor] = [] - def __post_init__(self): - self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - self.prompt_token_ids, self.prompt_embeds) + def clean(self): + self.hidden_states_cache.clear() - @property - def num_tokens(self) -> int: - return self.num_prompt_tokens + len(self.output_token_ids) - # Temporary back-compatibility for plugins that define model runner - @property - @deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be " - "removed in v0.13. Please use `mm_kwargs` instead.") - def mm_inputs(self) -> list[MultiModalKwargsItems]: - assert self.mm_features is not None - return [ - MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features - if f.data is not None - ] - - def get_token_id(self, idx: int) -> int: - if idx < self.num_prompt_tokens: - if self.prompt_token_ids is None: - raise ValueError( - f"Tried to access token index {idx}, but that token was " - "provided via prompt_embeds, and its ID is unknown.") - return self.prompt_token_ids[idx] - elif idx - self.num_prompt_tokens < len(self.output_token_ids): - return self.output_token_ids[idx - self.num_prompt_tokens] - else: - return -1 - - -class InputBatch: +class NPUInputBatch(InputBatch): def __init__( self, @@ -113,12 +50,12 @@ def __init__( pin_memory: bool, vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group - logitsprocs: Optional[LogitsProcessors] = None, + kernel_block_sizes: list[list[int]], + logitsprocs: LogitsProcessors | None = None, logitsprocs_need_output_token_ids: bool = False, is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, - kernel_block_sizes: Optional[list[list[int]]] = None, cp_kv_cache_interleave_size: int = 1, ): self.is_pooling_model = is_pooling_model @@ -130,12 +67,12 @@ def __init__( self.pin_memory = pin_memory self.vocab_size = vocab_size - self._req_ids: list[Optional[str]] = [] + self._req_ids: list[str | None] = [] self.req_id_to_index: dict[str, int] = {} # TODO(woosuk): This buffer could be too large if max_model_len is big. # Find a way to reduce the CPU memory usage. - # This buffer is not directly transferred to the NPU, so it does not + # This buffer is not directly transferred to the GPU, so it does not # need to be pinned. self.token_ids_cpu_tensor = torch.zeros( (max_num_reqs, max_model_len), @@ -162,8 +99,8 @@ def __init__( dtype=torch.int32, pin_memory=pin_memory, ) - self.num_computed_tokens_cpu = \ - self.num_computed_tokens_cpu_tensor.numpy() + self.num_computed_tokens_cpu = self.num_computed_tokens_cpu_tensor.numpy( + ) # Block table. self.block_table = MultiGroupBlockTable( @@ -222,8 +159,8 @@ def __init__( dtype=torch.float, device="cpu", pin_memory=pin_memory) - self.frequency_penalties_cpu = \ - self.frequency_penalties_cpu_tensor.numpy() + self.frequency_penalties_cpu = self.frequency_penalties_cpu_tensor.numpy( + ) self.frequency_penalties_reqs: set[str] = set() # Presence penalty related data structures @@ -247,8 +184,8 @@ def __init__( dtype=torch.float, device="cpu", pin_memory=pin_memory) - self.repetition_penalties_cpu = \ - self.repetition_penalties_cpu_tensor.numpy() + self.repetition_penalties_cpu = self.repetition_penalties_cpu_tensor.numpy( + ) self.repetition_penalties_reqs: set[str] = set() # Speculative decoding @@ -256,12 +193,12 @@ def __init__( dtype=torch.int64, device="cpu", pin_memory=pin_memory) - self.num_accepted_tokens_cpu = \ - self.num_accepted_tokens_cpu_tensor.numpy() + self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy( + ) # lora related self.request_lora_mapping = np.zeros((self.max_num_reqs, ), - dtype=np.int32) + dtype=np.int64) self.lora_id_to_request_ids: dict[int, set[str]] = {} self.lora_id_to_lora_request: dict[int, LoRARequest] = {} @@ -271,9 +208,6 @@ def __init__( self.generators: dict[int, torch.Generator] = {} self.num_logprobs: dict[str, int] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs - # that are currently in the prefill phase. - self.num_prompt_logprobs: dict[str, int] = {} # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} @@ -287,8 +221,8 @@ def __init__( self.has_allowed_token_ids: set[str] = set() # NOTE(lufang): In the mask tensor, if the corresponding token allowed, # the value is False. Since we use masked_fill_ to set -inf. - self.allowed_token_ids_mask: Optional[torch.Tensor] = None - self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None + self.allowed_token_ids_mask: torch.Tensor | None = None + self.allowed_token_ids_mask_cpu_tensor: torch.Tensor | None = None # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} @@ -296,7 +230,7 @@ def __init__( self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, dtype=bool) - self.req_output_token_ids: list[Optional[list[int]]] = [] + self.req_output_token_ids: list[list[int] | None] = [] # Store provided logitsprocs. If none are provided, initialize empty # data structure @@ -310,673 +244,15 @@ def __init__( # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() + # for pooling models self.pooling_params: dict[str, PoolingParams] = {} + self.pooling_states: dict[str, PoolingStates] = {} # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: torch.Tensor | None = None - self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None self.prev_req_id_to_index: dict[str, int] | None = None # These are used to update output_token_ids with real sampled # ids from prior step, if required by current sampling params # (e.g. penalties). self.sampled_token_ids_cpu: torch.Tensor | None = None self.async_copy_ready_event: torch.Event | None = None - - @property - def req_ids(self) -> list[str]: - # None elements should only be present transiently - # while performing state updates to the batch. - return cast(list[str], self._req_ids) - - def _register_add_request(self, request: "CachedRequestState") -> int: - """Track add-request operations for logits processors. - Not applicable to pooling models. - """ - - # Detailed added request metadata is only required for non-pooling - # models, to support logitsprocs - assert request.sampling_params - - # Fill the next empty index if there is one. - if (new_req_index := self.batch_update_builder.pop_removed()) is None: - # Append to end otherwise. - new_req_index = self.num_reqs - - assert new_req_index < self.max_num_reqs - self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, request.prompt_token_ids, - request.output_token_ids)) - return new_req_index - - def add_request( - self, - request: "CachedRequestState", - ) -> int: - if not self.is_pooling_model: - # New request index bookkeeping for autoregressive models. - req_index = self._register_add_request(request) - else: - req_index = self.num_reqs - - req_id = request.req_id - if req_index == len(self._req_ids): - self._req_ids.append(req_id) - self.req_output_token_ids.append(request.output_token_ids) - self.spec_token_ids.append([]) - else: - self._req_ids[req_index] = req_id - self.req_output_token_ids[req_index] = request.output_token_ids - self.spec_token_ids[req_index].clear() - - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = length_from_prompt_token_ids_or_embeds( - request.prompt_token_ids, request.prompt_embeds) - self.num_prompt_tokens[req_index] = num_prompt_tokens - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - if request.prompt_token_ids is not None: - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - self.is_token_ids[req_index, :num_prompt_tokens] = True - else: - self.is_token_ids[req_index, :num_prompt_tokens] = False - if request.prompt_embeds is not None: - self.req_prompt_embeds[req_index] = request.prompt_embeds - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - self.is_token_ids[req_index, start_idx:end_idx] = True - # Number of token ids in prompt (token_ids_cpu or prompt_embeds). - # NOTE(woosuk): This may include spec decode tokens. - self.num_tokens[req_index] = request.num_tokens - # Number of tokens without spec decode tokens. - self.num_tokens_no_spec[req_index] = request.num_tokens - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(request.block_ids, req_index) - - if sampling_params := request.sampling_params: - if (self.is_spec_decode - and is_spec_decode_unsupported(sampling_params)): - self.spec_decode_unsupported_reqs.add(req_id) - if sampling_params.sampling_type == SamplingType.GREEDY: - # Avoid later division by zero. - self.temperature_cpu[req_index] = -1.0 - self.greedy_reqs.add(req_id) - else: - self.temperature_cpu[req_index] = sampling_params.temperature - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - top_k = sampling_params.top_k - if 0 < top_k < self.vocab_size: - self.top_k_reqs.add(req_id) - else: - top_k = self.vocab_size - self.top_k_cpu[req_index] = top_k - self.frequency_penalties_cpu[ - req_index] = sampling_params.frequency_penalty - if sampling_params.frequency_penalty != 0.0: - self.frequency_penalties_reqs.add(req_id) - self.presence_penalties_cpu[ - req_index] = sampling_params.presence_penalty - if sampling_params.presence_penalty != 0.0: - self.presence_penalties_reqs.add(req_id) - self.repetition_penalties_cpu[ - req_index] = sampling_params.repetition_penalty - if sampling_params.repetition_penalty != 1.0: - self.repetition_penalties_reqs.add(req_id) - - # NOTE(woosuk): self.generators should not include the requests that - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = (self.vocab_size - if sampling_params.logprobs == -1 - else sampling_params.logprobs) - if sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[ - req_id] = sampling_params.prompt_logprobs - - if sampling_params.allowed_token_ids: - self.has_allowed_token_ids.add(req_id) - if self.allowed_token_ids_mask_cpu_tensor is None: - # Lazy allocation for this tensor, which can be large. - # False means we don't fill with -inf. - self.allowed_token_ids_mask = torch.zeros( - self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device=self.device) - self.allowed_token_ids_mask_cpu_tensor = torch.zeros( - self.max_num_reqs, - self.vocab_size, - dtype=torch.bool, - device="cpu") - self.allowed_token_ids_mask_cpu_tensor[req_index] = True - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu_tensor[req_index][ - sampling_params.allowed_token_ids] = False - - if sampling_params.bad_words_token_ids: - self.bad_words_token_ids[ - req_index] = sampling_params.bad_words_token_ids - elif pooling_params := request.pooling_params: - self.pooling_params[req_id] = pooling_params - self.logits_processing_needs_token_ids[req_index] = ( - pooling_params.requires_token_ids) - else: - raise NotImplementedError(request) - - # Speculative decoding: by default 1 token is generated. - self.num_accepted_tokens_cpu[req_index] = 1 - - # Add request lora ID - if request.lora_request: - lora_id = request.lora_request.lora_int_id - if lora_id not in self.lora_id_to_request_ids: - self.lora_id_to_request_ids[lora_id] = set() - - self.request_lora_mapping[req_index] = lora_id - self.lora_id_to_request_ids[lora_id].add(request.req_id) - self.lora_id_to_lora_request[lora_id] = request.lora_request - else: - # No LoRA - self.request_lora_mapping[req_index] = 0 - - return req_index - - def remove_request(self, req_id: str) -> Optional[int]: - """This method must always be followed by a call to condense(). - - Args: - req_id: request to remove - - Returns: - Removed request index, or `None` if `req_id` not recognized - """ - - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - if not self.is_pooling_model: - # Autoregressive models require bookkeeping of removed requests to - # support logitsprocs. - self.batch_update_builder.removed_append(req_index) - self._req_ids[req_index] = None - self.req_output_token_ids[req_index] = None - self.spec_token_ids[req_index].clear() - - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - lora_req_ids = self.lora_id_to_request_ids[lora_id] - lora_req_ids.discard(req_id) - if not lora_req_ids: - del self.lora_id_to_request_ids[lora_id] - del self.lora_id_to_lora_request[lora_id] - self.request_lora_mapping[req_index] = 0 - - if self.is_pooling_model: - self.pooling_params.pop(req_id, None) - return req_index - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.spec_decode_unsupported_reqs.discard(req_id) - self.frequency_penalties_reqs.discard(req_id) - self.presence_penalties_reqs.discard(req_id) - self.repetition_penalties_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) - self.in_progress_prompt_logprobs_cpu.pop(req_id, None) - - if self.prev_req_id_to_index is not None: - self.prev_req_id_to_index.pop(req_id, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - self.lora_id_to_request_ids[lora_id].discard(req_id) - if len(self.lora_id_to_request_ids[lora_id]) == 0: - self.lora_id_to_request_ids.pop(lora_id) - self.lora_id_to_lora_request.pop(lora_id) - self.request_lora_mapping[req_index] = 0 - - self.has_allowed_token_ids.discard(req_id) - if self.allowed_token_ids_mask_cpu_tensor is not None: - # False means we don't fill with -inf. - self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) - self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) - return req_index - - def swap_states(self, i1: int, i2: int) -> None: - # For autoregressive models, track detailed request reordering info - # to support logitsprocs - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) - old_id_i1 = self._req_ids[i1] - old_id_i2 = self._req_ids[i2] - self._req_ids[i1], self._req_ids[i2] =\ - self._req_ids[i2], self._req_ids[i1] # noqa - self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\ - self.req_output_token_ids[i2], self.req_output_token_ids[i1] - self.spec_token_ids[i1], self.spec_token_ids[i2] = ( - self.spec_token_ids[i2], - self.spec_token_ids[i1], - ) - assert old_id_i1 is not None and old_id_i2 is not None - self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\ - self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] - self.num_tokens[i1], self.num_tokens[i2] =\ - self.num_tokens[i2], self.num_tokens[i1] - self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ - self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] - self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ - self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] - self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ - self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] - self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\ - self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1] - - # NOTE: the following is unsafe - # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ - # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] - # instead, we need to temporiarily copy the data for one of the indices - # TODO(lucas): optimize this by only copying valid indices - tmp = self.token_ids_cpu[i1, ...].copy() - self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] - self.token_ids_cpu[i2, ...] = tmp - - self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] - - # Swap prompt embeddings if they exist - embeds_i1 = self.req_prompt_embeds.get(i1) - embeds_i2 = self.req_prompt_embeds.get(i2) - if embeds_i1 is not None: - self.req_prompt_embeds[i2] = embeds_i1 - else: - self.req_prompt_embeds.pop(i2, None) - if embeds_i2 is not None: - self.req_prompt_embeds[i1] = embeds_i2 - else: - self.req_prompt_embeds.pop(i1, None) - - swap_dict_values(self.generators, i1, i2) - swap_dict_values(self.bad_words_token_ids, i1, i2) - - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ - self.request_lora_mapping[i2], self.request_lora_mapping[i1] - - if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[i1], \ - self.allowed_token_ids_mask_cpu_tensor[i2] =\ - self.allowed_token_ids_mask_cpu_tensor[i2], \ - self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) - - def condense(self) -> None: - """Slide non-empty requests down into lower, empty indices. - - Any consecutive empty indices at the very end of the list are not - filled. - - Args: - empty_req_indices: empty indices which may be filled. - - Returns: - swaps: list of (from,to) swap tuples for moved requests - empty_req_indices: indices not filled by condensation - """ - num_reqs = self.num_reqs - - if self.is_pooling_model: - # Will be contiguous in pooling case, just trim the lists. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - return - - if not (empty_req_indices := self.batch_update_builder.removed): - # All removed requests were replaced by added requests, or else no - # requests were removed at all. No condense() needed - return - if num_reqs == 0: - # The batched states are empty. - self._req_ids.clear() - self.req_output_token_ids.clear() - self.spec_token_ids.clear() - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = self.batch_update_builder.peek_removed() - assert empty_index is not None - if empty_index >= last_req_index: - break - - # Move active request down into empty request - # index. - self.batch_update_builder.pop_removed() - # Autoregressive models require detailed tracking of condense - # operations to support logitsprocs - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) - req_id = self._req_ids[last_req_index] - output_token_ids = self.req_output_token_ids[last_req_index] - assert req_id is not None - self._req_ids[empty_index] = req_id - self._req_ids[last_req_index] = None - self.req_output_token_ids[empty_index] = output_token_ids - self.req_output_token_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - if last_req_index != empty_index: - ( - self.spec_token_ids[last_req_index], - self.spec_token_ids[empty_index], - ) = ( - self.spec_token_ids[empty_index], - self.spec_token_ids[last_req_index], - ) - self.spec_token_ids[last_req_index].clear() - - num_tokens = self.num_tokens[last_req_index] - self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ - last_req_index, :num_tokens] - self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ - last_req_index, :num_tokens] - if last_req_index in self.req_prompt_embeds: - self.req_prompt_embeds[ - empty_index] = self.req_prompt_embeds.pop(last_req_index) - self.num_tokens[empty_index] = num_tokens - self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ - last_req_index] - self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table.move_row(last_req_index, empty_index) - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - self.frequency_penalties_cpu[ - empty_index] = self.frequency_penalties_cpu[last_req_index] - self.presence_penalties_cpu[ - empty_index] = self.presence_penalties_cpu[last_req_index] - self.repetition_penalties_cpu[ - empty_index] = self.repetition_penalties_cpu[last_req_index] - self.num_accepted_tokens_cpu[ - empty_index] = self.num_accepted_tokens_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - - # TODO convert these to LogitsProcessors - if self.allowed_token_ids_mask_cpu_tensor is not None: - self.allowed_token_ids_mask_cpu_tensor[ - empty_index] = self.allowed_token_ids_mask_cpu_tensor[ - last_req_index] - - bad_words_token_ids = self.bad_words_token_ids.pop( - last_req_index, None) - if bad_words_token_ids is not None: - self.bad_words_token_ids[empty_index] = bad_words_token_ids - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - # Trim lists to the batch size. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - del self.spec_token_ids[num_reqs:] - - def refresh_metadata(self): - """Apply any batch updates to sampling metadata.""" - - if self.is_pooling_model: - # Batch changes every step for pooling models. - self.sampling_metadata = self._make_sampling_metadata() - return - - # For non-pooling models - generate and apply logitsprocs update; - # reset batch update tracking. - # Update sampling metadata if batch state is changed. - batch_update = self.batch_update_builder.get_and_reset(self.num_reqs) - for logit_proc in self.logitsprocs.all: - logit_proc.update_state(batch_update) - if batch_update: - self.sampling_metadata = self._make_sampling_metadata() - - def _make_sampling_metadata(self) -> SamplingMetadata: - num_reqs = self.num_reqs - if not self.all_greedy: - temperature = copy_slice(self.temperature_cpu_tensor, - self.temperature, num_reqs) - else: - temperature = None - if not self.no_top_p: - copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) - if not self.no_top_k: - copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) - - if not self.no_penalties: - # Since syncing these tensors is expensive only copy them - # if necessary i.e. if there are requests which require - # penalties to be applied during sampling. - copy_slice(self.frequency_penalties_cpu_tensor, - self.frequency_penalties, num_reqs) - copy_slice(self.presence_penalties_cpu_tensor, - self.presence_penalties, num_reqs) - copy_slice(self.repetition_penalties_cpu_tensor, - self.repetition_penalties, num_reqs) - - needs_prompt_token_ids = ( - not self.no_penalties - or self.logits_processing_needs_token_ids[:num_reqs].any()) - if needs_prompt_token_ids: - # The prompt tokens are used only for applying penalties or - # step pooling during the sampling/pooling process. - # Hence copy these tensors only when there are requests which - # need penalties/step_pooler to be applied. - prompt_token_ids = self._make_prompt_token_ids_tensor() - else: - prompt_token_ids = None - - allowed_token_ids_mask: Optional[torch.Tensor] = None - if not self.no_allowed_token_ids: - assert self.allowed_token_ids_mask is not None - copy_slice(self.allowed_token_ids_mask_cpu_tensor, - self.allowed_token_ids_mask, num_reqs) - allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] - - return SamplingMetadata( - temperature=temperature, - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=None if self.no_top_p else self.top_p[:num_reqs], - top_k=None if self.no_top_k else self.top_k[:num_reqs], - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - prompt_token_ids=prompt_token_ids, - frequency_penalties=self.frequency_penalties[:num_reqs], - presence_penalties=self.presence_penalties[:num_reqs], - repetition_penalties=self.repetition_penalties[:num_reqs], - output_token_ids=cast(list[list[int]], self.req_output_token_ids), - spec_token_ids=cast(list[list[int]], self.spec_token_ids), - no_penalties=self.no_penalties, - allowed_token_ids_mask=allowed_token_ids_mask, - bad_words_token_ids=self.bad_words_token_ids, - logitsprocs=self.logitsprocs, - ) - - def get_pooling_params(self) -> list[PoolingParams]: - assert len(self.req_ids) == len(self.pooling_params) - return [self.pooling_params[req_id] for req_id in self.req_ids] - - def get_pooling_metadata(self) -> PoolingMetadata: - pooling_params = self.get_pooling_params() - - return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]), - prompt_token_ids=self.sampling_metadata.prompt_token_ids, - pooling_params=pooling_params, - ) - - def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() - prompt_token_ids_cpu_tensor = torch.empty( - (self.num_reqs, max_prompt_len), - device="cpu", - dtype=torch.int64, - pin_memory=self.pin_memory, - ) - prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] - # Use the value of vocab_size as a pad since we don't have a - # token_id of this value. - for i in range(self.num_reqs): - prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, - non_blocking=True) - - def make_lora_inputs( - self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray - ) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]: - """ - Given the num_scheduled_tokens for each request in the batch, return - datastructures used to activate the current LoRAs. - Returns: - 1. prompt_lora_mapping: A tuple of size self.num_reqs where, - prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. - 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) - where, token_lora_mapping[i] is the LoRA id to use for ith token. - 3. lora_requests: Set of relevant LoRA requests. - """ - - req_lora_mapping = self.request_lora_mapping[:self.num_reqs] - prompt_lora_mapping = tuple(req_lora_mapping) - token_lora_mapping = tuple( - req_lora_mapping.repeat(num_scheduled_tokens)) - active_lora_requests: set[LoRARequest] = set( - self.lora_id_to_lora_request.values()) - - return prompt_lora_mapping, token_lora_mapping, active_lora_requests - - def set_async_sampled_token_ids( - self, - sampled_token_ids_cpu: torch.Tensor, - async_copy_ready_event: torch.Event, - ) -> None: - """ - In async scheduling case, store ref to sampled_token_ids_cpu - tensor and corresponding copy-ready event. Used to repair - output_token_ids prior to sampling, if needed by logits processors. - """ - if self.sampling_metadata.output_token_ids: - self.sampled_token_ids_cpu = sampled_token_ids_cpu - self.async_copy_ready_event = async_copy_ready_event - else: - self.sampled_token_ids_cpu = None - self.async_copy_ready_event = None - - def update_async_output_token_ids(self) -> None: - """ - In async scheduling case, update output_token_ids in sampling metadata - from prior steps sampled token ids once they've finished copying to CPU. - This is called right before they are needed by the logits processors. - """ - output_token_ids = self.sampling_metadata.output_token_ids - if self.sampled_token_ids_cpu is None or not output_token_ids: - # Output token ids not needed or not async scheduling. - return - - assert self.prev_req_id_to_index is not None - sampled_token_ids = None - for index, req_id in enumerate(self.req_ids): - prev_index = self.prev_req_id_to_index.get(req_id) - if prev_index is None: - continue - req_output_token_ids = output_token_ids[index] - if not req_output_token_ids or req_output_token_ids[-1] != -1: - # Final output id is not a placeholder, some tokens must have - # been discarded after a kv-load failure. - continue - if sampled_token_ids is None: - assert self.async_copy_ready_event is not None - self.async_copy_ready_event.synchronize() - sampled_token_ids = self.sampled_token_ids_cpu.squeeze( - -1).tolist() - # Replace placeholder token id with actual sampled id. - req_output_token_ids[-1] = sampled_token_ids[prev_index] - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def no_penalties(self) -> bool: - return (len(self.presence_penalties_reqs) == 0 - and len(self.frequency_penalties_reqs) == 0 - and len(self.repetition_penalties_reqs) == 0) - - @property - def max_num_logprobs(self) -> Optional[int]: - return max(self.num_logprobs.values()) if self.num_logprobs else None - - @property - def no_prompt_logprob(self) -> bool: - return not self.num_prompt_logprobs - - @property - def no_allowed_token_ids(self) -> bool: - return len(self.has_allowed_token_ids) == 0