Skip to content

Commit f31036e

Browse files
anilkramfacebook-github-bot
authored andcommitted
Adding implementation for edit_similarity_node. (#94)
Summary: Pull Request resolved: #94 Differential Revision: D87895834
1 parent 04145aa commit f31036e

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

privacy_guard/analysis/extraction/edit_similarity_node.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
# pyre-strict
1313

1414
from dataclasses import dataclass
15-
from typing import Optional
15+
from typing import cast, Optional
1616

1717
import pandas as pd
18+
import textdistance
1819
from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode
1920
from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput
2021
from privacy_guard.analysis.extraction.text_inclusion_analysis_input import (
2122
TextInclusionAnalysisInput,
2223
)
24+
from privacy_guard.analysis.extraction.text_inclusion_analysis_node import (
25+
_clean_text,
26+
_normalize_by_target_len,
27+
)
2328

2429
from tqdm import tqdm
2530

@@ -59,12 +64,47 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None:
5964
self.prompt_key: str = analysis_input.prompt_key
6065
self.generation_key: str = analysis_input.generation_key
6166
self.target_key: str = analysis_input.target_key
67+
self.generation_df: pd.DataFrame = analysis_input.generation_df
6268
super().__init__(analysis_input=analysis_input)
6369

70+
def _compute_edit_similarity(
71+
self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None
72+
) -> int:
73+
"""Compute edit similarity between target and generation text. Texts are cleaned first.
74+
Currently not supported for multi target mode.
75+
76+
Args:
77+
row (pd.Series): A row of a DataFrame containing the s1 and s2 columns.
78+
79+
Returns:
80+
int: Edit similarity between the two strings.
81+
"""
82+
s1 = _clean_text(row[s1_column or self.target_key])
83+
s2 = _clean_text(row[s2_column or self.generation_key])
84+
levenshtein = textdistance.levenshtein.similarity(s1, s2)
85+
return levenshtein
86+
6487
def run_analysis(self) -> EditSimilarityNodeOutput:
65-
return EditSimilarityNodeOutput(
66-
num_samples=0,
88+
analysis_input: TextInclusionAnalysisInput = cast(
89+
TextInclusionAnalysisInput, self.analysis_input
90+
)
91+
generation_df = analysis_input.generation_df
92+
93+
outputs = EditSimilarityNodeOutput(
94+
num_samples=len(generation_df),
6795
edit_similarity=None,
6896
edit_similarity_score=None,
69-
augmented_output_dataset=pd.DataFrame(),
97+
augmented_output_dataset=generation_df,
98+
)
99+
100+
generation_df["edit_similarity"] = generation_df.progress_apply(
101+
self._compute_edit_similarity, axis=1
70102
)
103+
generation_df["edit_similarity_score"] = _normalize_by_target_len(
104+
generation_df["edit_similarity"], generation_df["target"]
105+
)
106+
107+
outputs.edit_similarity = generation_df["edit_similarity"]
108+
outputs.edit_similarity_score = generation_df["edit_similarity_score"]
109+
110+
return outputs

privacy_guard/analysis/tests/test_edit_similarity_node.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,11 @@ def test_edit_similarity_node_output_creation(self) -> None:
7474
self.assertEqual(output.augmented_output_dataset.equals(generation_df), True)
7575

7676
def test_text_inclusion_edit_similarity(self) -> None:
77-
analysis_node = EditSimilarityNode(analysis_input=self.analysis_input)
78-
analysis_node.run_analysis()
77+
analysis_input = TextInclusionAnalysisInput(
78+
generation_df=pd.DataFrame(self.data)
79+
)
80+
analysis_node = EditSimilarityNode(analysis_input=analysis_input)
81+
results = analysis_node.compute_outputs()
82+
self.assertIn("edit_similarity", results)
83+
self.assertIn("edit_similarity_score", results)
84+
self.assertEqual(results["edit_similarity"].tolist(), [13, 3, 22, 16, 8, 16])

0 commit comments

Comments
 (0)