Skip to content

Commit 9f38ece

Browse files
authored
docs(eval): add docstrings for evaluation code (#148)
Added/Updated docstrings for the evaluation/ path. Had to make some minor fixes to pass lint checking. ] --------- Signed-off-by: Nina Xu <19981858+nina-xu@users.noreply.github.com>
1 parent e8d56d0 commit 9f38ece

25 files changed

Lines changed: 891 additions & 422 deletions

src/nemo_safe_synthesizer/evaluation/assets/text/multi_modal_tooltips.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
# ruff: noqa
55

6+
"""Tooltip text displayed in the multi-modal HTML evaluation report."""
7+
68
tooltips = {
79
"dataset_statistics_info": """
810
The dataset statistics provide a summary of the datasets. The table includes the number of rows and columns,

src/nemo_safe_synthesizer/evaluation/components/attribute_inference_protection.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,27 @@
4141

4242

4343
class AttributeInferenceProtection(Component):
44+
"""Attribute Inference Protection privacy metric.
45+
46+
Simulates an attribute inference attack: given quasi-identifier columns,
47+
can an adversary use synthetic nearest-neighbors to predict the remaining
48+
attributes of a training record? A higher score indicates better
49+
protection (lower prediction accuracy).
50+
51+
See Also:
52+
https://arxiv.org/abs/2501.03941 -- Synthetic Data Privacy Metrics.
53+
"""
54+
4455
name: str = Field(default="Attribute Inference Protection")
45-
col_accuracy_df: pd.DataFrame | None = Field(default=None)
56+
col_accuracy_df: pd.DataFrame | None = Field(
57+
default=None, description="Per-column prediction risk scores and grades."
58+
)
4659

4760
model_config = ConfigDict(arbitrary_types_allowed=True)
4861

4962
@cached_property
5063
def jinja_context(self) -> dict[str, str]:
64+
"""Template context with the attribute-inference bar chart figure."""
5165
d = super().jinja_context
5266
d["anchor_link"] = "#aia"
5367
if self.col_accuracy_df is not None and not self.col_accuracy_df.empty:
@@ -62,6 +76,7 @@ def jinja_context(self) -> dict[str, str]:
6276
def from_evaluation_dataset(
6377
evaluation_dataset: EvaluationDataset, config: SafeSynthesizerParameters | None = None
6478
) -> AttributeInferenceProtection:
79+
"""Run the attribute inference attack and return the protection score."""
6580
if not faiss_available:
6681
logger.info("FAISS is not available, skipping Attribute Inference Attack.")
6782
return AttributeInferenceProtection(score=EvaluationScore())
@@ -180,7 +195,7 @@ def _is_really_categorical(column: str) -> bool:
180195

181196
@staticmethod
182197
def _divide_tabular_text(df: pd.DataFrame, text_fields: list) -> tuple[pd.DataFrame, pd.DataFrame]:
183-
"""Takes a dataframe and divides it into two dataframes, one with the text fields and one with the tabular fields"""
198+
"""Split a dataframe into tabular-only and text-only subsets."""
184199
tabular_fields = []
185200
for col in df.columns:
186201
if col not in text_fields:
@@ -192,9 +207,7 @@ def _divide_tabular_text(df: pd.DataFrame, text_fields: list) -> tuple[pd.DataFr
192207

193208
@staticmethod
194209
def _embed_text(df: pd.DataFrame, embedder) -> pd.DataFrame:
195-
"""Takes a dataframe of text fields, finds the embeddings for each
196-
and then averages the embeddings into one embedding and returns a dataframe with just that
197-
"""
210+
"""Embed each text column and average into a single embedding per row."""
198211
embeddings = {}
199212
for col in df.columns:
200213
data = df[col].to_list()
@@ -267,7 +280,7 @@ def _get_synth_nn(
267280
if len(text_columns) == 0:
268281
# Create the faiss index on the synthetic data
269282
dim = df_synth_norm.shape[1]
270-
index = faiss.IndexFlatL2(dim) # ty: ignore[unresolved-attribute, possibly-unbound-attribute]
283+
index = faiss.IndexFlatL2(dim) # ty: ignore[possibly-unbound-attribute]
271284

272285
# This usage matches documentation. Specifying n= and x= parameters as
273286
# the type annotation for IndexFlatL2.add suggests seems unnecessary, possibly related
@@ -288,15 +301,15 @@ def _get_synth_nn(
288301
df_train_embeddings = AttributeInferenceProtection._embed_text(df_train_text, embedder)
289302
df_synth_embeddings = AttributeInferenceProtection._embed_text(df_synth_text, embedder)
290303
hits = util.semantic_search(
291-
np.array(list(df_train_embeddings["embedding"])),
292-
np.array(list(df_synth_embeddings["embedding"])),
304+
np.array(list(df_train_embeddings["embedding"])), # ty: ignore[invalid-argument-type]
305+
np.array(list(df_synth_embeddings["embedding"])), # ty: ignore[invalid-argument-type]
293306
top_k=k,
294307
)
295308
synth_rows = pd.DataFrame()
296309
for i in range(k):
297310
corpus_id = hits[0][i]["corpus_id"]
298311
synth_rows = pd.concat(
299-
[synth_rows, pd.DataFrame([df_synth.iloc[corpus_id]])],
312+
[synth_rows, pd.DataFrame([df_synth.iloc[int(corpus_id)]])],
300313
ignore_index=True,
301314
)
302315

@@ -310,8 +323,8 @@ def _get_synth_nn(
310323
df_synth_embeddings = AttributeInferenceProtection._embed_text(df_synth_text, embedder)
311324
search_synth_k = min(1000, len(df_synth_embeddings))
312325
hits = util.semantic_search(
313-
np.array(list(df_train_embeddings["embedding"])),
314-
np.array(list(df_synth_embeddings["embedding"])),
326+
np.array(list(df_train_embeddings["embedding"])), # ty: ignore[invalid-argument-type]
327+
np.array(list(df_synth_embeddings["embedding"])), # ty: ignore[invalid-argument-type]
315328
top_k=search_synth_k,
316329
)
317330
synth_NN = pd.DataFrame()
@@ -324,12 +337,12 @@ def _get_synth_nn(
324337
dist = 1 - sim
325338
text_dist[i] = dist
326339
corpus_ids.append(corpus_id)
327-
synth_NN = pd.concat([synth_NN, pd.DataFrame([df_synth_norm.iloc[corpus_id]])], ignore_index=True)
340+
synth_NN = pd.concat([synth_NN, pd.DataFrame([df_synth_norm.iloc[int(corpus_id)]])], ignore_index=True)
328341

329342
# Now get the tabular similarity for these 1000 NN
330343

331344
dim = synth_NN.shape[1]
332-
index = faiss.IndexFlatL2(dim) # ty: ignore[unresolved-attribute, possibly-unbound-attribute]
345+
index = faiss.IndexFlatL2(dim) # ty: ignore[possibly-unbound-attribute]
333346
index.add(np.float32(np.ascontiguousarray(np.array(synth_NN)))) # ty: ignore[missing-argument]
334347
dists, indexes = index.search(np.float32(np.ascontiguousarray(np.array(df_train_norm))), search_synth_k) # ty: ignore[missing-argument]
335348
# Scale the Euclidean distance to [0,1]
@@ -372,6 +385,20 @@ def _aia(
372385
df_synth: pd.DataFrame,
373386
quasi_identifier_count: int,
374387
) -> tuple[EvaluationScore, pd.DataFrame | None]:
388+
"""Core attribute inference attack implementation.
389+
390+
Iterates over random quasi-identifier subsets, finds nearest
391+
synthetic neighbors, and measures attribute prediction accuracy
392+
weighted by column entropy.
393+
394+
Args:
395+
df_train: Training dataframe.
396+
df_synth: Synthetic dataframe.
397+
quasi_identifier_count: Number of columns to use as quasi-identifiers.
398+
399+
Returns:
400+
Tuple of (overall protection score, per-column accuracy dataframe).
401+
"""
375402
ias = EvaluationScore(grade=PrivacyGrade.UNAVAILABLE)
376403
col_accuracy_df = None
377404
if quasi_identifier_count is None:
@@ -408,7 +435,7 @@ def _aia(
408435
nominal_columns = list(df_train.select_dtypes(include=["object", "category", "bool"]).columns)
409436
numeric_columns = [column for column in df_train.columns if column not in nominal_columns]
410437

411-
# Now seperate out the text columns from the nominal
438+
# Now separate out the text columns from the nominal
412439

413440
text_columns = []
414441
for col in nominal_columns:
@@ -531,7 +558,7 @@ def _aia(
531558
# Lat/lon values inspired this. Text must be dist .35 or less
532559
for column in predict_columns:
533560
synth_val = synth_values[column]
534-
train_val = train_row_all.iloc[0][column]
561+
train_val = train_row_all.iloc[0][column] # ty: ignore[invalid-argument-type]
535562

536563
if pd.isna(train_val):
537564
continue

src/nemo_safe_synthesizer/evaluation/components/column_distribution.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@
2323

2424

2525
class ColumnDistributionPlotRow(BaseModel):
26-
name1: str = Field()
27-
name2: str | None = Field()
28-
figure: str = Field()
26+
"""A pair of side-by-side column distribution plots for the HTML report."""
27+
28+
name1: str = Field(description="Name of the first column in the plot row.")
29+
name2: str | None = Field(description="Name of the second column in the plot row, if present.")
30+
figure: str = Field(description="Rendered HTML of the side-by-side distribution plot.")
2931

3032
@staticmethod
3133
def _get_figure_for_field(f: EvaluationField | None, reference: pd.Series, output) -> Figure | None:
@@ -88,21 +90,26 @@ def from_evaluation_dataset(evaluation_dataset: EvaluationDataset) -> list[dict[
8890

8991

9092
class ColumnDistribution(Component):
91-
"""
92-
This class wears a few hats, not ideal but saves some duplication:
93-
* Rendering of each EvaluationFields histogram
94-
* Rendering of Reference Columns table
95-
* Computation/rendering of Column Distribution Stability score
96-
* Field Distribution Stability functions are used for text metrics and (iirc) PCA as well
93+
"""Column Distribution Stability metric.
94+
95+
Computes per-column Jensen-Shannon divergence between reference and
96+
output distributions, averages across all tabular columns, and maps
97+
the result to a 0--10 score. Also carries data for the per-column
98+
histogram figures and the Reference Columns table in the HTML report.
9799
"""
98100

99101
name: str = Field(default="Column Distribution Stability")
100102
# Keep a copy to simplify rendering
101-
column_statistics: dict[str, ColumnStatistics] | None = Field(default=None)
102-
evaluation_fields: list[EvaluationField] = Field(default=list())
103+
column_statistics: dict[str, ColumnStatistics] | None = Field(
104+
default=None, description="Per-column PII entity and transform metadata."
105+
)
106+
evaluation_fields: list[EvaluationField] = Field(
107+
default=list(), description="Per-column evaluation metadata and distribution scores."
108+
)
103109

104110
@cached_property
105111
def jinja_context(self):
112+
"""Template context with evaluation fields and column statistics for the report."""
106113
d = super().jinja_context
107114
d["anchor_link"] = "#distribution-stability"
108115
if self.evaluation_fields:
@@ -117,6 +124,7 @@ def jinja_context(self):
117124
def from_evaluation_dataset(
118125
evaluation_dataset: EvaluationDataset, config: SafeSynthesizerParameters | None = None
119126
) -> ColumnDistribution:
127+
"""Compute column distribution stability from the evaluation dataset."""
120128
tabular_columns = set(evaluation_dataset.get_tabular_columns())
121129
tabular_fields = [f for f in evaluation_dataset.evaluation_fields if f.name in tabular_columns]
122130
if tabular_fields:

src/nemo_safe_synthesizer/evaluation/components/component.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,45 @@
1717

1818

1919
class Component(ABC, BaseModel):
20-
name: str = Field(
21-
description="Override this with the fancy display name of your component. It is used for json summaries and rendering scores."
20+
"""Abstract base for all evaluation components.
21+
22+
Each component computes one quality or privacy metric from an
23+
``EvaluationDataset`` and exposes a ``jinja_context`` property
24+
for HTML report rendering.
25+
26+
Subclasses should override ``from_evaluation_dataset`` to perform
27+
their metric-specific computation.
28+
"""
29+
30+
name: str = Field(description="Display name used in JSON summaries and the HTML report.")
31+
score: EvaluationScore = Field(
32+
default=EvaluationScore(), description="The computed EvaluationScore for this component."
2233
)
23-
score: EvaluationScore = Field(default=EvaluationScore())
2434

2535
@staticmethod
2636
def from_evaluation_dataset(
2737
evaluation_dataset: EvaluationDataset, config: SafeSynthesizerParameters | None = None
2838
) -> Component:
29-
return Component()
39+
"""Create a component from an ``EvaluationDataset``.
40+
41+
Subclasses override this to compute their specific metric.
42+
43+
Args:
44+
evaluation_dataset: Paired reference/output data.
45+
config: Optional pipeline configuration parameters.
46+
47+
Returns:
48+
A new component instance with computed scores.
49+
"""
50+
return Component() # ty: ignore[missing-argument]
3051

3152
def get_json(self) -> str:
53+
"""Serialize the component score to a JSON string."""
3254
return self.score.model_dump_json()
3355

3456
@cached_property
3557
def jinja_context(self) -> dict[str, Any]:
58+
"""Template context dict for Jinja2 rendering, keyed by name, score, and figure HTML."""
3659
# Dict values are typed as "Any" but err on the side of primitives (html strings, not plotly.Figure e.g.).
3760
# Prepping up front saves formatting logic inlined in templates.
3861
d = dict()
@@ -45,7 +68,7 @@ def jinja_context(self) -> dict[str, Any]:
4568

4669
@staticmethod
4770
def is_nonempty(dfs: None | pd.DataFrame | list[pd.DataFrame | None]) -> bool:
48-
"""Util for components that need to check dataframes before attempting to render (correlation and PCA)"""
71+
"""Return ``True`` if all provided DataFrames are non-``None`` and non-empty."""
4972
if dfs is None:
5073
return False
5174
if isinstance(dfs, pd.DataFrame):

src/nemo_safe_synthesizer/evaluation/components/composite_score.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,14 @@
1111

1212

1313
class CompositeScore(Component):
14+
"""A component whose score is the mean of its child component scores.
15+
16+
Used as the base for aggregate metrics like SQS and Data Privacy Score.
17+
"""
18+
1419
@cached_property
1520
def jinja_context(self):
21+
"""Template context with duplicate gauge figures for overview and detail sections."""
1622
d = super().jinja_context
1723
# This is some "plotly magic." The figure is a div with an id and an inlined script.
1824
# If you attempt to reuse the figure (we do), it won't render for the second one.
@@ -22,14 +28,15 @@ def jinja_context(self):
2228

2329
@staticmethod
2430
def from_components(components: list[Component] | Component, name: str) -> CompositeScore:
31+
"""Compute a composite score as the mean of child component scores."""
2532
if isinstance(components, Component):
26-
return CompositeScore(score=components.score)
33+
return CompositeScore(score=components.score, name=name)
2734
if (
2835
components is None
2936
or len(components) == 0
3037
or all([True for c in components if c.score is None or c.score.score is None])
3138
):
32-
return CompositeScore(score=EvaluationScore())
39+
return CompositeScore(score=EvaluationScore(), name=name)
3340

3441
# Take the mean
3542
total = 0.0

0 commit comments

Comments
 (0)