Skip to content

Commit 27877c6

Browse files
nina-xuNina Xu
andauthored
perf(config): change default max_sequences_per_example to 10 when dp is off (#78)
<!-- Thank you for contributing to Safe Synthesizer! --> # Summary Change the default max_sequences_per_example to 10 when dp is off. This is from our experiment findings that this gives good performance and decent runtime. ## Pre-Review Checklist <!-- These checks should be completed before a PR is reviewed, --> <!-- but you can submit a draft early to indicate that the issue is being worked on. --> Ensure that the following pass: - [x] `make format && make lint` or via prek validation. - [x] `make test` passes locally - [ ] `make e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [x] New or updated tests for any fix or new behavior - [x] Updated documentation for new features and behaviors, including docstrings for API docs. ## Other Notes <!-- Please add the issue number that should be closed when this PR is merged. --> - Closes #73 --------- Signed-off-by: Nina Xu <ninaxu@cs-oci-ord-vscode-01.cm.cluster> Co-authored-by: Nina Xu <ninaxu@cs-oci-ord-vscode-01.cm.cluster>
1 parent ab128f1 commit 27877c6

4 files changed

Lines changed: 47 additions & 46 deletions

File tree

src/nemo_safe_synthesizer/config/autoconfig.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import math
1212
import re
1313
from dataclasses import dataclass
14-
from typing import TYPE_CHECKING, Any, Callable, Literal
14+
from typing import TYPE_CHECKING, Any, Callable
1515

1616
import pandas as pd
1717
from pydantic import GetCoreSchemaHandler
@@ -228,24 +228,39 @@ def _determine_delta(self) -> dict[str, float]:
228228
)
229229
return {"delta": d}
230230

231-
def _determine_max_sequences_per_example(self) -> dict[str, Literal["auto"] | int | None]:
231+
def _determine_max_sequences_per_example(self) -> dict[str, int | None]:
232232
"""
233233
Determine max_sequences_per_example if set to auto.
234234
235235
Returns:
236-
Dict with max_sequences_per_example if auto-determined, empty dict otherwise.
236+
Dict with max_sequences_per_example resolved to a concrete value:
237+
1 if DP is enabled, 10 if auto with DP disabled, or the
238+
explicit value (int) if manually specified, or None if not specified.
237239
"""
238240
if self._dp_enabled is True:
239-
logger.info(
240-
"Parameter `max_sequences_per_example` was automatically set "
241-
"to 1 based on the use of differential privacy."
242-
)
241+
if self._config.data.max_sequences_per_example in [None, AUTO_STR, 1]:
242+
logger.info(
243+
"Parameter `max_sequences_per_example` was automatically set "
244+
"to 1 based on the use of differential privacy."
245+
)
246+
else:
247+
logger.info(
248+
"Parameter `max_sequences_per_example` does not allow the value of "
249+
"{self._config.data.max_sequences_per_example} when DP is enabled. Setting to 1 instead."
250+
)
243251
return {"max_sequences_per_example": 1}
244252
elif self._config.data.max_sequences_per_example != AUTO_STR:
253+
if self._config.data.max_sequences_per_example is None:
254+
logger.info(
255+
"Parameter `max_sequences_per_example` is not specified, so each example will fill up the context window."
256+
)
245257
return {"max_sequences_per_example": self._config.data.max_sequences_per_example}
246258

247259
else:
248-
return {"max_sequences_per_example": None}
260+
logger.info(
261+
"Parameter `max_sequences_per_example` was automatically set to 10 for best performance/efficiency."
262+
)
263+
return {"max_sequences_per_example": 10}
249264

250265
def _build_updated_params(
251266
self,
@@ -303,7 +318,7 @@ def resolve(self) -> SafeSynthesizerParameters:
303318
class AutoParamsValidator:
304319
value_func: Callable[[Any], bool]
305320

306-
def validate(self, value):
321+
def validate(self, value, _info):
307322
if isinstance(value, str) and value == "auto":
308323
return value
309324
elif self.value_func(value):

src/nemo_safe_synthesizer/config/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ class DataParameters(Parameters):
7676
OptionalAutoInt,
7777
Field(
7878
description=(
79-
"If specified, adds at most this number of sequences per example; "
80-
"otherwise, fills up context. Supports 'auto' where a value of 1 is "
81-
"chosen if differential privacy is enabled, and None otherwise. "
82-
"Required for DP to limit contribution of each example."
79+
"If specified, adds at most this number of sequences per example. "
80+
"Supports 'auto' where a value of 1 is chosen if differential privacy is "
81+
"enabled, and 10 otherwise. If not specified or set to 'auto', fills up "
82+
"context. Required for DP to limit contribution of each example."
8383
),
8484
),
8585
] = AUTO_STR

src/nemo_safe_synthesizer/config/parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def check_dp_compatibility(
8888

8989
if not dp_params.dp_enabled:
9090
if data.max_sequences_per_example is not None and data.max_sequences_per_example == AUTO_STR:
91-
logger.debug("setting max_sequences_per_example to None because DP is disabled")
92-
data.max_sequences_per_example = None
91+
logger.debug("setting max_sequences_per_example to the default of 10 because DP is disabled")
92+
data.max_sequences_per_example = 10
9393
return dp_params
9494

9595
match data.max_sequences_per_example:

tests/config/test_autoconfig.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_config(self) -> SafeSynthesizerParameters:
109109
num_input_records_to_sample=None, # Will be auto-resolved
110110
delta=None, # Not used (DP disabled)
111111
dp_enabled=False,
112-
max_seq=None, # "auto" with no DP -> None
112+
max_seq=10, # "auto" with no DP -> 10
113113
),
114114
)
115115

@@ -134,27 +134,6 @@ def get_config(self) -> SafeSynthesizerParameters:
134134
),
135135
)
136136

137-
AUTO_WITH_DP_NULL_MAX_SEQ = AutoConfigTestCase(
138-
name="auto_with_dp_null_max_seq",
139-
config=SafeSynthesizerParameters(
140-
training=TrainingHyperparams(
141-
rope_scaling_factor="auto",
142-
num_input_records_to_sample="auto",
143-
use_unsloth="auto",
144-
),
145-
data=DataParameters(max_sequences_per_example=None), # Explicit None
146-
privacy=DifferentialPrivacyHyperparams(dp_enabled=True, delta="auto"),
147-
),
148-
expected=Expected(
149-
use_unsloth=False, # "auto" resolves to False when DP enabled
150-
rope_scaling_factor=None, # Will be auto-resolved to an int
151-
num_input_records_to_sample=None, # Will be auto-resolved
152-
delta=None, # Will be auto-resolved based on data size
153-
dp_enabled=True,
154-
max_seq=1, # DP enabled -> always 1 (even with None input)
155-
),
156-
)
157-
158137
EXPLICIT = AutoConfigTestCase(
159138
name="explicit",
160139
config=SafeSynthesizerParameters(
@@ -205,7 +184,6 @@ def get_config(self) -> SafeSynthesizerParameters:
205184
ALL_TEST_CASES: list[AutoConfigTestCase] = [
206185
AUTO_NO_DP,
207186
AUTO_WITH_DP,
208-
AUTO_WITH_DP_NULL_MAX_SEQ,
209187
EXPLICIT,
210188
DP_WITH_UNSLOTH_TRUE,
211189
]
@@ -321,15 +299,23 @@ def test_determine_delta(self, data_size, test_case, config, expected):
321299
else:
322300
assert result == {}
323301

324-
def test_determine_max_sequences_per_example(self, sample_data, config, expected):
325-
"""Max sequences should be 1 for DP, None for non-DP auto, or explicit value."""
326-
resolver = AutoConfigResolver(sample_data, config)
327-
328-
# Verify validation already resolved the config value
329-
assert config.data.max_sequences_per_example == expected.max_seq
330-
302+
@pytest.mark.parametrize(
303+
"max_seq_input, expected_max_seq",
304+
[
305+
pytest.param("auto", [1, 10], id="auto_max_seq"), # dp_enabled=True -> 1, dp_enabled=False -> 10
306+
pytest.param(5, [1, 5], id="explicit_max_seq"), # dp_enabled=True -> 1, dp_enabled=False -> 5
307+
pytest.param(None, [1, None], id="none_max_seq"), # dp_enabled=True -> 1, dp_enabled=False -> None
308+
],
309+
)
310+
def test_determine_max_sequences_per_example(self, sample_data, config, max_seq_input, expected_max_seq):
311+
"""Max sequences should be 1 for DP regardless of input; for non-DP, auto -> 10, explicit -> explicit value, None -> None."""
312+
config_copy = config.model_copy(deep=True)
313+
config_copy.data.max_sequences_per_example = max_seq_input
314+
resolver = AutoConfigResolver(sample_data, config_copy)
331315
result = resolver._determine_max_sequences_per_example()
332-
assert result == {"max_sequences_per_example": expected.max_seq}
316+
317+
expected_index = 0 if config_copy.privacy.dp_enabled else 1
318+
assert result == {"max_sequences_per_example": expected_max_seq[expected_index]}
333319

334320
def test_resolve(self, sample_data, config, expected):
335321
"""Full resolution should produce valid SafeSynthesizerParameters.

0 commit comments

Comments
 (0)