11# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4+ """Shared utilities for the data actions framework.
5+
6+ Provides ``ActionCtx`` (execution context with state and dependency injection),
7+ ``TransformsUtil`` (wrapper around the transforms_v2 engine), helper types
8+ (``MetadataColumns``, ``TransformsUpdate``), and subclass-discovery functions.
9+ """
10+
411from __future__ import annotations
512
613import inspect
3542
3643
3744def type_alias_fn (field_name : str ) -> str :
38- """
39- This alias fn allows `type_` to be parsed as `type` from config yaml. We use `type_`
40- in the actual python objects so it doesn't conflict with the python builtin `type()`.
41- """
45+ """Pydantic alias generator that maps ``type_`` to ``type`` for YAML compatibility."""
4246 if field_name == "type_" :
4347 return "type"
4448
4549 return field_name
4650
4751
4852class MetadataColumns (StrEnum ):
49- INDEX = "__gretel__idx" # used in validation to maintain a mapping to pre-transformed records
50- REJECT_REASON = (
51- "__gretel_reject_reason" # used in validation to attach model_metadata about why the row was rejected
52- )
53+ """Internal column names injected during validation phases."""
54+
55+ INDEX = "__nss__idx"
56+ """Temporary index for mapping back to pre-transformed records."""
57+
58+ REJECT_REASON = "__nss_reject_reason"
59+ """Reason a row was rejected during batch validation."""
5360
5461
5562def remove_metadata_columns_from_df (df : pd .DataFrame ):
63+ """Drop all ``MetadataColumns`` from the DataFrame in-place."""
5664 metadata_cols = [col .value for col in MetadataColumns ]
5765
5866 columns_to_drop = [col for col in metadata_cols if col in df .columns ]
@@ -63,6 +71,7 @@ def remove_metadata_columns_from_df(df: pd.DataFrame):
6371
6472
6573def remove_metadata_columns_from_records (records : list [dict ]) -> list [dict ]:
74+ """Return a copy of each record dict with ``MetadataColumns`` keys removed."""
6675 metadata_cols = [col .value for col in MetadataColumns ]
6776
6877 new_records : list [dict ] = []
@@ -73,20 +82,18 @@ def remove_metadata_columns_from_records(records: list[dict]) -> list[dict]:
7382
7483
7584class TransformsUpdate (BaseModel ):
76- """
77- `transforms_v2` takes in untyped `dicts`, but this model adds a little
78- bit of structure for better validation.
79- """
85+ """Typed wrapper for a single transforms_v2 update step."""
8086
81- name : str
82- value : str
83- position : Optional [int ] = None
87+ name : str = Field ( description = "Target column name for the update." )
88+ value : str = Field ( description = "Jinja expression evaluated by the transforms_v2 engine." )
89+ position : Optional [int ] = Field ( default = None , description = "Column insertion index when adding a new column." )
8490
8591
8692class TransformsUtil :
87- """
88- Simple helper class to manage an instance of a TV2 `Environment` and some methods
89- to run `Step`s on input data.
93+ """Wrapper around a transforms_v2 ``Environment`` for executing column updates and drop conditions.
94+
95+ Args:
96+ seed: Random seed passed to the underlying ``Environment``.
9097 """
9198
9299 def __init__ (self , seed : Optional [int ] = None ) -> None :
@@ -148,15 +155,24 @@ def execute_drop_condition(self, batch: pd.DataFrame, conditions: list) -> pd.Da
148155
149156
150157class DataSource (BaseModel , ABC ):
158+ """Abstract base for pluggable data sources used by ``GenDataSource`` actions.
159+
160+ Subclasses implement ``generate_data`` to populate a column in an existing
161+ DataFrame. ``generate_records`` is a convenience wrapper that creates an
162+ empty DataFrame first.
163+ """
164+
151165 model_config = ConfigDict (alias_generator = type_alias_fn )
152166
153167 _ctx : ActionCtx = PrivateAttr ()
154168
155169 def with_ctx (self , ctx : ActionCtx ) -> Self :
170+ """Attach an ``ActionCtx`` and return self for chaining."""
156171 self ._ctx = ctx
157172 return self
158173
159174 def generate_records (self , num_records : int , col : str = "newcol" ) -> list [dict [Hashable , Any ]]:
175+ """Generate records as a list of dicts without an existing DataFrame."""
160176 df = pd .DataFrame (index = range (num_records ))
161177 return self .generate_data (df , col ).to_dict ("records" )
162178
@@ -191,16 +207,12 @@ def generate_data(self, df: pd.DataFrame, col: str = "newcol") -> pd.DataFrame:
191207
192208
193209def is_abstract (c : Any ) -> bool :
194- """
195- This checks the two common ways that classes indicate themselves
196- as abstract; they either have `@abstractmethod`s, or they explicitly
197- inherit from `ABC` (or the metaclass). This checks both of these.
198- """
210+ """Return True if the class has abstract methods or directly inherits ``ABC``."""
199211 return inspect .isabstract (c ) or ABC in c .__bases__
200212
201213
202214def all_subclasses (klass : type [T ]) -> set [type [T ]]:
203- """Grab all of the recursive subclasses of `klass`."""
215+ """Recursively collect all subclasses of `` klass` `."""
204216 subclasses : set [type [T ]] = set ()
205217 subclass_queue = [klass ]
206218 while subclass_queue :
@@ -213,23 +225,17 @@ def all_subclasses(klass: type[T]) -> set[type[T]]:
213225
214226
215227def concrete_subclasses (klass : type [T ]) -> set [type [T ]]:
216- """
217- Find all the subclasses of `klass`, then filter out the abstract
218- subclasses.
219-
220- This is useful for passing in a very abstract parent class
221- like `BaseAction`, and finding all of the potential children
222- of that `klass`. Some of these children themselves might be abstract,
223- so we should filter those out.
228+ """Return all non-abstract recursive subclasses of ``klass``.
224229
225- This function is likely used to feed information to `pydantic` about
226- which potential concrete classes exist for purposes of validation and
227- schema generation.
230+ Used by pydantic discriminated unions (e.g., ``ActionT``) to
231+ auto-discover instantiable action types for validation and schema
232+ generation.
228233 """
229234 return set (c for c in all_subclasses (klass ) if not is_abstract (c ))
230235
231236
232237def guess_datetime_format (datetime_str : str ) -> Optional [str ]:
238+ """Infer a ``strftime``-compatible format string from a date string, or None."""
233239 # TODO: use `pandas.tseries.api.guess_datetime_format` in the future?
234240 format = parse_date (datetime_str )
235241 if format is None :
@@ -238,25 +244,17 @@ def guess_datetime_format(datetime_str: str) -> Optional[str]:
238244
239245
240246class ActionCtx (BaseModel ):
241- """
242- Context available during all action execution. This object
243- can be used for some state specific to the execution,
244- as well as dependency injection for external services in the future.
245- """
247+ """Execution context shared across all action invocations.
246248
247- seed : Optional [int ] = None
248- """
249- Seed used for all random generation tasks
249+ Provides a random seed, a state dictionary for cross-phase communication,
250+ and a lazily-initialized ``TransformsUtil`` for expression evaluation.
250251 """
251252
252- state : dict [str , str ] = {}
253- """
254- Used for tracking state across multiple action invocations.
255- This is important for actions which might have multiple functions
256- which need to remember information in latter invocations. For example,
257- a `postprocessing` function might benefit from information persisted
258- inside a `preprocessing` function.
259- """
253+ seed : Optional [int ] = Field (default = None , description = "Seed used for all random generation tasks." )
254+
255+ state : dict [str , str ] = Field (
256+ default = {}, description = "Per-action state persisted across phases (keyed by BaseAction.hash())."
257+ )
260258
261259 def __init__ (self , / , ** data : Any ) -> None :
262260 super ().__init__ (** data )
0 commit comments