diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 09f649943..b4efa85b2 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -338,8 +338,8 @@ def update_finish_status(self, eos_ids): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif ( self.cur_output_len > 0 - and self.get_last_gen_token() in eos_ids and self.sampling_param.shm_param.ignore_eos is False + and self.get_last_gen_token() in eos_ids ): self.finish_status.set_status(FinishStatus.FINISHED_STOP) elif self.cur_output_len >= self.sampling_param.shm_param.max_new_tokens: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index e4db0aca7..69b96db26 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -4,6 +4,7 @@ import rpyc import torch import socket +import time from datetime import timedelta from typing import Dict, List, Tuple, Callable, Optional from transformers.configuration_utils import PretrainedConfig @@ -250,6 +251,81 @@ def _post_handle( is_chuncked_mode: bool, do_filter_finished_reqs: bool, extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ) -> List[int]: + """ + extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 + 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 + """ + if not hasattr(self, "_post_handle_impl"): + try: + finished_req_ids = self._fast_post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode, + do_filter_finished_reqs, + extra_post_req_handle_func, + ) + self._post_handle_impl = self._fast_post_handle + self.logger.info("use _fast_post_handle") + return finished_req_ids + except: + finished_req_ids = self._python_post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode, + do_filter_finished_reqs, + extra_post_req_handle_func, + ) + self.logger.info("use _python_post_handle") + self._post_handle_impl = self._python_post_handle + return finished_req_ids + else: + return self._post_handle_impl( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode, + do_filter_finished_reqs, + extra_post_req_handle_func, + ) + + def _fast_post_handle( + self, + run_reqs: List[InferReq], + next_token_ids, + next_token_logprobs, + is_chuncked_mode: bool, + do_filter_finished_reqs: bool, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ): + from . import cython_fast_impl + + start = time.time() + finished_req_ids = cython_fast_impl.fast_post_handle( + self, + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode, + do_filter_finished_reqs, + extra_post_req_handle_func, + ) + cost_time = time.time() - start + if self.is_master_in_dp and cost_time > 0.001: + self.logger.info(f"post handle cost time {cost_time} s, batch_size: {len(run_reqs)}") + return finished_req_ids + + # 一些可以复用的通用功能函数 + def _python_post_handle( + self, + run_reqs: List[InferReq], + next_token_ids, + next_token_logprobs, + is_chuncked_mode: bool, + do_filter_finished_reqs: bool, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, ) -> List[int]: """ extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 diff --git a/lightllm/server/router/model_infer/mode_backend/cython_fast_impl.pyx b/lightllm/server/router/model_infer/mode_backend/cython_fast_impl.pyx new file mode 100644 index 000000000..031c57ad7 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/cython_fast_impl.pyx @@ -0,0 +1,117 @@ +import cython +from typing import List, Optional, Callable +from ..infer_batch import InferReq, FinishStatus +from .base_backend import ModeBackend + + +def __update_finish_status(self: InferReq, gen_new_token_id:int, eos_ids: List[int]): + # stop way 1 + for stop_token_ids in self.stop_sequences: + stop_len = len(stop_token_ids) + output_len = self.cur_output_len + if stop_len > 0 and output_len >= stop_len: + total_len = self.shm_req.input_len + output_len + tail_token_ids = self.shm_req.shm_prompt_ids.arr[(total_len - stop_len) : total_len] + if all(tail_token_ids[i] == stop_token_ids[i] for i in range(stop_len)): + self.finish_status.set_status(FinishStatus.FINISHED_STOP) + return + + # stop way 2 + shm_param = self.sampling_param.shm_param + if (self.cur_output_len > 0 + and shm_param.ignore_eos is False + and gen_new_token_id in eos_ids + ): + self.finish_status.set_status(FinishStatus.FINISHED_STOP) + return + + # stop way 3 + if self.cur_output_len >= shm_param.max_new_tokens: + self.finish_status.set_status(FinishStatus.FINISHED_LENGTH) + return + + +# @cython.boundcheck(False) +# @cython.wraparound(False) +def fast_post_handle( + self: ModeBackend, + run_reqs: List[InferReq], + next_token_ids_, + next_token_logprobs_, + is_chuncked_mode: bool, + do_filter_finished_reqs: bool, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, +) -> List[int]: + """ + extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于 + 约束输出等模式,设置自己请求内部的状态机的状态,并添加额外的停止判定条件等。 + """ + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + finished_req_ids = [0 for _ in range(len(run_reqs))] + finished_req_ids.clear() + next_token_ids: cython.longlong[:] = cython.declare(cython.longlong[:], next_token_ids_) + next_token_logprobs: cython.float[:] = cython.declare(cython.float[:], next_token_logprobs_) + is_master_in_dp : cython.bint = self.is_master_in_dp + is_chuncked_mode : cython.bint = is_chuncked_mode + + i : cython.Py_ssize_t + for i in range(len(run_reqs)): + req_obj: InferReq = run_reqs[i] + shm_req = req_obj.shm_req + next_token_id: cython.int = next_token_ids[i] + next_token_logprob: cython.float = next_token_logprobs[i] + cur_total_len = shm_req.input_len + req_obj.cur_output_len + + if is_chuncked_mode: + new_kv_len = min(cur_total_len, req_obj.cur_kv_len + shm_req.chunked_prefill_size) + else: + new_kv_len = cur_total_len + + req_obj.cur_kv_len = new_kv_len + if is_master_in_dp: + shm_req.shm_cur_kv_len = req_obj.cur_kv_len + + # 这个地方主要是为了提前判断是否存在abort的情况,如果abort了 + # 直接将请求放入finished 处理队列中。 + if shm_req.router_aborted: + finished_req_ids.append(shm_req.request_id) + continue + + # 对于没有到达需要输出 token 阶段的请求,直接略过 + if req_obj.cur_kv_len < cur_total_len: + continue + + # 将生成的下一个token的信息写入到管理对象中。 + gen_token_index = cur_total_len + shm_req.shm_prompt_ids.arr[gen_token_index] = next_token_id + shm_req.shm_logprobs.arr[gen_token_index] = next_token_logprob + req_obj.cur_output_len += 1 + + req_obj.out_token_id_count[next_token_id] += 1 + __update_finish_status(req_obj, next_token_id, self.eos_id) + + if extra_post_req_handle_func is not None: + extra_post_req_handle_func(req_obj, next_token_id, next_token_logprob) + + # 判断是否已经满足生成结束条件。 + is_finished = req_obj.finish_status.is_finished() + if is_finished or shm_req.router_aborted: + finished_req_ids.append(shm_req.request_id) + + if is_master_in_dp: + # shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息 + # finish_token_index finish_status candetoken_out_len 是 + # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 + shm_req.shm_cur_output_len = req_obj.cur_output_len + + if is_finished: + shm_req.finish_token_index = gen_token_index + shm_req.finish_status = req_obj.finish_status + + shm_req.candetoken_out_len = req_obj.cur_output_len + + if do_filter_finished_reqs: + g_infer_context.filter(finished_req_ids) + + return finished_req_ids \ No newline at end of file diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index a3c58b355..6c1ed7511 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -65,7 +65,7 @@ def sample(logits, reqs, eos_id: List[int] = [2]): int64_batch_next_token_ids = torch.empty_like(batch_next_token_ids, dtype=torch.int64) int64_batch_next_token_ids[:] = batch_next_token_ids batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1)) - return batch_next_token_ids.view(-1), batch_next_token_probs.view(-1) + return int64_batch_next_token_ids.view(-1), batch_next_token_probs.view(-1) else: assert False, "dead path" diff --git a/requirements.txt b/requirements.txt index 47688ddad..439a45f5f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -88,3 +88,4 @@ flashinfer-python==0.2.4 sgl-kernel httpx==0.28.1 librosa==0.11.0 +Cython \ No newline at end of file diff --git a/setup.py b/setup.py index 1fcaa7ac0..810a9fd6a 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ from setuptools import setup, find_packages +from Cython.Build import cythonize package_data = {"lightllm": ["common/all_kernel_configs/*/*.json"]} setup( @@ -28,4 +29,10 @@ "triton", ], package_data=package_data, + ext_modules=cythonize( + [ + "lightllm/server/router/model_infer/mode_backend/cython_fast_impl.pyx", + ] + ), + zip_safe=False, )