Skip to content

Commit 334bd1c

Browse files
authored
docs: add/update docstrings for holdout and root files (#170)
<!-- SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. --> <!-- SPDX-License-Identifier: Apache-2.0 --> <!-- Thank you for contributing to Safe Synthesizer! --> # Summary <!-- Brief description of changes --> Added/audited docstrings for all the files in `src/nemo_safe_synthesizer/` root and in `holdout/`. Had agent check for style guide and correctness ## Pre-Review Checklist <!-- These checks should be completed before a PR is reviewed, --> <!-- but you can submit a draft early to indicate that the issue is being worked on. --> Ensure that the following pass: - [x] `make format && make check` or via prek validation. - [x] `make test` passes locally - [ ] `make test-e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [ ] New or updated tests for any fix or new behavior - [x] Updated documentation for new features and behaviors, including docstrings for API docs. ## Other Notes <!-- Please add the issue number that should be closed when this PR is merged. --> - Closes #<issue> Signed-off-by: Nina Xu <19981858+nina-xu@users.noreply.github.com>
1 parent 593c026 commit 334bd1c

7 files changed

Lines changed: 273 additions & 133 deletions

File tree

src/nemo_safe_synthesizer/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""Entry point for ``python -m nemo_safe_synthesizer``."""
5+
46
import os
57
import sys
68

src/nemo_safe_synthesizer/defaults.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""Collection of default settings for the Nemo Safe Synthesizer implementation."""
4+
"""Default constants for the Safe Synthesizer pipeline.
5+
6+
Includes artifact paths, logging parameters, training and generation
7+
defaults, prompt templates, and miscellaneous constants shared across
8+
modules.
9+
"""
510

611
from pathlib import Path
712

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,62 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""Custom error classes"""
4+
"""Error hierarchy for Safe Synthesizer.
5+
6+
All public exceptions inherit from ``SafeSynthesizerError``. Errors on
7+
the user side (bad data, bad config, generation failure) inherit from
8+
``UserError`` and a matching built-in (``ValueError``, ``RuntimeError``)
9+
so callers can catch either.
10+
11+
Classes:
12+
SafeSynthesizerError: Base for all known errors.
13+
UserError: Invalid usage (bad inputs, uninitialized state).
14+
InternalError: Library bug (equivalent to HTTP 5xx).
15+
DataError: Problems with training data (NaNs, unsupported types).
16+
ParameterError: Invalid config or parameter input.
17+
GenerationError: Sampling/generation failures.
18+
"""
519

620

721
class SafeSynthesizerError(Exception):
8-
"""Base class for all known errors that can be thrown from nemo_safe_synthesizer."""
22+
"""Base class for all known Safe Synthesizer errors."""
923

1024

1125
class UserError(SafeSynthesizerError):
12-
"""
13-
Base class for errors that are caused by invalid usage. This is usually caused
14-
by invalid input parameters, by calling methods on a class that is not initialized,
15-
etc.
16-
If you are receiving this error, please see documentation of the corresponding
17-
classes and check your inputs.
26+
"""Invalid usage -- bad input parameters, uninitialized state, etc.
27+
28+
If you receive this error, check the documentation for the corresponding
29+
class and verify your inputs.
1830
"""
1931

2032

2133
class InternalError(SafeSynthesizerError, RuntimeError):
22-
"""
23-
Error that indicate invalid internal state.
24-
25-
If you're using safe_synthesizer through documented interfaces, this usually
26-
indicates a bug in safe_synthesizer itself.
27-
If you're using not documented interfaces, this could indicate invalid usage.
34+
"""Invalid internal state indicating a bug in Safe Synthesizer.
2835
29-
This class of errors is equivalent to 5xx status codes in HTTP protocol.
36+
When using documented interfaces this usually indicates a library bug.
37+
When using undocumented interfaces it may indicate invalid usage.
38+
Equivalent to HTTP 5xx status codes.
3039
"""
3140

3241

3342
class DataError(UserError, ValueError):
34-
"""
35-
Represents problems with training data before work is actually attempted.
36-
For example: data contains values that are not supported by the model that is
37-
being used: infinity, too many NaNs, nested data, etc.
43+
"""Problems with training data before work is attempted.
44+
45+
Examples: data contains infinity, too many NaNs, nested structures,
46+
or types unsupported by the model.
3847
"""
3948

4049

4150
class ParameterError(UserError, ValueError):
42-
"""
43-
Represents errors with configurations or parameter input to user-facing methods.
44-
For example: config referencing column that is not present in the data.
51+
"""Invalid configuration or parameter input to user-facing methods.
52+
53+
Examples: config references a column not present in the data, invalid
54+
combination of parameters.
4555
"""
4656

4757

4858
class GenerationError(UserError, RuntimeError):
49-
"""
50-
Represents errors happening during sampling/generation.
51-
For example: rejection sampling fails, invalid record threshold met.
59+
"""Errors during sampling or generation.
60+
61+
Examples: rejection sampling fails, invalid record threshold met.
5262
"""

src/nemo_safe_synthesizer/holdout/holdout.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""Train/test splitting for evaluation holdout.
5+
6+
Provides two splitting strategies -- naive (random) and grouped (preserving
7+
group membership) -- and a ``Holdout`` class that selects the appropriate
8+
strategy based on pipeline configuration.
9+
"""
10+
411
import pandas as pd
512
from sklearn.model_selection import GroupShuffleSplit, train_test_split
613

@@ -23,6 +30,20 @@
2330

2431

2532
def naive_train_test_split(df, test_size, random_state=None) -> DataFrameOptionalTuple:
33+
"""Split a dataframe into train and test sets with a random shuffle.
34+
35+
Thin wrapper around ``sklearn.model_selection.train_test_split`` that
36+
resets the index on both resulting dataframes.
37+
38+
Args:
39+
df: Input dataframe to split.
40+
test_size: Number of rows (int) or fraction (float) to hold out.
41+
random_state: Seed for reproducibility.
42+
43+
Returns:
44+
Tuple of ``(train_df, test_df)``, or ``(train_df, None)`` if the
45+
split produces no test set.
46+
"""
2647
train, test = train_test_split(df, test_size=test_size, random_state=random_state)
2748
if test is None:
2849
return train, None
@@ -31,6 +52,26 @@ def naive_train_test_split(df, test_size, random_state=None) -> DataFrameOptiona
3152

3253

3354
def grouped_train_test_split(df, test_size, group_by, random_state=None) -> DataFrameOptionalTuple:
55+
"""Split a dataframe so that all rows sharing a group stay in the same fold.
56+
57+
Uses ``GroupShuffleSplit`` with 20 candidate splits and picks the one
58+
whose test-set size is closest to the requested ``test_size``. If
59+
``test_size`` exceeds the number of groups, equals 0, or equals 1, it
60+
falls back to ``DEFAULT_HOLDOUT``.
61+
62+
Args:
63+
df: Input dataframe to split.
64+
test_size: Desired number of test rows (int) or fraction (float).
65+
group_by: Column name whose values define the groups.
66+
random_state: Seed for reproducibility.
67+
68+
Returns:
69+
Tuple of ``(train_df, test_df)``, or ``(df, None)`` if no valid
70+
grouped split could be produced.
71+
72+
Raises:
73+
ValueError: If the ``group_by`` column contains missing values.
74+
"""
3475
# Do not continue the split process if the groupby column has missing values.
3576
if df[group_by].isna().any():
3677
msg = f"Group by column '{group_by}' has missing values. Please remove/replace them."
@@ -62,13 +103,49 @@ def grouped_train_test_split(df, test_size, group_by, random_state=None) -> Data
62103

63104

64105
class Holdout:
106+
"""Config-driven train/test splitter for the evaluation holdout set.
107+
108+
Reads holdout parameters from the pipeline configuration and delegates
109+
to either ``naive_train_test_split`` or ``grouped_train_test_split``
110+
depending on whether a ``group_training_examples_by`` column is set.
111+
112+
The holdout size is resolved as follows:
113+
114+
- If ``holdout < 1.0``, it is treated as a fraction of the input rows.
115+
- If ``holdout >= 1.0``, it is treated as an absolute row count.
116+
- The result is clamped to ``max_holdout`` and must be at least
117+
``MIN_HOLDOUT`` rows.
118+
119+
Args:
120+
config: Pipeline parameters providing ``holdout``, ``max_holdout``,
121+
``group_training_examples_by``, and ``random_state``.
122+
"""
123+
65124
def __init__(self, config: SafeSynthesizerParameters):
66125
self.holdout = config.get("holdout")
67126
self.max_holdout = config.get("max_holdout")
68127
self.group_by = config.get("group_training_examples_by")
69128
self.random_state = config.get("random_state")
70129

71130
def train_test_split(self, input_df: pd.DataFrame) -> DataFrameOptionalTuple:
131+
"""Split the input dataframe into training and holdout test sets.
132+
133+
Returns the full dataframe with no test set when holdout is disabled
134+
(``holdout == 0`` or ``max_holdout == 0``).
135+
136+
Args:
137+
input_df: The full input dataframe to split.
138+
139+
Returns:
140+
Tuple of ``(train_df, test_df)``, or ``(train_df, None)`` when
141+
holdout is disabled or the grouped split fails.
142+
143+
Raises:
144+
ValueError: If the input dataframe has fewer than
145+
``MIN_RECORDS_FOR_TEXT_AND_PRIVACY_METRICS`` rows, if the
146+
computed holdout is smaller than ``MIN_HOLDOUT``, or if
147+
the ``group_by`` column contains missing values.
148+
"""
72149
if self.holdout == 0 or self.max_holdout == 0:
73150
return input_df, None
74151

0 commit comments

Comments
 (0)