Skip to content

Commit d8da3b2

Browse files
authored
Merge pull request #144 from jaylfc/fix/worker-lifecycle
Fix worker lifecycle bugs causing NPU memory leaks
2 parents cc316e1 + 946ed13 commit d8da3b2

2 files changed

Lines changed: 251 additions & 33 deletions

File tree

src/rkllama/api/worker.py

Lines changed: 238 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,77 @@
44
import time
55
from datetime import datetime, timedelta
66
import os
7+
import signal
8+
import sys
79
import threading
810
import random
11+
import ctypes
12+
import atexit
913
from multiprocessing import Process, Pipe, Value
14+
15+
16+
# --- Orphan-safe worker helpers -------------------------------------------
17+
#
18+
# Worker processes are multiprocessing.Process children. If the parent dies
19+
# uncleanly (SIGKILL, crash, power loss), these children survive as orphans
20+
# owned by init (PID 1) and keep holding NPU memory until manually killed.
21+
# Three mitigations run together to prevent this:
22+
#
23+
# 1. Each worker calls prctl(PR_SET_PDEATHSIG, SIGTERM) on Linux so the
24+
# kernel sends it SIGTERM the moment its parent dies.
25+
# 2. The parent installs signal handlers + atexit to tear down all workers
26+
# on clean shutdown.
27+
# 3. On startup, the parent scans for rkllama_server processes whose PPID
28+
# is 1 (orphaned from a previous run) and kills them.
29+
30+
PR_SET_PDEATHSIG = 1
31+
32+
33+
def _set_parent_death_signal():
34+
"""On Linux, ask the kernel to SIGTERM us if our parent dies.
35+
36+
Safe no-op on other platforms.
37+
"""
38+
if sys.platform != "linux":
39+
return
40+
try:
41+
libc = ctypes.CDLL("libc.so.6", use_errno=True)
42+
libc.prctl(PR_SET_PDEATHSIG, signal.SIGTERM, 0, 0, 0)
43+
except Exception as exc:
44+
logger.warning("Could not set parent-death signal: %s", exc)
45+
46+
47+
def _kill_orphaned_workers():
48+
"""Find and kill any rkllama_server worker processes orphaned by a
49+
previous run (adopted by init, PID 1).
50+
51+
Only matches processes whose cmdline contains ``rkllama_server``, so
52+
unrelated python processes are left alone. Called once at parent
53+
startup.
54+
"""
55+
my_pid = os.getpid()
56+
killed = 0
57+
for proc in psutil.process_iter(["pid", "ppid", "cmdline"]):
58+
try:
59+
info = proc.info
60+
if info["pid"] == my_pid:
61+
continue
62+
if info["ppid"] != 1: # only orphans
63+
continue
64+
cmd = info.get("cmdline") or []
65+
if not any("rkllama_server" in (a or "") for a in cmd):
66+
continue
67+
logger.warning(
68+
"Killing orphaned rkllama worker pid=%s (cmd=%s)",
69+
info["pid"], " ".join(cmd[:3]),
70+
)
71+
proc.kill()
72+
killed += 1
73+
except (psutil.NoSuchProcess, psutil.AccessDenied):
74+
continue
75+
if killed:
76+
logger.info("Reaped %d orphaned worker(s) from previous run", killed)
77+
1078
from datetime import datetime, timedelta
1179
from.model_utils import get_model_size, get_encoder_model_path, get_property_modelfile, is_rkllm_model, get_rknn_onnx_files_from_model
1280
from .classes import *
@@ -147,7 +215,9 @@ def run_translation_generator(model_runtime, model_input):
147215

148216
# RKLLM Worker
149217
def run_rkllm_worker(name, worker_pipe, abort_flag, model_path, model_dir, options=None, lora_model_path = None, prompt_cache_path = None, base_domain_id = 0):
150-
218+
# Die with the parent to prevent NPU memory leaks on parent crash
219+
_set_parent_death_signal()
220+
151221
# Initialize individual callback for each worker to prevent error from RKLLM
152222
from .callback import callback_impl, global_text, last_embeddings, global_metrics
153223
from .rkllm import RKLLM
@@ -285,8 +355,10 @@ def run_rkllm_worker(name, worker_pipe, abort_flag, model_path, model_dir, optio
285355

286356

287357
# RKNN Worker
288-
def run_rknn_worker(name, worker_pipe, model_dir, options=None):
289-
358+
def run_rknn_worker(name, worker_pipe, model_dir, options=None):
359+
# Die with the parent to prevent NPU memory leaks on parent crash
360+
_set_parent_death_signal()
361+
290362
from rknnlite.api.rknn_lite import RKNNLite
291363
import onnxruntime
292364

@@ -422,10 +494,35 @@ def run_rknn_process(name, task, model_input):
422494
class WorkerManager:
423495
def __init__(self):
424496
self.workers = {} # (name -> Worker)
497+
self._add_worker_lock = threading.Lock()
498+
499+
# Reap any orphaned workers left behind by a previous rkllama run
500+
# (e.g. SIGKILL, power loss). Without this, restarting the server
501+
# would leave NPU memory occupied until the host reboots.
502+
_kill_orphaned_workers()
503+
504+
# Ensure clean shutdown: on SIGTERM/SIGINT/atexit, stop all workers
505+
# so NPU memory is freed. Combined with PR_SET_PDEATHSIG in each
506+
# worker, this covers graceful and ungraceful shutdowns.
507+
atexit.register(self.stop_all)
508+
try:
509+
signal.signal(signal.SIGTERM, self._handle_shutdown_signal)
510+
signal.signal(signal.SIGINT, self._handle_shutdown_signal)
511+
except (ValueError, OSError):
512+
# signal.signal only works in the main thread; Flask dev mode
513+
# can spawn us in a secondary thread. Atexit still covers it.
514+
pass
425515

426516
# Start the monitor of running models
427517
self.start_models_monitor()
428518

519+
def _handle_shutdown_signal(self, signum, frame):
520+
logger.info("Received signal %s, stopping all workers...", signum)
521+
try:
522+
self.stop_all()
523+
finally:
524+
sys.exit(0)
525+
429526
def start_models_monitor(self, interval=60):
430527
"""
431528
Start a threat to monitor expired models to unload them from memory
@@ -439,6 +536,12 @@ def execute():
439536
# Wait for the next execution
440537
time.sleep(interval) # Check every 60 seconds expired models
441538

539+
# Reap any workers whose process died unexpectedly (e.g.
540+
# segfault, OOM kill) while the parent is still alive.
541+
# Without this, api/ps would keep reporting a "loaded"
542+
# model that has no real process backing it.
543+
self.reap_dead_workers()
544+
442545
# Call the process to unload expired models
443546
self.unload_expired_models()
444547

@@ -454,17 +557,79 @@ def execute():
454557
logger.info("Models Monitor running.")
455558

456559

560+
def reap_dead_workers(self) -> None:
561+
"""
562+
Remove worker entries whose subprocess has died unexpectedly.
563+
564+
A worker can die while the parent is alive (segfault in RKLLM C++,
565+
kernel OOM, external SIGKILL). ``multiprocessing.Process.is_alive()``
566+
returns False for such zombies. Without this reaper, ``self.workers``
567+
would keep the stale entry forever, causing ``api/ps`` to lie and
568+
preventing a fresh load of the same model.
569+
"""
570+
dead = []
571+
for model_name, worker in list(self.workers.items()):
572+
proc = worker.process
573+
if proc is None:
574+
continue
575+
if not proc.is_alive() and proc.exitcode is not None:
576+
dead.append((model_name, proc.exitcode))
577+
578+
for model_name, exitcode in dead:
579+
logger.warning(
580+
"Worker for model '%s' died unexpectedly (exitcode=%s); "
581+
"cleaning up stale entry.",
582+
model_name, exitcode,
583+
)
584+
try:
585+
worker = self.workers[model_name]
586+
# Reap the zombie if it's still in the process table
587+
try:
588+
worker.process.join(timeout=1)
589+
except Exception:
590+
pass
591+
try:
592+
worker.manager_pipe.close()
593+
except Exception:
594+
pass
595+
del self.workers[model_name]
596+
try:
597+
import rkllama.api.variables as variables
598+
variables.remove_model_lock(model_name)
599+
except Exception:
600+
pass
601+
except KeyError:
602+
pass
603+
457604
def unload_expired_models(self) -> int | None:
458605
"""
459606
Unload/stop workers for expired models
460607
"""
461608
# Get all expired models
462609
expired_models = [ model for model in self.workers.keys() if datetime.now() > self.workers[model].worker_model_info.expires_at ]
463-
610+
464611
# Unload/stop the expired model
465612
for model_name in expired_models:
466613
logger.info(f"Detected expired model: {model_name}")
467-
self.stop_worker(model_name)
614+
try:
615+
self.stop_worker(model_name)
616+
except Exception as e:
617+
logger.error(f"Failed to stop expired worker {model_name}: {e}")
618+
# Ensure the worker entry is cleaned up even if stop_worker
619+
# raised, so we don't leak dictionary entries for dead processes.
620+
if model_name in self.workers:
621+
process = self.workers[model_name].process
622+
if process is not None and process.is_alive():
623+
try:
624+
process.kill()
625+
process.join(timeout=5)
626+
except Exception:
627+
pass
628+
try:
629+
self.workers[model_name].manager_pipe.close()
630+
except Exception:
631+
pass
632+
del self.workers[model_name]
468633

469634

470635
def clear_old_cache_prompts(self) -> int | None:
@@ -547,13 +712,18 @@ def exists_model_loaded(self, model_name: str) -> bool:
547712
return model_name in self.workers.keys()
548713

549714

550-
def add_worker(self, model_name, model_path, model_dir, options=None, lora_model_path = None, prompt_cache_path = None) -> bool:
715+
def add_worker(self, model_name, model_path, model_dir, options=None, lora_model_path = None, prompt_cache_path = None, loaded_by=None) -> bool:
551716
"""
552717
Add a process worker to run inferences call from a specific model
553-
718+
554719
Args:
555720
model_name (str): model name to load in memory
721+
loaded_by (str): identifier of the client/process that triggered the load
556722
"""
723+
with self._add_worker_lock:
724+
return self._add_worker_locked(model_name, model_path, model_dir, options, lora_model_path, prompt_cache_path, loaded_by)
725+
726+
def _add_worker_locked(self, model_name, model_path, model_dir, options=None, lora_model_path = None, prompt_cache_path = None, loaded_by=None) -> bool:
557727
if model_name not in self.workers.keys():
558728

559729
if is_rkllm_model(model_name):
@@ -564,7 +734,7 @@ def add_worker(self, model_name, model_path, model_dir, options=None, lora_model
564734
base_domain_id = 0
565735

566736
# Add the worker to the dictionary of workers
567-
worker_model = Worker(model_name,base_domain_id)
737+
worker_model = Worker(model_name,base_domain_id,loaded_by=loaded_by)
568738

569739
# Check if available meory in server
570740
if not self.is_memory_available_for_model(worker_model.worker_model_info.size):
@@ -722,34 +892,72 @@ def get_result(self, model_name):
722892
return None
723893

724894

725-
def stop_worker(self, model_name):
895+
def stop_worker(self, model_name, timeout=30):
726896
"""
727-
Stop/Unload a model worker
728-
897+
Stop/Unload a model worker. Sends an unload command and waits up to
898+
``timeout`` seconds for the worker process to exit. If the process
899+
does not exit in time it is forcefully killed so that NPU memory is
900+
always reclaimed.
901+
729902
Args:
730903
model_name (str): Workers to unload.
904+
timeout (int): Seconds to wait for a graceful shutdown before
905+
force-killing the worker process. Default 30.
731906
732907
"""
733-
if model_name in self.workers.keys():
734-
if is_rkllm_model(model_name):
735-
# RKLLM
736-
# Send the abort task of the model if currently is running some inference
737-
self.workers[model_name].manager_pipe.send((WORKER_TASK_ABORT_INFERENCE,None,None,None))
908+
if model_name not in self.workers:
909+
return
738910

739-
# Send the unload task of the model
740-
self.workers[model_name].manager_pipe.send((WORKER_TASK_UNLOAD_MODEL,None,None,None))
911+
process = self.workers[model_name].process
912+
pipe = self.workers[model_name].manager_pipe
913+
914+
# --- 1. Request graceful shutdown via pipe -------------------------
915+
try:
916+
if is_rkllm_model(model_name):
917+
# RKLLM – abort any running inference first
918+
pipe.send((WORKER_TASK_ABORT_INFERENCE, None, None, None))
919+
pipe.send((WORKER_TASK_UNLOAD_MODEL, None, None, None))
741920
else:
742921
# RKNN
743-
# Send the unload task of the model
744-
self.workers[model_name].manager_pipe.send((WORKER_TASK_UNLOAD_MODEL,None))
745-
922+
pipe.send((WORKER_TASK_UNLOAD_MODEL, None))
923+
except (BrokenPipeError, OSError) as e:
924+
# Pipe already broken – the worker may have crashed earlier.
925+
logger.warning(f"Could not send unload command to worker {model_name}: {e}")
926+
927+
# --- 2. Wait for the process to exit gracefully -------------------
928+
if process is not None and process.is_alive():
929+
process.join(timeout=timeout)
930+
931+
# --- 3. Force-kill if still alive ---------------------------------
932+
if process is not None and process.is_alive():
933+
logger.warning(
934+
f"Worker {model_name} did not exit within {timeout}s, "
935+
"sending SIGKILL to reclaim resources."
936+
)
937+
try:
938+
process.kill() # SIGKILL on Unix
939+
process.join(timeout=5) # reap the zombie
940+
except Exception as e:
941+
logger.error(f"Failed to kill worker {model_name}: {e}")
942+
943+
logger.info(f"Worker {model_name} stopped.")
944+
945+
# --- 4. Cleanup bookkeeping ---------------------------------------
946+
# Close our end of the pipe to avoid resource leaks
947+
try:
948+
pipe.close()
949+
except Exception:
950+
pass
746951

747-
# Wait for unload
748-
self.workers[model_name].process.join()
749-
logger.info(f"Worker {model_name} stopped...")
952+
# Remove the worker from the dictionary
953+
del self.workers[model_name]
750954

751-
# Remove the worker from the dictionary
752-
del self.workers[model_name]
955+
# Clean up per-model lock from variables module
956+
try:
957+
import rkllama.api.variables as variables
958+
variables.remove_model_lock(model_name)
959+
except Exception:
960+
pass # variables module may not be fully initialized during shutdown
753961

754962
def stop_all(self):
755963
"""
@@ -1067,19 +1275,20 @@ def get_finished_inference_token(self):
10671275

10681276
# Class to manage the information for running RKLLM models
10691277
class WorkerModelInfo:
1070-
def __init__(self, model_name, base_domain_id):
1278+
def __init__(self, model_name, base_domain_id, loaded_by=None):
10711279
self.model = model_name
10721280
self.size = get_model_size(model_name)
10731281
self.expires_at = datetime.now() + timedelta(minutes=int(rkllama.config.get("model", "max_minutes_loaded_in_memory")))
10741282
self.loaded_at = datetime.now()
10751283
self.base_domain_id = base_domain_id
10761284
self.last_call = datetime.now()
1285+
self.loaded_by = loaded_by or "unknown"
10771286

10781287

10791288
# Class to manage the information for running RKLLM models
10801289
class Worker:
1081-
def __init__(self, model_name, base_domain_id):
1082-
self.worker_model_info = WorkerModelInfo(model_name=model_name, base_domain_id=base_domain_id)
1290+
def __init__(self, model_name, base_domain_id, loaded_by=None):
1291+
self.worker_model_info = WorkerModelInfo(model_name=model_name, base_domain_id=base_domain_id, loaded_by=loaded_by)
10831292
self.process = None
10841293
self.manager_pipe, self.worker_pipe = Pipe()
10851294
self.abort_flag = Value('b', False)

0 commit comments

Comments
 (0)