|
12 | 12 | # pyre-strict |
13 | 13 |
|
14 | 14 | from dataclasses import dataclass |
15 | | -from typing import Optional |
| 15 | +from typing import cast, Optional |
16 | 16 |
|
17 | 17 | import pandas as pd |
| 18 | +import textdistance |
18 | 19 | from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode |
19 | 20 | from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput |
20 | 21 | from privacy_guard.analysis.extraction.text_inclusion_analysis_input import ( |
21 | 22 | TextInclusionAnalysisInput, |
22 | 23 | ) |
| 24 | +from privacy_guard.analysis.extraction.text_inclusion_analysis_node import ( |
| 25 | + _clean_text, |
| 26 | + _normalize_by_target_len, |
| 27 | +) |
23 | 28 |
|
24 | 29 | from tqdm import tqdm |
25 | 30 |
|
@@ -59,12 +64,47 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None: |
59 | 64 | self.prompt_key: str = analysis_input.prompt_key |
60 | 65 | self.generation_key: str = analysis_input.generation_key |
61 | 66 | self.target_key: str = analysis_input.target_key |
| 67 | + self.generation_df: pd.DataFrame = analysis_input.generation_df |
62 | 68 | super().__init__(analysis_input=analysis_input) |
63 | 69 |
|
| 70 | + def _compute_edit_similarity( |
| 71 | + self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None |
| 72 | + ) -> int: |
| 73 | + """Compute edit similarity between target and generation text. Texts are cleaned first. |
| 74 | + Currently not supported for multi target mode. |
| 75 | +
|
| 76 | + Args: |
| 77 | + row (pd.Series): A row of a DataFrame containing the s1 and s2 columns. |
| 78 | +
|
| 79 | + Returns: |
| 80 | + int: Edit similarity between the two strings. |
| 81 | + """ |
| 82 | + s1 = _clean_text(row[s1_column or self.target_key]) |
| 83 | + s2 = _clean_text(row[s2_column or self.generation_key]) |
| 84 | + levenshtein = textdistance.levenshtein.similarity(s1, s2) |
| 85 | + return levenshtein |
| 86 | + |
64 | 87 | def run_analysis(self) -> EditSimilarityNodeOutput: |
65 | | - return EditSimilarityNodeOutput( |
66 | | - num_samples=0, |
| 88 | + analysis_input: TextInclusionAnalysisInput = cast( |
| 89 | + TextInclusionAnalysisInput, self.analysis_input |
| 90 | + ) |
| 91 | + generation_df = analysis_input.generation_df |
| 92 | + |
| 93 | + outputs = EditSimilarityNodeOutput( |
| 94 | + num_samples=len(generation_df), |
67 | 95 | edit_similarity=None, |
68 | 96 | edit_similarity_score=None, |
69 | | - augmented_output_dataset=pd.DataFrame(), |
| 97 | + augmented_output_dataset=generation_df, |
| 98 | + ) |
| 99 | + |
| 100 | + generation_df["edit_similarity"] = generation_df.progress_apply( |
| 101 | + self._compute_edit_similarity, axis=1 |
70 | 102 | ) |
| 103 | + generation_df["edit_similarity_score"] = _normalize_by_target_len( |
| 104 | + generation_df["edit_similarity"], generation_df["target"] |
| 105 | + ) |
| 106 | + |
| 107 | + outputs.edit_similarity = generation_df["edit_similarity"] |
| 108 | + outputs.edit_similarity_score = generation_df["edit_similarity_score"] |
| 109 | + |
| 110 | + return outputs |
0 commit comments