Skip to content

Commit a408ea7

Browse files
anilkramfacebook-github-bot
authored andcommitted
Adding implementation for edit_similarity_node. (#94)
Summary: Pull Request resolved: #94 Differential Revision: D87895834
1 parent 04145aa commit a408ea7

File tree

2 files changed

+65
-6
lines changed

2 files changed

+65
-6
lines changed

privacy_guard/analysis/extraction/edit_similarity_node.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,19 @@
1212
# pyre-strict
1313

1414
from dataclasses import dataclass
15-
from typing import Optional
15+
from typing import cast, Optional
1616

1717
import pandas as pd
18+
import textdistance
1819
from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode
1920
from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput
2021
from privacy_guard.analysis.extraction.text_inclusion_analysis_input import (
2122
TextInclusionAnalysisInput,
2223
)
24+
from privacy_guard.analysis.extraction.text_inclusion_analysis_node import (
25+
_clean_text,
26+
_normalize_by_target_len,
27+
)
2328

2429
from tqdm import tqdm
2530

@@ -59,12 +64,47 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None:
5964
self.prompt_key: str = analysis_input.prompt_key
6065
self.generation_key: str = analysis_input.generation_key
6166
self.target_key: str = analysis_input.target_key
67+
self.generation_df: pd.DataFrame = analysis_input.generation_df
6268
super().__init__(analysis_input=analysis_input)
6369

70+
def _compute_edit_similarity(
71+
self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None
72+
) -> int:
73+
"""Compute edit similarity between target and generation text. Texts are cleaned first.
74+
Currently not supported for multi target mode.
75+
76+
Args:
77+
row (pd.Series): A row of a DataFrame containing the s1 and s2 columns.
78+
79+
Returns:
80+
int: Edit similarity between the two strings.
81+
"""
82+
s1 = _clean_text(row[s1_column or self.target_key])
83+
s2 = _clean_text(row[s2_column or self.generation_key])
84+
levenshtein = textdistance.levenshtein.similarity(s1, s2)
85+
return levenshtein
86+
6487
def run_analysis(self) -> EditSimilarityNodeOutput:
65-
return EditSimilarityNodeOutput(
66-
num_samples=0,
88+
analysis_input: TextInclusionAnalysisInput = cast(
89+
TextInclusionAnalysisInput, self.analysis_input
90+
)
91+
generation_df = analysis_input.generation_df
92+
93+
outputs = EditSimilarityNodeOutput(
94+
num_samples=len(generation_df),
6795
edit_similarity=None,
6896
edit_similarity_score=None,
69-
augmented_output_dataset=pd.DataFrame(),
97+
augmented_output_dataset=generation_df,
98+
)
99+
100+
generation_df["edit_similarity"] = generation_df.progress_apply(
101+
self._compute_edit_similarity, axis=1
70102
)
103+
generation_df["edit_similarity_score"] = _normalize_by_target_len(
104+
generation_df["edit_similarity"], generation_df[self.target_key]
105+
)
106+
107+
outputs.edit_similarity = generation_df["edit_similarity"]
108+
outputs.edit_similarity_score = generation_df["edit_similarity_score"]
109+
110+
return outputs

privacy_guard/analysis/tests/test_edit_similarity_node.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
# pyre-strict
13+
114
import unittest
215

316
import pandas as pd
@@ -74,5 +87,11 @@ def test_edit_similarity_node_output_creation(self) -> None:
7487
self.assertEqual(output.augmented_output_dataset.equals(generation_df), True)
7588

7689
def test_text_inclusion_edit_similarity(self) -> None:
77-
analysis_node = EditSimilarityNode(analysis_input=self.analysis_input)
78-
analysis_node.run_analysis()
90+
analysis_input = TextInclusionAnalysisInput(
91+
generation_df=pd.DataFrame(self.data)
92+
)
93+
analysis_node = EditSimilarityNode(analysis_input=analysis_input)
94+
results = analysis_node.compute_outputs()
95+
self.assertIn("edit_similarity", results)
96+
self.assertIn("edit_similarity_score", results)
97+
self.assertEqual(results["edit_similarity"].tolist(), [13, 3, 22, 16, 8, 16])

0 commit comments

Comments
 (0)