Skip to content

Commit 2cd0d9f

Browse files
Merge branch 'InternLM:main' into add_k4v2
2 parents ee6cdc7 + 73a1121 commit 2cd0d9f

5 files changed

Lines changed: 64 additions & 2 deletions

File tree

lmdeploy/pytorch/disagg/conn/engine_conn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def handle_zmq_recv(self, remote_engine_id: str):
8080
if isinstance(req, DistServeCacheFreeRequest):
8181
session_id = req.remote_session_id
8282
if session_id in self.engine.scheduler.sessions:
83-
self.engine.scheduler.end_session(session_id=session_id)
83+
self.engine.end_session(session_id=session_id)
8484
else:
8585
logger.error(f'invalid free, {remote_engine_id}, {session_id}')
8686
else:

lmdeploy/pytorch/engine/engine.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import asyncio
3+
import ctypes
34
import gc
45
import os
56
from dataclasses import dataclass
@@ -9,6 +10,7 @@
910
import torch
1011

1112
from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig
13+
from lmdeploy.pytorch import envs as _envs
1214
from lmdeploy.pytorch.disagg.config import EngineRole
1315
from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection
1416
from lmdeploy.pytorch.disagg.conn.protocol import (
@@ -188,6 +190,8 @@ def __init__(
188190
# infer sleeping from empty_init: empty_init still builds runtime
189191
# resources and has its own weight-update workflow.
190192
self._sleeping_tags = set()
193+
self._multimodal_session_trim_count = max(0, _envs.multimodal_session_trim_count)
194+
self._multimodal_session_end_count = 0
191195

192196
# create main thread
193197
self.req_manager.set_main_loop_func(self.async_loop)
@@ -318,6 +322,37 @@ def _on_stop_session(self, reqs: list[Request], **kwargs):
318322
if resp:
319323
self._response(req.resp, resp_type)
320324

325+
@staticmethod
326+
def _try_mem_trim():
327+
"""Try to trim memory."""
328+
try:
329+
gc.collect()
330+
ctypes.CDLL('libc.so.6').malloc_trim(0)
331+
except Exception as e:
332+
logger.debug(f'Memory trim failed: {e}')
333+
334+
@staticmethod
335+
def _has_multimodal_session(session) -> bool:
336+
"""Check whether session has multimodal history."""
337+
for seq in session.sequences.values():
338+
history_multimodals = getattr(seq, 'history_multimodals', None)
339+
if history_multimodals is not None and not history_multimodals.empty():
340+
return True
341+
return False
342+
343+
def _maybe_trim_multimodal_session(self, has_multimodal: bool):
344+
"""Trim host memory after enough multimodal sessions have ended."""
345+
trim_count = getattr(self, '_multimodal_session_trim_count', max(0, _envs.multimodal_session_trim_count))
346+
if not has_multimodal or trim_count <= 0:
347+
return
348+
349+
self._multimodal_session_end_count = getattr(self, '_multimodal_session_end_count', 0) + 1
350+
if self._multimodal_session_end_count < trim_count:
351+
return
352+
353+
self._multimodal_session_end_count = 0
354+
self._try_mem_trim()
355+
321356
def _on_end_session(self, reqs: list[Request], **kwargs):
322357
"""On end session callback."""
323358
for req in reqs:
@@ -598,7 +633,9 @@ def start_loop(self):
598633
def end_session(self, session_id: int):
599634
"""End session."""
600635
if session_id in self.scheduler.sessions:
636+
has_multimodal = self._has_multimodal_session(self.scheduler.sessions[session_id])
601637
self.scheduler.end_session(session_id)
638+
self._maybe_trim_multimodal_session(has_multimodal)
602639
return True
603640
return False
604641

lmdeploy/pytorch/engine/mp_engine/zmq_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ def _mp_proc(
9797

9898
from .zmq_rpc import AsyncRPCServer
9999

100+
# try rename the process
101+
try:
102+
import ctypes
103+
ctypes.CDLL(None).prctl(15, b'ZMQMPEngine', 0, 0, 0)
104+
except Exception as e:
105+
logger.debug(f'Failed to rename MPEngine process: {e}')
106+
100107
logger.setLevel(log_level)
101108

102109
# create an async rpc server

lmdeploy/pytorch/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ def _patched_get_env(
157157
# model agent
158158
skip_warmup = env_to_bool('LMDEPLOY_SKIP_WARMUP', False)
159159

160+
# memory trim
161+
multimodal_session_trim_count = env_to_int('LMDEPLOY_MULTIMODAL_SESSION_TRIM_COUNT', 128)
162+
160163
# model format
161164
scale_fmt = os.getenv('LMDEPLOY_SCALE_FMT', None)
162165

lmdeploy/turbomind/models/internvl.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ def _cfg_get(cfg, name: str, default=None):
1515
return getattr(cfg, name, default)
1616

1717

18+
def map_interns1_hf_keys(name: str) -> str:
19+
"""Map Intern-S1 HF VLM checkpoint keys to the Qwen3 text loader layout."""
20+
language_model_prefix = 'model.language_model.'
21+
if name.startswith(language_model_prefix):
22+
suffix = name[len(language_model_prefix):]
23+
return f'language_model.model.{suffix}'
24+
if name.startswith('lm_head.'):
25+
return f'language_model.{name}'
26+
return name
27+
28+
1829
@INPUT_MODELS.register_module(name='internvl')
1930
class InternVLModel:
2031
"""Aggregate source model for InternVL checkpoints with any registered text
@@ -42,6 +53,10 @@ def __init__(self, cfg: PretrainedConfig, *, resolver):
4253

4354
text_model_cls = INPUT_MODELS.get(text_model_registered_name)
4455
self.text_model = text_model_cls(llm_cfg, resolver=resolver)
56+
archs = _cfg_get(cfg, 'architectures') or []
57+
self._checkpoint_mappings = []
58+
if archs and archs[0] == 'InternS1ForConditionalGeneration':
59+
self._checkpoint_mappings.append(map_interns1_hf_keys)
4560
self.vision_model = None
4661

4762
def bind_runtime(self, *, ctx, root_handles,
@@ -60,7 +75,7 @@ def _vocab_size(self):
6075

6176
@property
6277
def _loader_mappings(self):
63-
return list(getattr(type(self.text_model), '_loader_mappings', []))
78+
return self._checkpoint_mappings + list(getattr(type(self.text_model), '_loader_mappings', []))
6479

6580
def model(self, pfx):
6681
self.text_model.model(pfx + 'language_model')

0 commit comments

Comments
 (0)