Skip to content

Commit 386004f

Browse files
docs: add pydoc for data_processing (#208)
Ensure that the following pass: - [x] `make format && make check` or via prek validation. - [ ] `make test` passes locally - [ ] `make test-e2e` passes locally - [ ] `make test-ci-container` passes locally (recommended) ## Pre-Merge Checklist <!-- These checks need to be completed before a PR is merged, --> <!-- but as PRs often change significantly during review, --> <!-- it's OK for them to be incomplete when review is first requested. --> - [ ] New or updated tests for any fix or new behavior - [ ] Updated documentation for new features and behaviors, including docstrings for API docs. ## Other Notes <!-- Please add the issue number that should be closed when this PR is merged. --> - Closes #<issue> --------- Signed-off-by: Sean Yang <seayang@nvidia.com> Signed-off-by: seayang-nv <seayang@nvidia.com> Co-authored-by: Kendrick Boyd <kendrickb@nvidia.com>
1 parent fded14a commit 386004f

12 files changed

Lines changed: 672 additions & 342 deletions

File tree

src/nemo_safe_synthesizer/data_processing/actions/data_actions.py

Lines changed: 108 additions & 121 deletions
Large diffs are not rendered by default.

src/nemo_safe_synthesizer/data_processing/actions/dates.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""Date string parsing, formatting, and inference utilities.
5+
6+
Supports ISO8601 timezone offsets (via ``strftime_extra`` / ``strptime_extra``),
7+
permutation-based format inference (``parse_date``, ``infer_from_series``),
8+
and date randomization for PII replacement (``randomize``).
9+
"""
10+
411
import itertools
512
import re
613
from collections import Counter
@@ -126,15 +133,16 @@ def strptime_extra(date_string: str, fmt: str) -> datetime:
126133

127134

128135
def date_component_permutations() -> list[tuple[str, str, str, str, str]]:
129-
"""Returns a list of string formats by component type. Each permutation is
130-
indexed by y, m, d, hms, tz and can be passed into component formatter from
131-
``date_component_orders``.
136+
"""Return the Cartesian product of per-component format strings.
137+
138+
Each tuple is indexed by (year, month, day, hms, tz) and can be
139+
passed into a formatter from ``date_component_orders``.
132140
"""
133141
return list(itertools.product(*component_formats.values())) # type:ignore
134142

135143

136144
def gen_date_str_fmt_permutations() -> set[str]:
137-
"""Returns a list of unique date string format permutations"""
145+
"""Return the set of all unique date format permutations."""
138146
return {order(*str_fmt) for str_fmt in date_component_permutations() for order in date_component_orders}
139147

140148

@@ -284,6 +292,7 @@ def tokenize_date_str(input: str) -> TokenizedStr:
284292

285293

286294
def maybe_match(date, format) -> Optional[datetime]:
295+
"""Attempt to parse ``date`` with ``format``, returning None on failure."""
287296
try:
288297
return strptime_extra(date, format)
289298
except ValueError:
@@ -294,13 +303,15 @@ def parse_date(
294303
input_date: str,
295304
date_str_fmts: list[str] | set[str] = date_str_fmt_permutations,
296305
) -> Optional[ParsedDate]:
306+
"""Parse a date string and return the first matching ``ParsedDate``, or None."""
297307
return next(parse_date_multiple(input_date, date_str_fmts), None)
298308

299309

300310
def parse_date_multiple(
301311
input_date: str,
302312
date_str_fmts: list[str] | set[str] = date_str_fmt_permutations,
303313
) -> Iterator[ParsedDate]:
314+
"""Yield all valid ``ParsedDate`` interpretations of ``input_date`` across known formats."""
304315
tokenized_date = tokenize_date_str(input_date)
305316

306317
for str_fmt in date_str_fmts:
@@ -335,28 +346,31 @@ def randomize(date: str, days: int) -> Optional[str]:
335346

336347

337348
def d_str_to_fmt_multiple(input_date: str) -> Iterator[str]:
338-
"""Infers all likely date format from a date string."""
349+
"""Yield all plausible ``strftime`` format strings for a date string."""
339350
for parsed_date in parse_date_multiple(input_date):
340351
yield parsed_date.fmt_str
341352

342353

343354
def maybe_d_str_to_fmt_multiple(input_date: str) -> Iterator[str]:
344-
"""Infers all likely date format from a date string or nothing."""
355+
"""Like ``d_str_to_fmt_multiple`` but silently yields nothing on ``ValueError``."""
345356
try:
346357
yield from d_str_to_fmt_multiple(input_date)
347358
except ValueError:
348359
pass
349360

350361

351362
def d_str_to_fmt(input_date: str) -> Optional[str]:
352-
"""Infers a date format from a date string."""
363+
"""Infer the most likely ``strftime`` format string for a date string, or None."""
353364
return next(d_str_to_fmt_multiple(input_date), None)
354365

355366

356367
def infer_from_series(date_series: Iterable[str]) -> Optional[str]:
357-
"""An inference on a single date string isn't always perfect. Sometimes we mix
358-
up format likes %m and %d. ``infer_from_series`` will evaluate a series of dates
359-
and return the best date format for the series.
368+
"""Infer the best ``strftime`` format for a series of date strings.
369+
370+
Evaluates each date against all known format permutations and returns
371+
the most frequently matched format. This is more reliable than
372+
single-string inference, which can confuse ambiguous components like
373+
``%m`` and ``%d``.
360374
"""
361375
fmt_occurrences = Counter()
362376
for date in date_series:
@@ -371,6 +385,21 @@ def fit_and_transform_dates(
371385
df: pd.DataFrame,
372386
inplace: bool = False,
373387
) -> tuple[dict[str, dict[str, str]], pd.DataFrame]:
388+
"""Detect date columns, convert them to elapsed seconds, and record the transformation.
389+
390+
For each object-typed column, samples values to infer a date format. If
391+
successful, converts the column to seconds elapsed since the column minimum
392+
and records the format and min date for later reversal.
393+
394+
Args:
395+
df: Input DataFrame.
396+
inplace: If True, mutate ``df`` directly instead of copying.
397+
398+
Returns:
399+
A tuple of (date_min_dict, result_df). ``date_min_dict`` maps column
400+
names to ``{"format": ..., "min": ...}`` dicts needed by
401+
``transform_dates`` for reversal.
402+
"""
374403
date_min_dict = {}
375404
object_cols = [col for col, col_type in df.dtypes.iteritems() if col_type == "object"]
376405
result_df = df.copy() if not inplace else df
@@ -396,6 +425,16 @@ def fit_and_transform_dates(
396425

397426

398427
def transform_dates(dates: dict[str, dict[str, str]], df: pd.DataFrame) -> pd.DataFrame:
428+
"""Apply a previously fitted date-to-seconds transformation to a DataFrame.
429+
430+
Args:
431+
dates: Mapping from column names to ``{"format": ..., "min": ...}``
432+
dicts as returned by ``fit_and_transform_dates``.
433+
df: DataFrame to transform.
434+
435+
Returns:
436+
A copy of ``df`` with date columns converted to elapsed seconds.
437+
"""
399438
result_df = df.copy()
400439
for col, details in dates.items():
401440
_dates = pd.to_datetime(result_df[col], format=details["format"], errors="coerce")

src/nemo_safe_synthesizer/data_processing/actions/distributions.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""Statistical distribution models for sampling numeric and datetime values.
5+
6+
Provides ``Distribution`` (float-valued) and ``DatetimeDistribution``
7+
hierarchies, each with Gaussian and Uniform concrete implementations.
8+
Pydantic discriminated unions (``DistributionT``, ``DatetimeDistributionT``)
9+
allow YAML/JSON configs to select the distribution type via ``distribution_type``.
10+
"""
11+
412
from __future__ import annotations
513

614
from abc import ABC, abstractmethod
@@ -13,30 +21,24 @@
1321

1422

1523
class Distribution(BaseModel, ABC):
16-
"""
17-
Abstract base class representing a distribution.
18-
Child classes should specify whichever arguments are needed
19-
to properly parametrize their distribution.
24+
"""Abstract base for float-valued distributions.
25+
26+
Subclasses specify the parameters needed to define their distribution
27+
and implement ``sample`` to draw values.
2028
"""
2129

2230
@abstractmethod
2331
def sample(self, num_records: int) -> list[Any]: ...
2432

2533

2634
class DatetimeDistribution(BaseModel, ABC):
27-
"""
28-
This class is separate from the `Distribution` ABC above
29-
because datetimes need slightly different handling than floats.
30-
Providing this separate class hierarchy also makes it easier
31-
in pydantic to specify what datatypes we expect in the distribution
32-
parameters (float vs datetime), as well as dt-specific arguments.
33-
34-
In practice, this means creating a "copy" `DatetimeDistribution`
35-
for each regular `Distribution` where it makes sense. We could probably
36-
automate some of this with generics, but IMO that'd just make it confusing
37-
to read. We're still able to reuse the original `Distribution` class most
38-
of the time in `DatetimeDistribution`, making the only business logic
39-
really be about how we want to translate dates --> floats.
35+
"""Abstract base for datetime-valued distributions.
36+
37+
Separate from ``Distribution`` because datetime parameters (``datetime``,
38+
``timedelta``) differ from floats, and pydantic validation benefits from
39+
distinct type hierarchies. Subclasses implement ``sample_datetimes`` to
40+
produce raw datetime samples; universal post-processing (rounding via
41+
``precision``, formatting via ``format``) is applied by ``sample``.
4042
"""
4143

4244
precision: Optional[timedelta] = None

src/nemo_safe_synthesizer/data_processing/actions/utils.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
"""Shared utilities for the data actions framework.
5+
6+
Provides ``ActionCtx`` (execution context with state and dependency injection),
7+
``TransformsUtil`` (wrapper around the transforms_v2 engine), helper types
8+
(``MetadataColumns``, ``TransformsUpdate``), and subclass-discovery functions.
9+
"""
10+
411
from __future__ import annotations
512

613
import inspect
@@ -35,24 +42,25 @@
3542

3643

3744
def type_alias_fn(field_name: str) -> str:
38-
"""
39-
This alias fn allows `type_` to be parsed as `type` from config yaml. We use `type_`
40-
in the actual python objects so it doesn't conflict with the python builtin `type()`.
41-
"""
45+
"""Pydantic alias generator that maps ``type_`` to ``type`` for YAML compatibility."""
4246
if field_name == "type_":
4347
return "type"
4448

4549
return field_name
4650

4751

4852
class MetadataColumns(StrEnum):
49-
INDEX = "__gretel__idx" # used in validation to maintain a mapping to pre-transformed records
50-
REJECT_REASON = (
51-
"__gretel_reject_reason" # used in validation to attach model_metadata about why the row was rejected
52-
)
53+
"""Internal column names injected during validation phases."""
54+
55+
INDEX = "__nss__idx"
56+
"""Temporary index for mapping back to pre-transformed records."""
57+
58+
REJECT_REASON = "__nss_reject_reason"
59+
"""Reason a row was rejected during batch validation."""
5360

5461

5562
def remove_metadata_columns_from_df(df: pd.DataFrame):
63+
"""Drop all ``MetadataColumns`` from the DataFrame in-place."""
5664
metadata_cols = [col.value for col in MetadataColumns]
5765

5866
columns_to_drop = [col for col in metadata_cols if col in df.columns]
@@ -63,6 +71,7 @@ def remove_metadata_columns_from_df(df: pd.DataFrame):
6371

6472

6573
def remove_metadata_columns_from_records(records: list[dict]) -> list[dict]:
74+
"""Return a copy of each record dict with ``MetadataColumns`` keys removed."""
6675
metadata_cols = [col.value for col in MetadataColumns]
6776

6877
new_records: list[dict] = []
@@ -73,20 +82,18 @@ def remove_metadata_columns_from_records(records: list[dict]) -> list[dict]:
7382

7483

7584
class TransformsUpdate(BaseModel):
76-
"""
77-
`transforms_v2` takes in untyped `dicts`, but this model adds a little
78-
bit of structure for better validation.
79-
"""
85+
"""Typed wrapper for a single transforms_v2 update step."""
8086

81-
name: str
82-
value: str
83-
position: Optional[int] = None
87+
name: str = Field(description="Target column name for the update.")
88+
value: str = Field(description="Jinja expression evaluated by the transforms_v2 engine.")
89+
position: Optional[int] = Field(default=None, description="Column insertion index when adding a new column.")
8490

8591

8692
class TransformsUtil:
87-
"""
88-
Simple helper class to manage an instance of a TV2 `Environment` and some methods
89-
to run `Step`s on input data.
93+
"""Wrapper around a transforms_v2 ``Environment`` for executing column updates and drop conditions.
94+
95+
Args:
96+
seed: Random seed passed to the underlying ``Environment``.
9097
"""
9198

9299
def __init__(self, seed: Optional[int] = None) -> None:
@@ -148,15 +155,24 @@ def execute_drop_condition(self, batch: pd.DataFrame, conditions: list) -> pd.Da
148155

149156

150157
class DataSource(BaseModel, ABC):
158+
"""Abstract base for pluggable data sources used by ``GenDataSource`` actions.
159+
160+
Subclasses implement ``generate_data`` to populate a column in an existing
161+
DataFrame. ``generate_records`` is a convenience wrapper that creates an
162+
empty DataFrame first.
163+
"""
164+
151165
model_config = ConfigDict(alias_generator=type_alias_fn)
152166

153167
_ctx: ActionCtx = PrivateAttr()
154168

155169
def with_ctx(self, ctx: ActionCtx) -> Self:
170+
"""Attach an ``ActionCtx`` and return self for chaining."""
156171
self._ctx = ctx
157172
return self
158173

159174
def generate_records(self, num_records: int, col: str = "newcol") -> list[dict[Hashable, Any]]:
175+
"""Generate records as a list of dicts without an existing DataFrame."""
160176
df = pd.DataFrame(index=range(num_records))
161177
return self.generate_data(df, col).to_dict("records")
162178

@@ -191,16 +207,12 @@ def generate_data(self, df: pd.DataFrame, col: str = "newcol") -> pd.DataFrame:
191207

192208

193209
def is_abstract(c: Any) -> bool:
194-
"""
195-
This checks the two common ways that classes indicate themselves
196-
as abstract; they either have `@abstractmethod`s, or they explicitly
197-
inherit from `ABC` (or the metaclass). This checks both of these.
198-
"""
210+
"""Return True if the class has abstract methods or directly inherits ``ABC``."""
199211
return inspect.isabstract(c) or ABC in c.__bases__
200212

201213

202214
def all_subclasses(klass: type[T]) -> set[type[T]]:
203-
"""Grab all of the recursive subclasses of `klass`."""
215+
"""Recursively collect all subclasses of ``klass``."""
204216
subclasses: set[type[T]] = set()
205217
subclass_queue = [klass]
206218
while subclass_queue:
@@ -213,23 +225,17 @@ def all_subclasses(klass: type[T]) -> set[type[T]]:
213225

214226

215227
def concrete_subclasses(klass: type[T]) -> set[type[T]]:
216-
"""
217-
Find all the subclasses of `klass`, then filter out the abstract
218-
subclasses.
219-
220-
This is useful for passing in a very abstract parent class
221-
like `BaseAction`, and finding all of the potential children
222-
of that `klass`. Some of these children themselves might be abstract,
223-
so we should filter those out.
228+
"""Return all non-abstract recursive subclasses of ``klass``.
224229
225-
This function is likely used to feed information to `pydantic` about
226-
which potential concrete classes exist for purposes of validation and
227-
schema generation.
230+
Used by pydantic discriminated unions (e.g., ``ActionT``) to
231+
auto-discover instantiable action types for validation and schema
232+
generation.
228233
"""
229234
return set(c for c in all_subclasses(klass) if not is_abstract(c))
230235

231236

232237
def guess_datetime_format(datetime_str: str) -> Optional[str]:
238+
"""Infer a ``strftime``-compatible format string from a date string, or None."""
233239
# TODO: use `pandas.tseries.api.guess_datetime_format` in the future?
234240
format = parse_date(datetime_str)
235241
if format is None:
@@ -238,25 +244,17 @@ def guess_datetime_format(datetime_str: str) -> Optional[str]:
238244

239245

240246
class ActionCtx(BaseModel):
241-
"""
242-
Context available during all action execution. This object
243-
can be used for some state specific to the execution,
244-
as well as dependency injection for external services in the future.
245-
"""
247+
"""Execution context shared across all action invocations.
246248
247-
seed: Optional[int] = None
248-
"""
249-
Seed used for all random generation tasks
249+
Provides a random seed, a state dictionary for cross-phase communication,
250+
and a lazily-initialized ``TransformsUtil`` for expression evaluation.
250251
"""
251252

252-
state: dict[str, str] = {}
253-
"""
254-
Used for tracking state across multiple action invocations.
255-
This is important for actions which might have multiple functions
256-
which need to remember information in latter invocations. For example,
257-
a `postprocessing` function might benefit from information persisted
258-
inside a `preprocessing` function.
259-
"""
253+
seed: Optional[int] = Field(default=None, description="Seed used for all random generation tasks.")
254+
255+
state: dict[str, str] = Field(
256+
default={}, description="Per-action state persisted across phases (keyed by BaseAction.hash())."
257+
)
260258

261259
def __init__(self, /, **data: Any) -> None:
262260
super().__init__(**data)

0 commit comments

Comments
 (0)