Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,17 @@ def update_mtp_block_num(self, num_gpu_blocks) -> None:
}
)

def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int):
def insert_tasks_v1(
self, req_dicts: List[Request], num_running_requests: int, target_model_index_to_batch_id: dict = {}
):

if "caches" not in self.model_inputs:
self.initialize_kv_cache()
req_len = len(req_dicts)
self.model_inputs["num_running_requests"] = num_running_requests
self.model_inputs["running_requests_ids"] = range(num_running_requests)
if target_model_index_to_batch_id:
self.model_inputs.index_to_batch_id = dict(target_model_index_to_batch_id)
for i in range(req_len):
request = req_dicts[i]
logger.debug(f"{i}th request-{request.request_id}: {request}")
Expand Down Expand Up @@ -1223,11 +1227,11 @@ def _get_cache_type(self):
raise NotImplementedError
return cache_type

def reorder_inputs(self):
def reorder_inputs(self, target_model_input_batch):
"""
Reorder inputs to split prefill and decode.
"""
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs)
reorder_split_prefill_and_decode_form_index_to_batch_id(self.model_inputs, target_model_input_batch)

def _share_external_data(self, cache, cache_name, cache_shape):
if current_platform.is_xpu():
Expand Down
4 changes: 2 additions & 2 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =

self.share_inputs["seq_lens_this_time"] = self.share_inputs["seq_lens_this_time_buffer"][:num_running_requests]
if self.spec_method == SpecMethod.MTP:
self.proposer.insert_tasks_v1(req_dicts, num_running_requests)
self.proposer.insert_tasks_v1(req_dicts, num_running_requests, self.share_inputs.index_to_batch_id)

def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int):
raise NotImplementedError("GPUs only support KVCACHE SCHEDULER V1 in versions 2.6 and above.")
Expand Down Expand Up @@ -1221,7 +1221,7 @@ def _process_reorder(self) -> None:
reorder_split_prefill_and_decode(input_batch=self.share_inputs)
if self.speculative_decoding:
if self.spec_method == SpecMethod.MTP:
self.proposer.reorder_inputs()
self.proposer.reorder_inputs(self.share_inputs.index_to_batch_id)

def load_model(self) -> None:
"""load or download model"""
Expand Down
66 changes: 23 additions & 43 deletions fastdeploy/worker/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def swap_data(tensor, idx1, idx2):
swap_data(self.step_seq_lens_this_time, i1, i2)
swap_data(self.draft_logits, i1, i2)
swap_data(self.cu_batch_token_offset, i1, i2)
swap_data(self.stop_flags, i1, i2)

if self.enable_mm:
if self.image_features_list is not None:
self.image_features_list[i1], self.image_features_list[i2] = (
Expand Down Expand Up @@ -674,7 +674,6 @@ def __init__(self, fd_config: FDConfig, target_model_input_batch: InputBatch) ->
def init_share_inputs(self):
# share with targe model
self.enable_pd_reorder = getattr(self.target_model_input_batch, "enable_pd_reorder", False)
self.index_to_batch_id = getattr(self.target_model_input_batch, "index_to_batch_id", {})

self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
Expand Down Expand Up @@ -851,55 +850,22 @@ def swap_data(tensor, idx1, idx2):
tensor[idx1] = tensor[idx2].clone()
tensor[idx2] = temp

self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
swap_data(self.block_tables, i1, i2)
swap_data(self.input_ids, i1, i2)
swap_data(self.input_ids_cpu, i1, i2)
swap_data(self.seq_lens_this_time_buffer, i1, i2)
swap_data(self.seq_lens_encoder, i1, i2)
swap_data(self.seq_lens_decoder, i1, i2)
swap_data(self.step_idx, i1, i2)
swap_data(self.stop_flags, i1, i2)
swap_data(self.not_need_stop, i1, i2)
swap_data(self.pre_ids, i1, i2)
if current_platform.is_cuda():
swap_data(self.cu_seqlens_q_output, i1, i2)
swap_data(self.batch_id_per_token_output, i1, i2)
swap_data(self.token_ids_all, i1, i2)
else:
swap_data(self.output_cum_offsets, i1, i2)
swap_data(self.output_padding_offset, i1, i2)
swap_data(self.ids_remove_padding, i1, i2)
swap_data(self.batch_id_per_token, i1, i2)
swap_data(self.cu_seqlens_q, i1, i2)
swap_data(self.cu_seqlens_k, i1, i2)

swap_data(self.target_hidden_states, i1, i2)

swap_data(self.draft_tokens, i1, i2)
swap_data(self.encoder_block_lens, i1, i2)

swap_data(self.is_block_step, i1, i2)
swap_data(self.batch_drop, i1, i2)
swap_data(self.used_list_len, i1, i2)

if self.num_model_steps > 1:
swap_data(self.last_seq_lens_this_time, i1, i2)

swap_data(self.input_ids_len, i1, i2)
swap_data(self.first_token_hidden_states, i1, i2)

swap_data(self.batch_token_num, i1, i2)
swap_data(self.next_token_num, i1, i2)
swap_data(self.cu_batch_token_offset, i1, i2)
swap_data(self.cu_next_token_offset, i1, i2)
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)

if self.enable_mm:
swap_data(self.attn_mask_offsets, i1, i2)
swap_data(self.attn_mask_offsets_full, i1, i2)
swap_data(self.attn_mask_offsets_decoder, i1, i2)
swap_data(self.decode_states, i1, i2)

def reset_model_inputs(self) -> None:
"""
Expand Down Expand Up @@ -1041,14 +1007,28 @@ def reset_model_inputs(self) -> None:
logger.error(f"Resetting model inputs failed, skipping reset, error message is {e}")


def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch):
swapped = set()
for i, target in input_batch.index_to_batch_id.items():
if i in swapped or target in swapped or i == target:
def reorder_split_prefill_and_decode_form_index_to_batch_id(input_batch: InputBatch, target_model_input_batch: dict):
mtp_index_2_mtp_id = {v: k for k, v in input_batch.index_to_batch_id.items()}
for target_model_id in target_model_input_batch:
target_model_index = target_model_input_batch[target_model_id]
if input_batch.index_to_batch_id[target_model_id] == target_model_index:
continue
input_batch.swap_states(i, target)
swapped.add(i)
swapped.add(target)
mtp_id = mtp_index_2_mtp_id[target_model_index]
v1 = input_batch.index_to_batch_id[target_model_id]
v2 = input_batch.index_to_batch_id[mtp_id]
input_batch.swap_states(target_model_id, mtp_id)
# update mapping
mtp_index_2_mtp_id[v1] = mtp_id
mtp_index_2_mtp_id[v2] = target_model_id

keys_to_remove = input_batch.index_to_batch_id.keys() - target_model_input_batch.keys()

for key in keys_to_remove:
del input_batch.index_to_batch_id[key]
for k, v in mtp_index_2_mtp_id.items():
if v == key:
del mtp_index_2_mtp_id[k]
break


def reorder_split_prefill_and_decode(input_batch: InputBatch):
Expand Down
19 changes: 19 additions & 0 deletions tests/e2e/test_pd_reorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,24 @@ def test_model_against_baseline(
quantization,
"dummy",
prompts,
{}, # speculative_config
),
)

mtp_model_path = os.path.join(model_path, "mtp")
speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path}
_ = run_with_timeout(
target=form_model_get_output_topp0,
args=(
fd_runner,
model_path,
tensor_parallel_size,
max_num_seqs,
max_model_len,
max_tokens,
quantization,
"dummy",
prompts,
speculative_config,
),
)
2 changes: 2 additions & 0 deletions tests/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def form_model_get_output_topp0(
quantization,
load_choices,
prompts,
speculative_config,
result_queue,
):
try:
Expand All @@ -96,6 +97,7 @@ def form_model_get_output_topp0(
max_model_len=max_model_len,
load_choices=load_choices,
quantization=quantization,
speculative_config=speculative_config,
) as fd_model:
fd_outputs = fd_model.generate_topp0(prompts, max_tokens=max_tokens)
result_queue.put(fd_outputs)
Expand Down
Loading