diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 5461ac2833c..e86169529fe 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -453,13 +453,17 @@ def update_mtp_block_num(self, num_gpu_blocks) -> None: } ) - def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): + def insert_tasks_v1( + self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {} + ): if "caches" not in self.model_inputs: self.initialize_kv_cache() req_len = len(req_dicts) self.model_inputs["num_running_requests"] = num_running_requests self.model_inputs["running_requests_ids"] = range(num_running_requests) + if target_model_index_to_batch_id: + self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id) for i in range(req_len): request = req_dicts[i] logger.debug(f"{i}th request-{request.request_id}: {request}") @@ -1223,11 +1227,11 @@ def _get_cache_type(self): raise NotImplementedError return cache_type - def reorder_inputs(self): + def reorder_inputs(self, target_model_input_batch): """ Reorder inputs to split prefill and decode. """ - reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs) + reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch) def _share_external_data(self, cache, cache_name, cache_shape): if current_platform.is_xpu(): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 554e464bf53..0bed801a17b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -977,7 +977,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests] if self.spec_method == SpecMethod.MTP: - self.proposer.insert_tasks_v1(req_dicts, num_running_requests) + self.proposer.insert_tasks_v1(req_dicts, num_running_requests, self.share_inputs.index_to_batch_id) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): raise NotImplementedError("GPUs only support KVCACHE SCHEDULER V1 in versions 2.6 and above.") @@ -1221,7 +1221,7 @@ def _process_reorder(self) -> None: reorder_split_prefill_and_decode(input_batch=self.share_inputs) if self.speculative_decoding: if self.spec_method == SpecMethod.MTP: - self.proposer.reorder_inputs() + self.proposer.reorder_inputs(self.share_inputs.index_to_batch_id) def load_model(self) -> None: """load or download model""" diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 4c6c646d27b..cdd7c60b95c 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -428,7 +428,7 @@ def swap_data(tensor, idx1, idx2): swap_data(self.step_seq_lens_this_time, i1, i2) swap_data(self.draft_logits, i1, i2) swap_data(self.cu_batch_token_offset, i1, i2) - swap_data(self.stop_flags, i1, i2) + if self.enable_mm: if self.image_features_list is not None: self.image_features_list[i1], self.image_features_list[i2] = ( @@ -674,7 +674,6 @@ def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) -> def init_share_inputs(self): # share with targe model self.enable_pd_reorder = getattr(self.target_model_input_batch, "enable_pd_reorder", False) - self.index_to_batch_id = getattr(self.target_model_input_batch, "index_to_batch_id", {}) self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"]) self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"]) @@ -851,6 +850,7 @@ def swap_data(tensor, idx1, idx2): tensor[idx1] = tensor[idx2].clone() tensor[idx2] = temp + self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1] swap_data(self.block_tables, i1, i2) swap_data(self.input_ids, i1, i2) swap_data(self.input_ids_cpu, i1, i2) @@ -858,48 +858,14 @@ def swap_data(tensor, idx1, idx2): swap_data(self.seq_lens_encoder, i1, i2) swap_data(self.seq_lens_decoder, i1, i2) swap_data(self.step_idx, i1, i2) - swap_data(self.stop_flags, i1, i2) - swap_data(self.not_need_stop, i1, i2) swap_data(self.pre_ids, i1, i2) - if current_platform.is_cuda(): - swap_data(self.cu_seqlens_q_output, i1, i2) - swap_data(self.batch_id_per_token_output, i1, i2) - swap_data(self.token_ids_all, i1, i2) - else: - swap_data(self.output_cum_offsets, i1, i2) - swap_data(self.output_padding_offset, i1, i2) - swap_data(self.ids_remove_padding, i1, i2) - swap_data(self.batch_id_per_token, i1, i2) - swap_data(self.cu_seqlens_q, i1, i2) - swap_data(self.cu_seqlens_k, i1, i2) - - swap_data(self.target_hidden_states, i1, i2) - - swap_data(self.draft_tokens, i1, i2) swap_data(self.encoder_block_lens, i1, i2) - - swap_data(self.is_block_step, i1, i2) - swap_data(self.batch_drop, i1, i2) - swap_data(self.used_list_len, i1, i2) - - if self.num_model_steps > 1: - swap_data(self.last_seq_lens_this_time, i1, i2) - swap_data(self.input_ids_len, i1, i2) - swap_data(self.first_token_hidden_states, i1, i2) - - swap_data(self.batch_token_num, i1, i2) - swap_data(self.next_token_num, i1, i2) - swap_data(self.cu_batch_token_offset, i1, i2) - swap_data(self.cu_next_token_offset, i1, i2) swap_data(self.mask_rollback, i1, i2) swap_data(self.recompute_token_num, i1, i2) - if self.enable_mm: - swap_data(self.attn_mask_offsets, i1, i2) swap_data(self.attn_mask_offsets_full, i1, i2) swap_data(self.attn_mask_offsets_decoder, i1, i2) - swap_data(self.decode_states, i1, i2) def reset_model_inputs(self) -> None: """ @@ -1041,14 +1007,28 @@ def reset_model_inputs(self) -> None: logger.error(f"Resetting model inputs failed, skipping reset, error message is {e}") -def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch): - swapped = set() - for i, target in input_batch.index_to_batch_id.items(): - if i in swapped or target in swapped or i == target: +def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch, target_model_input_batch: dict): + mtp_index_2_mtp_id = {v: k for k, v in input_batch.index_to_batch_id.items()} + for target_model_id in target_model_input_batch: + target_model_index = target_model_input_batch[target_model_id] + if input_batch.index_to_batch_id[target_model_id] == target_model_index: continue - input_batch.swap_states(i, target) - swapped.add(i) - swapped.add(target) + mtp_id = mtp_index_2_mtp_id[target_model_index] + v1 = input_batch.index_to_batch_id[target_model_id] + v2 = input_batch.index_to_batch_id[mtp_id] + input_batch.swap_states(target_model_id, mtp_id) + # update mapping + mtp_index_2_mtp_id[v1] = mtp_id + mtp_index_2_mtp_id[v2] = target_model_id + + keys_to_remove = input_batch.index_to_batch_id.keys() - target_model_input_batch.keys() + + for key in keys_to_remove: + del input_batch.index_to_batch_id[key] + for k, v in mtp_index_2_mtp_id.items(): + if v == key: + del mtp_index_2_mtp_id[k] + break def reorder_split_prefill_and_decode(input_batch: InputBatch): diff --git a/tests/e2e/test_pd_reorder.py b/tests/e2e/test_pd_reorder.py index 90e95a7606c..ce6ad6824e4 100644 --- a/tests/e2e/test_pd_reorder.py +++ b/tests/e2e/test_pd_reorder.py @@ -104,5 +104,24 @@ def test_model_against_baseline( quantization, "dummy", prompts, + {}, # speculative_config + ), + ) + + mtp_model_path = os.path.join(model_path, "mtp") + speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path} + _ = run_with_timeout( + target=form_model_get_output_topp0, + args=( + fd_runner, + model_path, + tensor_parallel_size, + max_num_seqs, + max_model_len, + max_tokens, + quantization, + "dummy", + prompts, + speculative_config, ), ) diff --git a/tests/model_loader/utils.py b/tests/model_loader/utils.py index 9a41422fe98..977124d915f 100644 --- a/tests/model_loader/utils.py +++ b/tests/model_loader/utils.py @@ -86,6 +86,7 @@ def form_model_get_output_topp0( quantization, load_choices, prompts, + speculative_config, result_queue, ): try: @@ -96,6 +97,7 @@ def form_model_get_output_topp0( max_model_len=max_model_len, load_choices=load_choices, quantization=quantization, + speculative_config=speculative_config, ) as fd_model: fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens) result_queue.put(fd_outputs)