Skip to content

Commit 851165f

Browse files
authored
fix(holdout): setting test size to 0 for holdout (#275)
<!-- SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. --> <!-- SPDX-License-Identifier: Apache-2.0 --> <!-- Thank you for contributing to Safe Synthesizer! --> # Summary Added an early return in naive_train_test_split when test_size == 0, bypassing sklearn's train_test_split which rejects 0 as an invalid value. Also added test for `naive_train_test_split` ## Pre-Review Checklist <!-- These checks should be completed before a PR is reviewed, --> <!-- but you can submit a draft early to indicate that the issue is being worked on. --> Ensure that the following pass: - [x] `make format && make check` or via prek validation. - [x] `make test` passes locally - [ ] `make test-e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) - [ ] GPU CI status check passes -- comment `/sync` on this PR to trigger a run (auto-triggers on ready-for-review) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [ ] New or updated tests for any fix or new behavior - [ ] Updated documentation for new features and behaviors, including docstrings for API docs. ## Other Notes <!-- Please add the issue number that should be closed when this PR is merged. --> - Closes #172 Signed-off-by: Sean Yang <seayang@nvidia.com>
1 parent 1554f66 commit 851165f

2 files changed

Lines changed: 10 additions & 0 deletions

File tree

src/nemo_safe_synthesizer/holdout/holdout.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def naive_train_test_split(df, test_size, random_state=None) -> DataFrameOptiona
4444
Tuple of ``(train_df, test_df)``, or ``(train_df, None)`` if the
4545
split produces no test set.
4646
"""
47+
if test_size == 0:
48+
return df.reset_index(drop=True), None
49+
4750
train, test = train_test_split(df, test_size=test_size, random_state=random_state)
4851
if test is None:
4952
return train, None

tests/holdout/test_holdout.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
HOLDOUT_TOO_SMALL_ERROR,
1111
INPUT_DATA_TOO_SMALL_ERROR,
1212
Holdout,
13+
naive_train_test_split,
1314
)
1415

1516

@@ -88,6 +89,12 @@ def test_zero_max_holdout(df):
8889
assert test is None
8990

9091

92+
def test_naive_train_test_split_zero_int_test_size(df):
93+
train, test = naive_train_test_split(df, test_size=0)
94+
assert len(train) == 200
95+
assert test is None
96+
97+
9198
def test_does_group_by_holdout(df):
9299
holdout = Holdout(SafeSynthesizerParameters.from_params(group_training_examples_by="big_cat"))
93100
train, test = holdout.train_test_split(df)

0 commit comments

Comments
 (0)