Skip to content

Commit 73e4de5

Browse files
committed
fix(trainer): ensure proper resource cleanup during evaluation cycle and improve process termination handling in inference server
1 parent 1abb78a commit 73e4de5

File tree

2 files changed

+67
-20
lines changed

2 files changed

+67
-20
lines changed

grail/neurons/trainer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,12 +722,19 @@ async def _run_server_evaluation(
722722
server_model_name=server.model_name,
723723
)
724724

725-
metrics = await self._run_evaluation_cycle(
726-
plan=plan,
727-
window_number=window_number,
728-
env_factory=env_factory,
729-
evaluator=evaluator,
730-
)
725+
try:
726+
metrics = await self._run_evaluation_cycle(
727+
plan=plan,
728+
window_number=window_number,
729+
env_factory=env_factory,
730+
evaluator=evaluator,
731+
)
732+
finally:
733+
# Explicitly shutdown evaluator before server context exits
734+
# to ensure all resources are released before vLLM process is killed
735+
evaluator.shutdown()
736+
del tokenizer
737+
gc.collect()
731738

732739
logger.info("Server shutdown complete")
733740
return metrics

grail/trainer/inference_server.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -235,31 +235,69 @@ async def _terminate_process(
235235
proc: subprocess.Popen[bytes] | None,
236236
wait_for_gpu: bool = True,
237237
) -> None:
238-
"""Terminate subprocess gracefully and optionally wait for GPU memory release."""
238+
"""Terminate subprocess and all children, optionally wait for GPU memory release.
239+
240+
Uses process group kill to ensure vLLM/SGLang worker processes are terminated.
241+
These workers hold GPU memory and must be killed for clean shutdown.
242+
"""
239243
if proc is None:
240244
logger.debug("_terminate_process: proc is None, skipping")
241245
return
242246

243247
import asyncio
248+
import signal
244249
import time
245250

246-
logger.info("_terminate_process: sending SIGTERM to pid=%s", proc.pid)
251+
pid = proc.pid
252+
logger.info("_terminate_process: terminating process group for pid=%s", pid)
253+
247254
try:
248-
proc.terminate()
255+
# Get process group ID (same as pid when start_new_session=True)
256+
try:
257+
pgid = os.getpgid(pid)
258+
except ProcessLookupError:
259+
logger.info("Process %s already exited, skipping termination", pid)
260+
return
261+
262+
# Send SIGTERM to entire process group (including vLLM worker children)
263+
logger.info("_terminate_process: sending SIGTERM to process group pgid=%s", pgid)
249264
try:
250-
logger.debug("_terminate_process: waiting for process to exit (timeout=10s)...")
265+
os.killpg(pgid, signal.SIGTERM)
266+
except ProcessLookupError:
267+
logger.info("Process group %s already exited", pgid)
268+
return
269+
270+
# Wait for parent process to exit
271+
try:
272+
logger.debug("_terminate_process: waiting for process group to exit (timeout=10s)")
251273
proc.wait(timeout=10)
252-
logger.info("Server process terminated (pid=%s)", proc.pid)
274+
logger.info("Server process group terminated (pgid=%s)", pgid)
253275
except subprocess.TimeoutExpired:
254-
logger.warning("Process didn't exit gracefully, force killing pid=%s", proc.pid)
255-
proc.kill()
256-
logger.debug(
257-
"_terminate_process: waiting for killed process to exit (timeout=5s)..."
258-
)
259-
proc.wait(timeout=5)
260-
logger.info("Process killed and reaped (pid=%s)", proc.pid)
276+
# Force kill entire process group
277+
logger.warning("Process group didn't exit gracefully, force killing pgid=%s", pgid)
278+
try:
279+
os.killpg(pgid, signal.SIGKILL)
280+
except ProcessLookupError:
281+
pass # Already dead
282+
logger.debug("_terminate_process: waiting for killed process group (timeout=5s)")
283+
try:
284+
proc.wait(timeout=5)
285+
except subprocess.TimeoutExpired:
286+
logger.error("Failed to kill process group pgid=%s", pgid)
287+
logger.info("Process group killed (pgid=%s)", pgid)
288+
261289
except Exception as exc:
262-
logger.warning("Error terminating process: %s", exc)
290+
logger.warning("Error terminating process group: %s", exc)
291+
# Fallback to simple process termination
292+
try:
293+
proc.terminate()
294+
proc.wait(timeout=5)
295+
except Exception:
296+
try:
297+
proc.kill()
298+
proc.wait(timeout=5)
299+
except Exception:
300+
pass
263301

264302
# Wait for GPU memory release if requested
265303
if wait_for_gpu and torch.cuda.is_available():
@@ -391,6 +429,7 @@ async def _start_server(self) -> None:
391429
stderr=stderr_target,
392430
text=False,
393431
env=popen_env,
432+
start_new_session=True, # Create process group for clean shutdown of all workers
394433
)
395434
logger.info(
396435
"Launched vLLM server: pid=%s host=%s port=%s",
@@ -611,14 +650,15 @@ async def _start_server(self) -> None:
611650
popen_env = os.environ.copy()
612651
if self._config.env:
613652
popen_env.update(self._config.env)
614-
logger.info("vLLM server using custom environment: %s", self._config.env)
653+
logger.info("SGLang server using custom environment: %s", self._config.env)
615654

616655
self._process = subprocess.Popen(
617656
cmd,
618657
stdout=stdout_target,
619658
stderr=stderr_target,
620659
text=False,
621660
env=popen_env,
661+
start_new_session=True, # Create process group for clean shutdown of all workers
622662
)
623663
logger.info(
624664
"Launched SGLang server: pid=%s host=%s port=%s",

0 commit comments

Comments
 (0)