Skip to content

MTP #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed

MTP #885

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 94 additions & 132 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput


logger = init_logger(__name__)

Expand Down Expand Up @@ -71,6 +74,9 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

# Speculative decoding
self.spec_algo = SpeculativeDecodeAlgorithm.from_string(kvargs.get("spec_algo", "NONE"))

self._init_datatype()
self._init_config()
self._verify_must()
Expand Down Expand Up @@ -226,140 +232,82 @@ def _init_custom(self):
pass

@torch.no_grad()
def forward(
self,
batch_size,
total_token_num,
max_len_in_batch,
input_ids: torch.Tensor,
mem_indexes: torch.Tensor,
b_req_idx: torch.Tensor,
b_seq_len: torch.Tensor,
b_ready_cache_len: torch.Tensor = None,
multimodal_params=None,
is_prefill=True,
):
assert mem_indexes.is_cuda

if is_prefill:
return self._prefill(
batch_size,
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_seq_len,
b_ready_cache_len,
multimodal_params,
)
def forward(self, model_input: ModelInput):
assert model_input.mem_indexes.is_cuda

if model_input.is_prefill:
return self._prefill(model_input)
else:
return self._decode(
batch_size,
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_seq_len,
multimodal_params,
)
return self._decode(model_input)

def _prefill(
self,
batch_size,
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_seq_len,
b_ready_cache_len,
multimodal_params,
):
def _create_inferstate(self, model_input: ModelInput, batch_index: int = 0):
infer_state = self.infer_state_class()
infer_state.is_prefill = True
infer_state.is_prefill = model_input.is_prefill
infer_state.spec_algo = self.spec_algo
infer_state.spec_info = model_input.hidden_states

infer_state.is_token_healing = self.is_token_healing
infer_state.return_all_prompt_logics = self.return_all_prompt_logics
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
assert b_req_idx.shape[0] == b_seq_len.shape[0]
infer_state.b_req_idx = b_req_idx
infer_state.b_seq_len = b_seq_len
if b_ready_cache_len is not None:
infer_state.b_ready_cache_len = b_ready_cache_len
infer_state.batch_size = model_input.batch_size
infer_state.total_token_num = model_input.total_token_num
infer_state.max_len_in_batch = model_input.max_len_in_batch
assert model_input.b_req_idx.shape[0] == model_input.b_seq_len.shape[0]
infer_state.b_req_idx = model_input.b_req_idx
infer_state.b_seq_len = model_input.b_seq_len
if model_input.b_ready_cache_len is not None and model_input.is_prefill:
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
else:
infer_state.b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=b_seq_len.dtype, device=b_seq_len.device)
infer_state.multimodal_params = multimodal_params
infer_state.b_ready_cache_len = torch.zeros_like(
model_input.b_seq_len, dtype=model_input.b_seq_len.dtype, device=model_input.b_seq_len.device
)
infer_state.multimodal_params = model_input.multimodal_params

infer_state.mem_manager = self.mem_manager
infer_state.req_manager = self.req_manager

infer_state.mem_index = mem_indexes
infer_state.mem_index = model_input.mem_indexes
infer_state.kv_buffer_shapedtype = (
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
(model_input.input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
self.data_type,
)
infer_state.dist_group = dist_group_manager.get_default_group()
infer_state.dist_group = dist_group_manager.get_group(batch_index)
return infer_state

def _prefill(
self,
model_input: ModelInput,
):
infer_state = self._create_inferstate(model_input)
init_req_to_token_indexes(
self.req_manager.req_to_token_indexs,
b_req_idx,
b_seq_len,
model_input.b_req_idx,
model_input.b_seq_len,
infer_state.b_ready_cache_len,
max_len_in_batch,
model_input.max_len_in_batch,
infer_state.mem_index,
)

infer_state.init_some_extra_state(self, input_ids)
predict_logics = self._context_forward(input_ids, infer_state)
return predict_logics
infer_state.init_some_extra_state(self, model_input.input_ids)
return self._context_forward(model_input.input_ids, infer_state)

def _decode(
self,
batch_size,
total_token_num,
max_len_in_batch,
input_ids,
mem_indexes,
b_req_idx,
b_seq_len,
multimodal_params,
model_input: ModelInput,
):
infer_state = self.infer_state_class()
infer_state.is_prefill = False
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
assert b_req_idx.shape[0] == b_seq_len.shape[0]
infer_state.b_req_idx = b_req_idx
infer_state.b_seq_len = b_seq_len
infer_state.multimodal_params = multimodal_params

infer_state.mem_manager = self.mem_manager
infer_state.req_manager = self.req_manager

infer_state.mem_index = mem_indexes
infer_state.kv_buffer_shapedtype = (
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
self.data_type,
infer_state = self._create_inferstate(model_input)
copy_kv_index_to_req(
self.req_manager.req_to_token_indexs, model_input.b_req_idx, model_input.b_seq_len, infer_state.mem_index
)
infer_state.dist_group = dist_group_manager.get_default_group()
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)

infer_state.init_some_extra_state(self, input_ids)
if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch):
if self.graph.need_capture(batch_size):
infer_state.init_some_extra_state(self, model_input.input_ids)
if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
if self.graph.need_capture(model_input.batch_size):
infer_state.is_cuda_graph = True
predict_logics = self.graph.capture_decode(self._token_forward, input_ids, infer_state)
return self.graph.capture_decode(self._token_forward, model_input.input_ids, infer_state)
else:
predict_logics = self.graph.replay(input_ids, infer_state)
else:
predict_logics = self._token_forward(input_ids, infer_state)
return predict_logics
return self.graph.replay(model_input.input_ids, infer_state)

return self._token_forward(model_input.input_ids, infer_state)

@torch.no_grad()
def microbatch_overlap_decode(self, batch: DecodeMicroBatch, batch1: DecodeMicroBatch):
Expand Down Expand Up @@ -409,22 +357,22 @@ def create_inferstate(cur_batch: DecodeMicroBatch, batch_index):
infer_state.is_cuda_graph = True
infer_state1.is_cuda_graph = True

predict_logics, predict_logics1 = self.graph.capture_decode(
predict_logits, predict_logits1 = self.graph.capture_decode(
self._overlap_tpsp_token_forward,
input_ids,
infer_state,
input_ids1=input_ids1,
infer_state1=infer_state1,
)
else:
predict_logics, predict_logics1 = self.graph.replay(
predict_logits, predict_logits1 = self.graph.replay(
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
)
else:
predict_logics, predict_logics1 = self._overlap_tpsp_token_forward(
predict_logits, predict_logits1 = self._overlap_tpsp_token_forward(
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
)
return predict_logics, predict_logics1
return predict_logits, predict_logits1

@torch.no_grad()
def microbatch_overlap_prefill(self, batch: PrefillMicroBatch, batch1: PrefillMicroBatch):
Expand Down Expand Up @@ -478,11 +426,11 @@ def create_inferstate(cur_batch: PrefillMicroBatch, batch_index):
infer_state.init_some_extra_state(self, input_ids)
infer_state1.init_some_extra_state(self, input_ids1)

predict_logics, predict_logics1 = self._overlap_tpsp_context_forward(
predict_logits, predict_logits1 = self._overlap_tpsp_context_forward(
input_ids, infer_state, input_ids1=input_ids1, infer_state1=infer_state1
)
dist_group_manager.clear_deepep_buffer()
return predict_logics, predict_logics1
return predict_logits, predict_logits1

@final
def _context_forward(self, input_ids, infer_state: InferStateInfo):
Expand All @@ -499,10 +447,16 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])

post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)

g_cache_manager.cache_env_out()
return predict_logics
is_return_hidden_states = self.spec_algo.is_mtp() or (
self.spec_algo.is_mtp_module() and not self.last_mtp_module
)
return ModelOutput(
logits=predict_logits,
hidden_states=input_embs if is_return_hidden_states else None,
)

@final
def _token_forward(self, input_ids, infer_state: InferStateInfo):
Expand All @@ -521,10 +475,16 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])

post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)

g_cache_manager.cache_env_out()
return predict_logics
is_return_hidden_states = self.spec_algo.is_mtp() or (
self.spec_algo.is_mtp_module() and not self.last_mtp_module
)
return ModelOutput(
logits=predict_logits,
hidden_states=input_embs if is_return_hidden_states else None,
)

@final
def _overlap_tpsp_token_forward(
Expand All @@ -544,12 +504,12 @@ def _overlap_tpsp_token_forward(
input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i]
)

predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward(
predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight
)

g_cache_manager.cache_env_out()
return predict_logics, predict_logics1
return predict_logits, predict_logits1

@final
def _overlap_tpsp_context_forward(
Expand All @@ -563,11 +523,11 @@ def _overlap_tpsp_context_forward(
input_embs, input_embs1 = self.layers_infer[i].overlap_tpsp_context_forward(
input_embs, input_embs1, infer_state, infer_state1, self.trans_layers_weight[i]
)
predict_logics, predict_logics1 = self.post_infer.overlap_tpsp_token_forward(
predict_logits, predict_logits1 = self.post_infer.overlap_tpsp_token_forward(
input_embs, input_embs1, infer_state, infer_state1, self.pre_post_weight
)
g_cache_manager.cache_env_out()
return predict_logics, predict_logics1
return predict_logits, predict_logits1

@final
@torch.no_grad()
Expand All @@ -587,20 +547,22 @@ def _check_max_len_infer(self):
b_seq_len[:] = self.batch_max_tokens
b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda")
total_token_num = self.batch_max_tokens
logics = self.forward(
1,
total_token_num,
self.batch_max_tokens,
dummy_input_ids,
mem_indexes,
b_req_idx,
b_seq_len,
b_ready_cache_len=b_ready_cache_len,
model_input = ModelInput(
batch_size=1,
total_token_num=total_token_num,
max_len_in_batch=self.batch_max_tokens,
input_ids=dummy_input_ids,
mem_indexes=mem_indexes,
b_req_idx=b_req_idx,
b_seq_len=b_seq_len,
is_prefill=True,
multimodal_params=[],
b_ready_cache_len=b_ready_cache_len,
)
model_output = self.forward(
model_input,
)
prob_out = torch.softmax(logics, dim=-1)
logics = None
prob_out = torch.softmax(model_output.logits, dim=-1)
del model_output
torch.argmax(prob_out, dim=1, keepdim=True)
prob_out = None
self.req_manager.free_all()
Expand Down
23 changes: 23 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
from dataclasses import dataclass, field


@dataclass
class ModelInput:
batch_size: int
total_token_num: int
max_len_in_batch: int
input_ids: torch.Tensor
mem_indexes: torch.Tensor
b_req_idx: torch.Tensor
b_seq_len: torch.Tensor
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
multimodal_params: list = field(default_factory=list)
hidden_states: torch.Tensor = None


@dataclass
class ModelOutput:
logits: torch.tensor
hidden_states: torch.tensor
Loading