Skip to content

MTP #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

MTP #885

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch, PrefillMicroBatch
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm


logger = init_logger(__name__)

Expand Down Expand Up @@ -71,6 +73,10 @@ def __init__(self, kvargs):
self.tp_world_size_ = get_dp_world_size()
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode

# Speculative decoding
self.spec_algo = SpeculativeDecodeAlgorithm.from_string(kvargs.get("spec_algo", "NONE"))
self.spec_info = None

self._init_datatype()
self._init_config()
self._verify_must()
Expand Down Expand Up @@ -279,6 +285,8 @@ def _prefill(
):
infer_state = self.infer_state_class()
infer_state.is_prefill = True
infer_state.spec_algo = self.spec_algo
infer_state.spec_info = self.spec_info
infer_state.is_token_healing = self.is_token_healing
infer_state.return_all_prompt_logics = self.return_all_prompt_logics
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
Expand Down Expand Up @@ -330,6 +338,8 @@ def _decode(
):
infer_state = self.infer_state_class()
infer_state.is_prefill = False
infer_state.spec_algo = self.spec_algo
infer_state.spec_info = self.spec_info
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
Expand All @@ -343,12 +353,13 @@ def _decode(
infer_state.req_manager = self.req_manager

infer_state.mem_index = mem_indexes
decode_len = self.spec_algo.decode_len()
infer_state.kv_buffer_shapedtype = (
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
(batch_size * decode_len, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
self.data_type,
)
infer_state.dist_group = dist_group_manager.get_default_group()
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index, decode_len)

infer_state.init_some_extra_state(self, input_ids)
if self.graph is not None and self.graph.can_run(batch_size, max_len_in_batch):
Expand Down Expand Up @@ -498,6 +509,9 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])

if self.spec_algo.is_mtp():
self.spec_info = input_embs

post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)

Expand All @@ -519,7 +533,10 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
layer = self.layers_infer[i]
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])


if self.spec_algo.is_mtp():
self.spec_info = input_embs

post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
predict_logics = post_method(input_embs, infer_state, self.pre_post_weight)

Expand Down
7 changes: 6 additions & 1 deletion lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
from lightllm.common.basemodel.microbatch_overlap_objs import DecodeMicroBatch
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm
from lightllm.common.basemodel.basemodel import TpPartBaseModel

logger = init_logger(__name__)

Expand Down Expand Up @@ -126,7 +128,7 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
return self._replay(input_ids, infer_state)

@torch.no_grad()
def warmup(self, model):
def warmup(self, model: TpPartBaseModel):
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
for batch_size in range(self.max_batch_size, 0, -1):
# dummy prefill
Expand Down Expand Up @@ -160,6 +162,9 @@ def warmup(self, model):
torch.cuda.empty_cache()

# dummy decoding, capture the cudagraph
decode_len = model.spec_algo.decode_len()
predict_ids = predict_ids.repeat(decode_len)
b_start_loc = b_start_loc + torch.arange(0, batch_size*decode_len, decode_len, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
mem_indexes = model.mem_manager.alloc(len(predict_ids)).cuda()
Expand Down
5 changes: 5 additions & 0 deletions lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Tuple, Any
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from lightllm.common.spec_info import SpeculativeDecodeAlgorithm


class InferStateInfo:
Expand Down Expand Up @@ -53,6 +54,10 @@ def __init__(self):
self.position_ids: torch.Tensor = None
self.max_q_seq_len: int = None
self.max_kv_seq_len: int = None

# Speculative decoding
self.spec_algo = SpeculativeDecodeAlgorithm.NONE
self.spec_info = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
Expand Down
48 changes: 36 additions & 12 deletions lightllm/common/basemodel/triton_kernel/copy_kv_index_to_req.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import triton
import triton.language as tl

import copy

@triton.jit
def _fwd_kernel_copy_kv_index_to_req(
Expand All @@ -19,16 +19,40 @@ def _fwd_kernel_copy_kv_index_to_req(


@torch.no_grad()
def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex):
seq_len = b_seq_len.shape[0]
assert b_seq_len.shape[0] == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]
grid = (seq_len,)
def copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len, memindex, decode_len=1):
batch_size = b_seq_len.shape[0]
assert b_seq_len.shape[0] * decode_len == memindex.shape[0] and b_req_idx.shape[0] == b_seq_len.shape[0]
grid = (batch_size, )
num_warps = 1

_fwd_kernel_copy_kv_index_to_req[grid](
req_to_token_indexs, b_req_idx, b_seq_len, memindex,
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
num_warps=num_warps,
num_stages=1,
)
_b_seq_len = copy.deepcopy(b_seq_len)
for i in range(decode_len):
_fwd_kernel_copy_kv_index_to_req[grid](
req_to_token_indexs, b_req_idx, _b_seq_len, memindex[batch_size * i: batch_size * (i + 1)],
req_to_token_indexs.stride(0), req_to_token_indexs.stride(1),
num_warps=num_warps,
num_stages=1,
)
_b_seq_len = _b_seq_len + 1
return


if __name__ == '__main__':
for decode_len in [1,2]:
max_request_num = 100
max_sequence_length = 1000
req_to_token_indexs = torch.zeros(
(max_request_num + 1, max_sequence_length), dtype=torch.int32, device="cuda"
)
bs = 8
b_req_idx = torch.randint(low=0, high=max_request_num-1, size=(bs,)).cuda()
b_seq_len = torch.randint(low=1, high=max_sequence_length, size=(bs,)).cuda()
memindex = torch.randint(low=0, high=10000, size=(bs*decode_len,)).cuda()
copy_kv_index_to_req(req_to_token_indexs, b_req_idx, b_seq_len,memindex,decode_len)

for i in range(bs):
for j in range(decode_len):
if req_to_token_indexs[b_req_idx[i]][b_seq_len[i]+j-1] != memindex[j*bs+i]:
print("ERROR")
exit(1)

print("PASS")
32 changes: 32 additions & 0 deletions lightllm/common/spec_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from enum import IntEnum, auto


class SpeculativeDecodeAlgorithm(IntEnum):
NONE = auto()
MTP = auto()
MTP_MOUDLE = auto()

def is_none(self):
return self == SpeculativeDecodeAlgorithm.NONE

def is_mtp(self):
return self == SpeculativeDecodeAlgorithm.MTP

@staticmethod
def from_string(name: str):
name_map = {
"MTP": SpeculativeDecodeAlgorithm.MTP,
"MTP_MOUDLE": SpeculativeDecodeAlgorithm.MTP_MOUDLE,
"NONE": SpeculativeDecodeAlgorithm.NONE,
}
if name is not None:
name = name.upper()
return name_map[name]

def decode_len(self):
if self == SpeculativeDecodeAlgorithm.NONE:
return 1
if self == SpeculativeDecodeAlgorithm.MTP:
return 2
if self == SpeculativeDecodeAlgorithm.MTP_MOUDLE:
return 2
15 changes: 13 additions & 2 deletions lightllm/models/deepseek2/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch.distributed as dist
from lightllm.models.llama.infer_struct import LlamaInferStateInfo

from lightllm.common.spec_info import SpeculativeDecodeAlgorithm

class Deepseek2InferStateInfo(LlamaInferStateInfo):
def __init__(self):
Expand All @@ -18,4 +18,15 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
self.b1_kv_start_loc = self.b1_cu_kv_seq_len
self.max_value_in_b_seq_len = self.b_seq_len.max().item()
return

if not self.is_prefill and not self.spec_algo.is_none():
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
position_ids = torch.from_numpy(
np.concatenate(
[np.arange(b_seq_len_numpy[i] - 2, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))],
axis=0,
)
).cuda()
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
return
Empty file.
44 changes: 44 additions & 0 deletions lightllm/models/deepseek_mtp/deepseek3_mtp_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt

logger = init_logger(__name__)


class Deepseek3MTPMemoryManager(Deepseek2MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
# profile the max total token num if the size is None
self.profile_size(mem_fraction)

self.mem_state = torch.arange(
0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
)
self.mark_start = 0
self.mark_end = self.size

self.can_use_mem_size = self.size

rank_in_node = get_current_rank_in_node()
self.shared_can_use_token_num = SharedInt(
f"MTP_mem_manger_can_use_token_num_{rank_in_node}"
)

self.shared_can_use_token_num.set_value(self.can_use_mem_size)

self._init_buffers(
self.size,
dtype,
head_num,
head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size

Empty file.
Loading
Loading