Skip to content

Add metrics endpoint #1423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def add_parser_api_server():
'engine’s tasks once the maximum number of concurrent requests is '
'reached, regardless of any additional requests sent by clients '
'concurrently during that time. Default to None.')
parser.add_argument('--metrics', action='store_true', default=False, help='Whether log stats to prometheus')
# common args
ArgumentHelper.backend(parser)
ArgumentHelper.log_level(parser)
Expand Down Expand Up @@ -337,6 +338,7 @@ def api_server(args):
proxy_url=args.proxy_url,
max_log_len=args.max_log_len,
disable_fastapi_docs=args.disable_fastapi_docs,
metrics=args.metrics,
max_concurrent_requests=args.max_concurrent_requests,
reasoning_parser=args.reasoning_parser,
tool_call_parser=args.tool_call_parser)
Expand Down
147 changes: 87 additions & 60 deletions lmdeploy/serve/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from lmdeploy.logger import RequestLogger
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, ResponseType, TurbomindEngineConfig
from lmdeploy.model import MODELS, BaseChatTemplate, ChatTemplateConfig, best_match_model
from lmdeploy.serve.metrics import IterTimer, Metrics
from lmdeploy.serve.utils import LogitsMixin
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_hf_gen_cfg, get_logger
Expand Down Expand Up @@ -89,6 +90,11 @@ def _append_response(dst: Response, src: Response):
return dst


def is_error(status: ResponseType):
"""Whether is an error response."""
return status not in [ResponseType.SUCCESS, ResponseType.FINISH]


class Session:
"""Session for AsyncEngine.chat.

Expand Down Expand Up @@ -246,6 +252,7 @@ class AsyncEngine(LogitsMixin):
Default to None.
max_log_len (int): Max number of prompt characters or prompt tokens
being printed in log. Default: Unlimited
metrics (bool): Whether log the stats to prometheus.
"""

def __init__(self,
Expand All @@ -255,6 +262,7 @@ def __init__(self,
backend_config: Optional[Union[TurbomindEngineConfig, PytorchEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
max_log_len: int = None,
metrics: bool = False,
**kwargs) -> None:
logger.info(f'input backend={backend}, backend_config={backend_config}')
logger.info(f'input chat_template_config={chat_template_config}')
Expand Down Expand Up @@ -297,6 +305,8 @@ def __init__(self,
self.request_logger = RequestLogger(max_log_len)
self.internal_thread = _EventLoopThread(daemon=True)
self.limiter: asyncio.Semaphore = None
self.metrics = Metrics(metrics)
self.metrics.info(self.backend_config)

def close(self):
self.internal_thread.close()
Expand Down Expand Up @@ -585,6 +595,10 @@ async def model_inst(self, session_id: int):
inst._active.set()
free_insts.put_nowait(inst)

async def handle_exception(self, session_id: int):
self.metrics.failure_frame()
await self.stop_session(session_id)

@asynccontextmanager
async def safe_run(self, inst, session_id, **kwargs):
generator = inst.async_stream_infer(session_id, **kwargs)
Expand All @@ -593,7 +607,7 @@ async def safe_run(self, inst, session_id, **kwargs):
except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa
logger.error(f'[safe_run] exception caught: {type(e).__name__} {e}')
# TODO: remove session_id from async cancel
await inst.async_cancel(session_id)
await self.handle_exception(session_id)
finally:
await generator.aclose()

Expand Down Expand Up @@ -629,61 +643,70 @@ async def generate(
"""
if (messages is not None) ^ (input_ids is None):
raise ValueError('You must specify exactly one of messages or input_ids')
if session_id not in self.id2step:
self.id2step[session_id] = 0
if step != 0:
self.id2step[session_id] = step
if gen_config is None:
gen_config = GenerationConfig()
else:
gen_config = deepcopy(gen_config)
gen_config.convert_stop_bad_words_to_ids(self.tokenizer)
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id)
if not gen_config.do_sample:
logger.warning(f'GenerationConfig: {gen_config}')
logger.warning('Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
' decoding is needed')
# greedy decode
gen_config.top_k = 1
# avoid unnecessary process
gen_config.temperature = 1.0
gen_config.repetition_penalty = 1.0
# set random if it is not set and sequence_start is True
elif gen_config.random_seed is None and sequence_start:
gen_config.random_seed = random.getrandbits(64)
if gen_config.n > 1:
logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. "
f'Fallback to 1')
gen_config.n = 1
if messages:
prompt = messages
self.request_logger.log_prompt(session_id=session_id, prompt=prompt)
prompt_input = await self._get_prompt_input(prompt,
do_preprocess,
sequence_start,
adapter_name,
tools=tools)
prompt = prompt_input['prompt']
input_ids = prompt_input['input_ids']
self.request_logger.log_inputs(session_id=session_id,
prompt=prompt,
prompt_token_ids=input_ids,
gen_config=gen_config,
adapter_name=adapter_name)
logger.info(f'session={session_id}, '
f'history_tokens={self.id2step[session_id]}, '
f'input_tokens={len(input_ids)}, '
f'max_new_tokens={gen_config.max_new_tokens}, '
f'seq_start={sequence_start}, seq_end={sequence_end}, '
f'step={step}, prep={do_preprocess}')
else:
# TODO(lvhan) VLM doesn't support input_ids as an argument.
# Figure out a graceful way to handle the invalid input
prompt_input = dict(input_ids=input_ids)

async def get_inputs_genconfig(gen_config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to wrap it into a function?

if session_id not in self.id2step:
self.id2step[session_id] = 0
if step != 0:
self.id2step[session_id] = step
if gen_config is None:
gen_config = GenerationConfig()
else:
gen_config = deepcopy(gen_config)
gen_config.convert_stop_bad_words_to_ids(self.tokenizer)
if gen_config.stop_token_ids is None:
gen_config.stop_token_ids = self.stop_words
gen_config.update_from_hf_gen_cfg(self.hf_gen_cfg, self.tokenizer.eos_token_id)
if not gen_config.do_sample:
logger.warning(f'GenerationConfig: {gen_config}')
logger.warning('Since v0.6.0, lmdeploy add `do_sample` in '
'GenerationConfig. It defaults to False, meaning greedy '
'decoding. Please set `do_sample=True` if sampling '
' decoding is needed')
# greedy decode
gen_config.top_k = 1
# avoid unnecessary process
gen_config.temperature = 1.0
gen_config.repetition_penalty = 1.0
# set random if it is not set and sequence_start is True
elif gen_config.random_seed is None and sequence_start:
gen_config.random_seed = random.getrandbits(64)
if gen_config.n > 1:
logger.error(f"n({gen_config.n}) > 1 hasn't been supported yet. "
f'Fallback to 1')
gen_config.n = 1
if messages:
prompt = messages
self.request_logger.log_prompt(session_id=session_id, prompt=prompt)
prompt_input = await self._get_prompt_input(prompt,
do_preprocess,
sequence_start,
adapter_name,
tools=tools)
prompt = prompt_input['prompt']
input_ids = prompt_input['input_ids']
self.request_logger.log_inputs(session_id=session_id,
prompt=prompt,
prompt_token_ids=input_ids,
gen_config=gen_config,
adapter_name=adapter_name)
logger.info(f'session={session_id}, '
f'history_tokens={self.id2step[session_id]}, '
f'input_tokens={len(input_ids)}, '
f'max_new_tokens={gen_config.max_new_tokens}, '
f'seq_start={sequence_start}, seq_end={sequence_end}, '
f'step={step}, prep={do_preprocess}')
else:
# TODO(lvhan) VLM doesn't support input_ids as an argument.
# Figure out a graceful way to handle the invalid input
prompt_input = dict(input_ids=input_ids)
return prompt_input, gen_config

arrival_frame = self.metrics.insert_frame()
prompt_input, gen_config = await get_inputs_genconfig(gen_config)
input_ids = prompt_input['input_ids']
self.metrics.update_preprocess(arrival_frame)

if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(128, self.session_len - self.id2step[session_id] - len(input_ids))
Expand All @@ -697,15 +720,14 @@ async def generate(
await self.end_session(session_id)
return

def is_error(status):
return status not in [ResponseType.SUCCESS, ResponseType.FINISH]

# used to skip / rewind stop words in interactive mode
stop_ids = []
if skip_stop_tokens and not gen_config.ignore_eos:
stop_ids = gen_config.stop_token_ids or []

start_frame = self.metrics.insert_frame()
async with self.model_inst(session_id) as inst:
self.metrics.update_queue_waiting(start_frame)
token_ids = input_ids.copy()
history_len = self.id2step[session_id]
input_len = len(input_ids)
Expand All @@ -723,9 +745,12 @@ def is_error(status):
sequence_start=sequence_start,
sequence_end=sequence_end,
step=history_len) as gen:
if self.metrics.applied is True:
gen = IterTimer(gen)
prev_len = 0
hit_stop_token = 0
async for outputs in gen:
is_first_token = prev_len == 0
# decode res
if is_error(outputs.status):
break
Expand Down Expand Up @@ -771,7 +796,8 @@ def is_error(status):
out.logits = outputs.logits
if hit_stop_token:
out.logits = out.logits[:-hit_stop_token]

if is_first_token:
self.metrics.update_FTL(arrival_frame)
yield out
# end of generator loop

Expand Down Expand Up @@ -807,6 +833,7 @@ def is_error(status):
# rewind the step to the token before the stop token
output_len = gen_len
self.id2step[session_id] += input_len + output_len
self.metrics.last_token_frame(gen)

def _run(self, fn=None, coro=None, loop=None):
assert (fn or coro) and not (fn and coro)
Expand Down
Loading
Loading