|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | 2 | import asyncio |
| 3 | +import ctypes |
3 | 4 | import gc |
4 | 5 | import os |
5 | 6 | from dataclasses import dataclass |
|
9 | 10 | import torch |
10 | 11 |
|
11 | 12 | from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig |
| 13 | +from lmdeploy.pytorch import envs as _envs |
12 | 14 | from lmdeploy.pytorch.disagg.config import EngineRole |
13 | 15 | from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection |
14 | 16 | from lmdeploy.pytorch.disagg.conn.protocol import ( |
@@ -188,6 +190,8 @@ def __init__( |
188 | 190 | # infer sleeping from empty_init: empty_init still builds runtime |
189 | 191 | # resources and has its own weight-update workflow. |
190 | 192 | self._sleeping_tags = set() |
| 193 | + self._multimodal_session_trim_count = max(0, _envs.multimodal_session_trim_count) |
| 194 | + self._multimodal_session_end_count = 0 |
191 | 195 |
|
192 | 196 | # create main thread |
193 | 197 | self.req_manager.set_main_loop_func(self.async_loop) |
@@ -318,6 +322,37 @@ def _on_stop_session(self, reqs: list[Request], **kwargs): |
318 | 322 | if resp: |
319 | 323 | self._response(req.resp, resp_type) |
320 | 324 |
|
| 325 | + @staticmethod |
| 326 | + def _try_mem_trim(): |
| 327 | + """Try to trim memory.""" |
| 328 | + try: |
| 329 | + gc.collect() |
| 330 | + ctypes.CDLL('libc.so.6').malloc_trim(0) |
| 331 | + except Exception as e: |
| 332 | + logger.debug(f'Memory trim failed: {e}') |
| 333 | + |
| 334 | + @staticmethod |
| 335 | + def _has_multimodal_session(session) -> bool: |
| 336 | + """Check whether session has multimodal history.""" |
| 337 | + for seq in session.sequences.values(): |
| 338 | + history_multimodals = getattr(seq, 'history_multimodals', None) |
| 339 | + if history_multimodals is not None and not history_multimodals.empty(): |
| 340 | + return True |
| 341 | + return False |
| 342 | + |
| 343 | + def _maybe_trim_multimodal_session(self, has_multimodal: bool): |
| 344 | + """Trim host memory after enough multimodal sessions have ended.""" |
| 345 | + trim_count = getattr(self, '_multimodal_session_trim_count', max(0, _envs.multimodal_session_trim_count)) |
| 346 | + if not has_multimodal or trim_count <= 0: |
| 347 | + return |
| 348 | + |
| 349 | + self._multimodal_session_end_count = getattr(self, '_multimodal_session_end_count', 0) + 1 |
| 350 | + if self._multimodal_session_end_count < trim_count: |
| 351 | + return |
| 352 | + |
| 353 | + self._multimodal_session_end_count = 0 |
| 354 | + self._try_mem_trim() |
| 355 | + |
321 | 356 | def _on_end_session(self, reqs: list[Request], **kwargs): |
322 | 357 | """On end session callback.""" |
323 | 358 | for req in reqs: |
@@ -598,7 +633,9 @@ def start_loop(self): |
598 | 633 | def end_session(self, session_id: int): |
599 | 634 | """End session.""" |
600 | 635 | if session_id in self.scheduler.sessions: |
| 636 | + has_multimodal = self._has_multimodal_session(self.scheduler.sessions[session_id]) |
601 | 637 | self.scheduler.end_session(session_id) |
| 638 | + self._maybe_trim_multimodal_session(has_multimodal) |
602 | 639 | return True |
603 | 640 | return False |
604 | 641 |
|
|
0 commit comments