Skip to content

Commit 877e11c

Browse files
authored
fix: make vLLM backend teardown idempotent and remove keep_llm_state (#91)
Closes #88 - Merge `_clear_llm_state()` into idempotent `teardown()` with `_torn_down` guard - Remove `keep_llm_state` from `generate()` and base class - Add try/finally in SDK and CLI paths - Add shutdown lifecycle unit tests Made with [Cursor](https://cursor.com) --------- Signed-off-by: aagonzales <aagonzales@nvidia.com>
1 parent 95f38b0 commit 877e11c

7 files changed

Lines changed: 196 additions & 46 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ safe-synthesizer = "nemo_safe_synthesizer.cli.cli:cli"
326326
# Below is a list of excluded directories from ty typechecks.
327327
exclude = [
328328
".uv_cache", # Cache Dir in CI
329+
"./docs/**/*.ipynb",
329330
"./src/nemo_safe_synthesizer/artifacts/",
330331
"./src/nemo_safe_synthesizer/pii_replacer/",
331332
"./tests/",

src/nemo_safe_synthesizer/cli/run.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,18 @@ def run(
219219
from ..sdk.library_builder import SafeSynthesizer
220220

221221
ss: SafeSynthesizer = SafeSynthesizer(config=config, workdir=workdir).with_data_source(df)
222-
ss.run()
223-
ss.save_results(output_file=settings.output_file or workdir.output_file)
224-
ss.results.summary.log_summary(run_logger)
225-
ss.results.summary.timing.log_timing(run_logger)
226-
ss.results.summary.log_wandb()
222+
# ss.run() calls train + generate + evaluate. The generate step has its own try/finally,
223+
# but train or evaluate failures leave the generator loaded; this guard ensures teardown
224+
# on all exit paths of the full pipeline.
225+
try:
226+
ss.run()
227+
ss.save_results(output_file=settings.output_file or workdir.output_file)
228+
ss.results.summary.log_summary(run_logger)
229+
ss.results.summary.timing.log_timing(run_logger)
230+
ss.results.summary.log_wandb()
231+
finally:
232+
if hasattr(ss, "generator") and ss.generator is not None:
233+
ss.generator.teardown()
227234

228235

229236
@run.command("train")
@@ -359,9 +366,18 @@ def run_generate(
359366
if df is not None:
360367
ss = ss.with_data_source(df)
361368

362-
ss = ss.load_from_save_path().process_data().generate().evaluate().save_results(output_file=final_output_file)
363-
ss.generator.teardown()
364-
ss.results.summary.log_summary(run_logger)
365-
ss.results.summary.timing.log_timing(run_logger)
366-
run_logger.info(f"Generation complete. Results saved to: {final_output_file}")
367-
ss.results.summary.log_wandb()
369+
try:
370+
ss = (
371+
ss.load_from_save_path()
372+
.process_data()
373+
.generate()
374+
.evaluate()
375+
.save_results(output_file=final_output_file)
376+
)
377+
ss.results.summary.log_summary(run_logger)
378+
ss.results.summary.timing.log_timing(run_logger)
379+
run_logger.info(f"Generation complete. Results saved to: {final_output_file}")
380+
ss.results.summary.log_wandb()
381+
finally:
382+
if hasattr(ss, "generator") and ss.generator is not None:
383+
ss.generator.teardown()

src/nemo_safe_synthesizer/generation/backend.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def prepare_params(self, **kwargs) -> None:
102102
@abc.abstractmethod
103103
def generate(
104104
self,
105-
keep_llm_state: bool = True,
106105
data_actions_fn: utils.DataActionsFn | None = None,
107106
) -> GenerateJobResults:
108107
"""Run the batch generation loop and return aggregated results.
@@ -116,9 +115,6 @@ def generate(
116115
batch.
117116
118117
Args:
119-
keep_llm_state: If ``True``, keep the model in GPU memory
120-
after generation for potential reuse. If ``False``,
121-
GPU resources are freed immediately on completion.
122118
data_actions_fn: Optional post-processing / validation
123119
function applied to each batch of generated records.
124120
Typically reverses training-time preprocessing and

src/nemo_safe_synthesizer/generation/timeseries_backend.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,6 @@ def _retain_single_valid_response(self, batch: Batch) -> list[dict]:
856856

857857
def generate(
858858
self,
859-
keep_llm_state: bool = True,
860859
data_actions_fn: utils.DataActionsFn | None = None,
861860
) -> GenerateJobResults:
862861
"""Generate time-series tabular data using Nemo Safe Synthesizer.
@@ -872,7 +871,6 @@ def generate(
872871
seen during training (from model_metadata.initial_prefill).
873872
874873
Args:
875-
keep_llm_state: If True, keep the model in memory after generation.
876874
data_actions_fn: Optional function that takes a DataFrame and returns a modified DataFrame.
877875
878876
Returns:
@@ -924,9 +922,6 @@ def generate(
924922
batches.job_complete()
925923
batches.log_status()
926924

927-
if not keep_llm_state:
928-
self._clear_llm_state()
929-
930925
generation_time_sec = time.monotonic() - generation_start
931926
self.elapsed_time = generation_time_sec
932927
self.gen_results = GenerateJobResults.from_batches(

src/nemo_safe_synthesizer/generation/vllm_backend.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,13 @@ class VllmBackend(GeneratorBackend):
118118
avoid leaking sensitive data).
119119
"""
120120

121-
def __init__(self, config: SafeSynthesizerParameters, model_metadata: ModelMetadata, workdir: Workdir, **kwargs):
121+
def __init__(
122+
self,
123+
config: SafeSynthesizerParameters,
124+
model_metadata: ModelMetadata,
125+
workdir: Workdir,
126+
**kwargs,
127+
):
122128
self.model_metadata = model_metadata
123129
self.config = config
124130
self.remote = False
@@ -140,34 +146,43 @@ def __init__(self, config: SafeSynthesizerParameters, model_metadata: ModelMetad
140146
self.processor = create_processor(self.schema, self.model_metadata, self.config)
141147
adapter_path = self.workdir.adapter_path if self.workdir.adapter_path else self.model_metadata.adapter_path
142148
self.lora_req = LoRARequest("lora", 1, str(adapter_path)) if adapter_path else None
149+
self._torn_down = False
143150

144151
def teardown(self) -> None:
145-
"""Release GPU memory and clean up distributed resources."""
146-
self._clear_llm_state()
152+
"""Release GPU memory and distributed resources. Idempotent -- safe to call multiple times."""
153+
if self._torn_down:
154+
return
155+
self._torn_down = True
156+
157+
try:
158+
cleanup_dist_env_and_memory()
159+
except Exception:
160+
logger.debug("cleanup_dist_env_and_memory failed during teardown", exc_info=True)
147161

148-
def _clear_llm_state(self) -> None:
149-
"""Delete LLM state to free up GPU memory."""
150-
cleanup_dist_env_and_memory()
151-
# destroy_model_parallel()
152162
self.llm = None
153-
logger.debug("Cleaned up LLM")
154-
cleanup_memory()
155-
logger.debug("Cleaned up memory")
163+
self._gen_method = None
164+
self.gen_method = None
165+
166+
try:
167+
cleanup_memory()
168+
except Exception:
169+
logger.debug("cleanup_memory failed during teardown", exc_info=True)
156170

157171
def __del__(self) -> None:
158-
"""Clean up resources on garbage collection to prevent shutdown warnings."""
172+
"""Clean up resources on garbage collection."""
159173
try:
160-
self._clear_llm_state()
174+
self.teardown()
161175
except Exception:
162-
# Suppress errors during garbage collection to avoid masking other exceptions
163176
pass
164177

165178
def initialize(self, **kwargs) -> None:
166179
"""Initialize and load the model into memory."""
167-
# vLLM 0.11.x uses an environment variable for attention backend selection.
168-
# When vLLM is upgraded to 0.12+, migrate to the attention_backend constructor arg.
169-
if self.config.generation.attention_backend not in [None, "auto"]:
170-
os.environ["VLLM_ATTENTION_BACKEND"] = self.config.generation.attention_backend
180+
self._torn_down = False
181+
182+
# vLLM 0.12+ accepts attention_config as a constructor arg (replaces the
183+
# VLLM_ATTENTION_BACKEND env var used in 0.11.x).
184+
attn_backend = self.config.generation.attention_backend
185+
attention_config = {"backend": attn_backend} if attn_backend not in (None, "auto") else None
171186

172187
max_vram = get_max_vram()
173188
# note this only works for single GPU setups
@@ -194,6 +209,7 @@ def initialize(self, **kwargs) -> None:
194209
max_lora_rank=self.config.training.lora_r,
195210
structured_outputs_config=structured_outputs_config,
196211
enforce_eager=enforce_eager,
212+
attention_config=attention_config,
197213
)
198214

199215
def _build_structured_output_params(self) -> StructuredOutputsParams | None:
@@ -455,7 +471,6 @@ def _log_batch_timing_and_progress(
455471

456472
def generate(
457473
self,
458-
keep_llm_state: bool = True,
459474
data_actions_fn: utils.DataActionsFn | None = None,
460475
) -> GenerateJobResults:
461476
"""Generate synthetic tabular data in batches until the target count is reached.
@@ -465,9 +480,6 @@ def generate(
465480
a stopping condition fires.
466481
467482
Args:
468-
keep_llm_state: If ``True``, keep the model in GPU memory after
469-
generation for potential reuse. The model is still freed
470-
on garbage collection.
471483
data_actions_fn: Optional post-processing / validation function
472484
applied to each batch of generated records.
473485
@@ -529,9 +541,6 @@ def generate(
529541
batches.job_complete()
530542
batches.log_status()
531543

532-
if not keep_llm_state:
533-
self._clear_llm_state()
534-
535544
max_num_records = (
536545
self.config.generation.num_records
537546
if self.config.data.group_training_examples_by is None and batches.status == GenerationStatus.COMPLETE

src/nemo_safe_synthesizer/sdk/library_builder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,11 @@ def generate(self) -> SafeSynthesizer:
383383
config=self._nss_config, model_metadata=self._llm_metadata, workdir=self._workdir
384384
)
385385

386-
self.generator.initialize()
387-
self.generator.generate(keep_llm_state=False)
386+
try:
387+
self.generator.initialize()
388+
self.generator.generate()
389+
finally:
390+
self.generator.teardown()
388391
self._generated = True
389392
return self
390393

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for the VllmBackend teardown lifecycle."""
5+
6+
from unittest.mock import MagicMock, patch
7+
8+
import pytest
9+
10+
from nemo_safe_synthesizer.generation import vllm_backend as vllm_backend_mod
11+
from nemo_safe_synthesizer.generation.vllm_backend import VllmBackend
12+
13+
14+
@pytest.fixture
15+
def _mock_vllm_cleanup():
16+
"""Patch vLLM distributed cleanup so tests run without a GPU."""
17+
with (
18+
patch.object(vllm_backend_mod, "cleanup_dist_env_and_memory") as mock_dist,
19+
patch.object(vllm_backend_mod, "cleanup_memory") as mock_mem,
20+
):
21+
yield mock_dist, mock_mem
22+
23+
24+
@pytest.fixture
25+
def backend(_mock_vllm_cleanup, fixture_session_cache_dir):
26+
"""Create a VllmBackend with mocked dependencies."""
27+
mock_metadata = MagicMock()
28+
mock_metadata.adapter_path = None
29+
mock_metadata.instruction = "Generate"
30+
mock_metadata.prompt_config = MagicMock()
31+
mock_metadata.prompt_config.template = "{instruction} {schema}"
32+
33+
mock_config = MagicMock()
34+
# Pin branching fields so create_processor() selects TabularDataProcessor deterministically.
35+
mock_config.time_series.is_timeseries = False
36+
mock_config.data.group_training_examples_by = None
37+
# Pin to a valid literal so StructuredOutputsConfig Pydantic validation passes in initialize().
38+
mock_config.generation.structured_generation_backend = "xgrammar"
39+
mock_config.generation.attention_backend = None
40+
41+
mock_workdir = MagicMock()
42+
mock_workdir.schema_file = fixture_session_cache_dir / "schema.json"
43+
mock_workdir.schema_file.parent.mkdir(parents=True, exist_ok=True)
44+
mock_workdir.schema_file.write_text('{"properties": {"col_a": {"type": "string"}}}')
45+
mock_workdir.adapter_path = None
46+
47+
return VllmBackend(config=mock_config, model_metadata=mock_metadata, workdir=mock_workdir)
48+
49+
50+
class TestTeardownIdempotency:
51+
def test_first_teardown_runs_cleanup(self, backend, _mock_vllm_cleanup):
52+
mock_dist, mock_mem = _mock_vllm_cleanup
53+
backend.llm = MagicMock()
54+
55+
backend.teardown()
56+
57+
mock_dist.assert_called_once()
58+
mock_mem.assert_called_once()
59+
assert backend.llm is None
60+
assert backend._torn_down is True
61+
62+
def test_second_teardown_is_noop(self, backend, _mock_vllm_cleanup):
63+
mock_dist, mock_mem = _mock_vllm_cleanup
64+
65+
backend.teardown()
66+
mock_dist.reset_mock()
67+
mock_mem.reset_mock()
68+
69+
backend.teardown()
70+
71+
mock_dist.assert_not_called()
72+
mock_mem.assert_not_called()
73+
74+
def test_initialize_resets_guard(self, backend, _mock_vllm_cleanup):
75+
backend.teardown()
76+
assert backend._torn_down is True
77+
78+
with patch.object(vllm_backend_mod, "vLLM"):
79+
backend.initialize()
80+
81+
assert backend._torn_down is False
82+
83+
84+
class TestTeardownResilience:
85+
def test_cleanup_memory_runs_even_if_dist_cleanup_fails(self, backend, _mock_vllm_cleanup):
86+
mock_dist, mock_mem = _mock_vllm_cleanup
87+
mock_dist.side_effect = RuntimeError("distributed cleanup failed")
88+
89+
backend.teardown()
90+
91+
mock_dist.assert_called_once()
92+
mock_mem.assert_called_once()
93+
assert backend.llm is None
94+
95+
def test_llm_cleared_even_if_dist_cleanup_fails(self, backend, _mock_vllm_cleanup):
96+
mock_dist, _ = _mock_vllm_cleanup
97+
mock_dist.side_effect = RuntimeError("boom")
98+
backend.llm = MagicMock()
99+
100+
backend.teardown()
101+
102+
assert backend.llm is None
103+
104+
105+
class TestDunderDel:
106+
def test_del_calls_teardown(self, backend, _mock_vllm_cleanup):
107+
mock_dist, _ = _mock_vllm_cleanup
108+
109+
# Reset to isolate only this explicit __del__ call.
110+
mock_dist.reset_mock()
111+
backend.__del__()
112+
113+
mock_dist.assert_called_once()
114+
assert backend._torn_down is True
115+
116+
def test_del_suppresses_exceptions(self, backend, _mock_vllm_cleanup):
117+
mock_dist, _ = _mock_vllm_cleanup
118+
mock_dist.side_effect = RuntimeError("boom")
119+
120+
backend.__del__()
121+
122+
def test_del_after_teardown_is_noop(self, backend, _mock_vllm_cleanup):
123+
mock_dist, _ = _mock_vllm_cleanup
124+
125+
backend.teardown()
126+
mock_dist.reset_mock()
127+
128+
backend.__del__()
129+
130+
mock_dist.assert_not_called()

0 commit comments

Comments
 (0)