Skip to content

Commit 8e665ca

Browse files
authored
fix(sdk): unify behavior between cli and sdk to auto save results (#272)
<!-- SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. --> <!-- SPDX-License-Identifier: Apache-2.0 --> <!-- Thank you for contributing to Safe Synthesizer! --> # Summary <!-- Brief description of changes --> This PR fixes row 20 from the [bug bash feedback](https://docs.google.com/spreadsheets/d/1fmFYl89hUCNMROR3B_hnq4oqaqM8tAQJJc7p2P17ep4/edit?gid=0#gid=0). Previously, cli calls `.save_results()` automatically, while for sdk `.save_results()` need to be explicitly called. With this change, we call `.save_results()` in `.run()`, which applies to both cli and sdk. ## Pre-Review Checklist <!-- These checks should be completed before a PR is reviewed, --> <!-- but you can submit a draft early to indicate that the issue is being worked on. --> Ensure that the following pass: - [x] `make format && make check` or via prek validation. - [x] `make test` passes locally - [x] `make test-e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) - [ ] GPU CI status check passes -- comment `/sync` on this PR to trigger a run (auto-triggers on ready-for-review) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [x] New or updated tests for any fix or new behavior - [x] Updated documentation for new features and behaviors, including docstrings for API docs. ## Testing - CLI - default output path: `safe-synthesizer run --config /root/configs/quick-tinyllama-unsloth.yaml --data-source /root/datasets/clinc_oos.csv`. Everything's there: <img width="298" height="295" alt="image" src="https://github.com/user-attachments/assets/b02547b1-a98b-4cff-b2be-dc7ced8bda66" /> - output file override: `safe-synthesizer run --config /root/configs/quick-tinyllama-unsloth.yaml --data-source /root/datasets/clinc_oos.csv --output-file /root/Safe-Synthesizer/safe-synthesizer-artifacts/output-path-override/synth.csv`: everything else is still in the default path, except for the generated csv <img width="300" height="402" alt="image" src="https://github.com/user-attachments/assets/1a710b64-dab0-4592-9c4d-1e96eed2a85d" /> - SDK - ran the 101 notebook: ``` from nemo_safe_synthesizer.sdk.library_builder import SafeSynthesizer builder = SafeSynthesizer().with_data_source(df).with_replace_pii(enable=False).with_train(num_input_records_to_sample=1000).resolve() builder.run() results = builder.results ``` everything's there <img width="293" height="331" alt="image" src="https://github.com/user-attachments/assets/6b1b1b4b-a00a-4366-9282-da6ad2e0230e" /> --------- Signed-off-by: nina-xu <19981858+nina-xu@users.noreply.github.com>
1 parent 969f0dc commit 8e665ca

7 files changed

Lines changed: 100 additions & 45 deletions

File tree

STYLE_GUIDE.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,10 @@ Standard context managers (`with open(...)`, `with lock:`) are fine when they fi
357357

358358
```python
359359
try:
360-
ss.run()
361-
ss.save_results(workdir)
360+
nss.run() # saves results automatically
362361
finally:
363-
if hasattr(ss, "generator") and ss.generator is not None:
364-
ss.generator.teardown()
362+
if hasattr(nss, "generator") and nss.generator is not None:
363+
nss.generator.teardown()
365364
```
366365

367366
For backends with expensive resources, use the `_torn_down` guard pattern:

docs/user-guide/running.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -960,10 +960,11 @@ the CLI Commands section for all options.
960960
Path("./safe-synthesizer-artifacts/myconfig---mydata/2026-01-15T12:00:00")
961961
)
962962
synthesizer = SafeSynthesizer(config, workdir=workdir)
963-
synthesizer.process_data()
964963
synthesizer.load_from_save_path()
964+
synthesizer.process_data()
965965
synthesizer.generate()
966966
synthesizer.evaluate()
967+
synthesizer.save_results()
967968
```
968969

969970
### Stepwise execution (SDK)
@@ -1031,11 +1032,15 @@ Key outputs:
10311032

10321033
### SDK Results Access
10331034

1035+
`run()` automatically saves `synthetic_data.csv` and `evaluation_report.html`
1036+
to the artifacts directory unless an `output_file` override is provided.
1037+
For stepwise execution, call `save_results()` explicitly after `evaluate()`.
1038+
10341039
```python
10351040
results = synthesizer.results
10361041
df = results.synthetic_data
10371042
summary = results.summary
1038-
synthesizer.save_results()
1043+
# synthesizer.save_results() # only needed for stepwise execution; run() saves automatically
10391044
```
10401045

10411046
### Cleaning Up

src/nemo_safe_synthesizer/cli/run.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -218,19 +218,18 @@ def run(
218218
with traced_user("SafeSynthesizer"):
219219
from ..sdk.library_builder import SafeSynthesizer
220220

221-
ss: SafeSynthesizer = SafeSynthesizer(config=config, workdir=workdir).with_data_source(df)
222-
# ss.run() calls train + generate + evaluate. The generate step has its own try/finally,
223-
# but train or evaluate failures leave the generator loaded; this guard ensures teardown
224-
# on all exit paths of the full pipeline.
221+
nss: SafeSynthesizer = SafeSynthesizer(config=config, workdir=workdir).with_data_source(df)
222+
# nss.run() calls train + generate + evaluate + save_results. The generate step has its
223+
# own try/finally, but train or evaluate failures leave the generator loaded; this guard
224+
# ensures teardown on all exit paths of the full pipeline.
225225
try:
226-
ss.run()
227-
ss.save_results(output_file=settings.output_file or workdir.output_file)
228-
ss.results.summary.log_summary(run_logger)
229-
ss.results.summary.timing.log_timing(run_logger)
230-
ss.results.summary.log_wandb()
226+
nss.run(output_file=settings.output_file)
227+
nss.results.summary.log_summary(run_logger)
228+
nss.results.summary.timing.log_timing(run_logger)
229+
nss.results.summary.log_wandb()
231230
finally:
232-
if hasattr(ss, "generator") and ss.generator is not None:
233-
ss.generator.teardown()
231+
if hasattr(nss, "generator") and nss.generator is not None:
232+
nss.generator.teardown()
234233

235234

236235
@run.command("train")
@@ -359,25 +358,25 @@ def run_generate(
359358

360359
final_output_file = settings.output_file or workdir.output_file
361360
with traced_user("SafeSynthesizer"):
362-
ss = SafeSynthesizer(config, workdir=workdir)
361+
nss = SafeSynthesizer(config, workdir=workdir)
363362

364363
# Only set data source if provided via --data-source
365364
# Otherwise, load_from_save_path() will load from cached files
366365
if df is not None:
367-
ss = ss.with_data_source(df)
366+
nss = nss.with_data_source(df)
368367

369368
try:
370-
ss = (
371-
ss.load_from_save_path()
369+
nss = (
370+
nss.load_from_save_path()
372371
.process_data()
373372
.generate()
374373
.evaluate()
375374
.save_results(output_file=final_output_file)
376375
)
377-
ss.results.summary.log_summary(run_logger)
378-
ss.results.summary.timing.log_timing(run_logger)
376+
nss.results.summary.log_summary(run_logger)
377+
nss.results.summary.timing.log_timing(run_logger)
379378
run_logger.info(f"Generation complete. Results saved to: {final_output_file}")
380-
ss.results.summary.log_wandb()
379+
nss.results.summary.log_wandb()
381380
finally:
382-
if hasattr(ss, "generator") and ss.generator is not None:
383-
ss.generator.teardown()
381+
if hasattr(nss, "generator") and nss.generator is not None:
382+
nss.generator.teardown()

src/nemo_safe_synthesizer/sdk/AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Precedence: `kwargs` override `values`; `values` override model defaults. Each `
4545
- generate(): Chooses `TimeseriesBackend` or `VllmBackend`, initializes, generates.
4646
- evaluate(): Builds `Evaluator`, compiles `results` via `make_nss_results`.
4747

48-
`run()` calls `process_data().train().generate().evaluate()`.
48+
`run()` calls `process_data().train().generate().evaluate()` then `save_results()`. Stepwise callers must invoke `save_results()` themselves.
4949

5050
## Gotchas
5151

src/nemo_safe_synthesizer/sdk/library_builder.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class SafeSynthesizer(ConfigBuilder):
9191
9292
builder = SafeSynthesizer().with_data_source(df)
9393
builder.process_data().train().generate().evaluate()
94+
builder.save_results()
9495
results = builder.results
9596
9697
Args:
@@ -439,10 +440,16 @@ def evaluate(self) -> SafeSynthesizer:
439440
)
440441
return self
441442

442-
def run(self) -> None:
443-
"""Run the full pipeline: ``process_data`` -> ``train`` -> ``generate`` -> ``evaluate``.
443+
def run(self, output_file: Path | str | None = None) -> None:
444+
"""Run the full pipeline and save results.
444445
445-
For step-by-step control, call the individual methods instead.
446+
Executes ``process_data`` -> ``train`` -> ``generate`` ->
447+
``evaluate`` -> ``save_results``. For step-by-step control,
448+
call the individual methods instead.
449+
450+
Args:
451+
output_file: Explicit output path for the synthetic data CSV.
452+
Falls back to ``workdir.output_file`` when ``None``.
446453
447454
Raises:
448455
RuntimeError: If called after ``load_from_save_path()``.
@@ -460,11 +467,15 @@ def run(self) -> None:
460467
assert isinstance(self._data_source, pd.DataFrame)
461468

462469
self.process_data().train().generate().evaluate()
470+
self.save_results(output_file=output_file)
463471

464472
@traced("SafeSynthesizer.save_results", category=LogCategory.RUNTIME, level="INFO")
465473
def save_results(self, output_file: Path | str | None = None) -> None:
466474
"""Save synthetic data CSV and evaluation report HTML to the workdir.
467475
476+
Called automatically by ``run()``. Call explicitly after
477+
stepwise execution (``process_data().train().generate().evaluate()``).
478+
468479
Args:
469480
output_file: Explicit output path for the CSV. Falls back
470481
to ``workdir.output_file`` when ``None``.
@@ -473,7 +484,6 @@ def save_results(self, output_file: Path | str | None = None) -> None:
473484
assert self.results is not None
474485
assert isinstance(self.results.synthetic_data, pd.DataFrame)
475486

476-
# Determine output file path for synthetic data
477487
match output_file:
478488
case Path() as p:
479489
output_file = p
@@ -482,12 +492,10 @@ def save_results(self, output_file: Path | str | None = None) -> None:
482492
case _:
483493
output_file = self._workdir.output_file
484494

485-
# Save synthetic data CSV
486495
output_file.parent.mkdir(parents=True, exist_ok=True)
487496
self.results.synthetic_data.to_csv(str(output_file), index=False)
488497
logger.info(f"Saved synthetic data to {output_file}")
489498

490-
# Save evaluation report HTML if available
491499
if self.results.evaluation_report_html:
492500
report_path = self._workdir.evaluation_report
493501
report_path.parent.mkdir(parents=True, exist_ok=True)

tests/cli/test_run.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_run_uses_custom_output_file(
131131
fixture_session_cache_dir: Path,
132132
patched_run_dependencies: dict,
133133
):
134-
"""Verify that --output-file overrides default workdir output."""
134+
"""Verify that --output-file is forwarded to run()."""
135135
custom_output = tmp_path / "custom_output.csv"
136136

137137
result = cli_runner.invoke(
@@ -147,22 +147,18 @@ def test_run_uses_custom_output_file(
147147
catch_exceptions=False,
148148
)
149149

150-
# Verify save_results was called with the custom output file
151150
assert result.exit_code == 0
152151
mock_ss = patched_run_dependencies["safe_synthesizer"]
153-
mock_ss.save_results.assert_called_once()
154-
actual_output_path = mock_ss.save_results.call_args.kwargs.get("output_file")
155-
assert str(actual_output_path) == str(custom_output)
152+
mock_ss.run.assert_called_once_with(output_file=str(custom_output))
156153

157-
def test_run_uses_workdir_output_when_no_override(
154+
def test_run_without_output_file_passes_none(
158155
self,
159156
cli_runner: CliRunner,
160157
dummy_csv: Path,
161158
fixture_session_cache_dir: Path,
162-
mock_workdir: MagicMock,
163159
patched_run_dependencies: dict,
164160
):
165-
"""Verify that workdir.output_file is used when --output-file is not provided."""
161+
"""Without --output-file, run() is called with output_file=None."""
166162
result = cli_runner.invoke(
167163
run,
168164
[
@@ -174,12 +170,10 @@ def test_run_uses_workdir_output_when_no_override(
174170
catch_exceptions=False,
175171
)
176172

177-
# Verify save_results was called with the workdir's default output file
178173
assert result.exit_code == 0
179174
mock_ss = patched_run_dependencies["safe_synthesizer"]
180-
mock_ss.save_results.assert_called_once()
181-
actual_output_path = mock_ss.save_results.call_args.kwargs.get("output_file")
182-
assert str(actual_output_path) == str(mock_workdir.output_file)
175+
# Default output path is used if no --output-file is provided
176+
mock_ss.run.assert_called_once_with(output_file=None)
183177

184178

185179
class TestPathOptions:

tests/sdk/test_builder.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from pathlib import Path
5+
from unittest.mock import MagicMock
6+
47
import pandas as pd
58
import pytest
69

@@ -12,6 +15,7 @@
1215
from nemo_safe_synthesizer.sdk.library_builder import SafeSynthesizer
1316

1417
_SMALL_DF = pd.DataFrame({"a": [1, 2, 3]})
18+
_REPORT_HTML = "<html><body>report</body></html>"
1519

1620
PATCH_PREFIX = "nemo_safe_synthesizer.sdk.builder"
1721

@@ -355,3 +359,49 @@ def test_with_replace_pii_reenable_after_disable():
355359
._nss_config
356360
)
357361
assert config.replace_pii is not None
362+
363+
364+
def _builder_with_mock_results(tmp_path: Path) -> SafeSynthesizer:
365+
"""Create a SafeSynthesizer with mocked results for save_results testing."""
366+
nss = SafeSynthesizer(save_path=tmp_path / "artifacts")
367+
nss.results = MagicMock()
368+
nss.results.synthetic_data = _SMALL_DF
369+
nss.results.evaluation_report_html = _REPORT_HTML
370+
return nss
371+
372+
373+
class TestSaveResults:
374+
"""Verify save_results persists CSV and HTML to the expected paths."""
375+
376+
def test_saves_to_default_workdir(self, tmp_path: Path):
377+
nss = _builder_with_mock_results(tmp_path)
378+
379+
nss.save_results()
380+
381+
csv_path = nss._workdir.output_file
382+
report_path = nss._workdir.evaluation_report
383+
assert csv_path.exists()
384+
assert report_path.exists()
385+
assert pd.read_csv(csv_path).equals(_SMALL_DF)
386+
assert report_path.read_text() == _REPORT_HTML
387+
388+
def test_output_file_override_writes_csv_to_custom_path(self, tmp_path: Path):
389+
nss = _builder_with_mock_results(tmp_path)
390+
custom_csv = tmp_path / "custom" / "output.csv"
391+
392+
nss.save_results(output_file=custom_csv)
393+
394+
assert custom_csv.exists()
395+
assert pd.read_csv(custom_csv).equals(_SMALL_DF)
396+
# Report still goes to the workdir regardless of output_file
397+
assert nss._workdir.evaluation_report.exists()
398+
assert nss._workdir.evaluation_report.read_text() == _REPORT_HTML
399+
400+
def test_skips_report_when_html_is_none(self, tmp_path: Path):
401+
nss = _builder_with_mock_results(tmp_path)
402+
nss.results.evaluation_report_html = None
403+
404+
nss.save_results()
405+
406+
assert nss._workdir.output_file.exists()
407+
assert not nss._workdir.evaluation_report.exists()

0 commit comments

Comments
 (0)