Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions privacy_guard/analysis/extraction/text_inclusion_analysis_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ def __init__(
target_key: str = "target",
generation_key: str = "output_text",
disable_exact_match: bool = False,
disable_lcs: bool = False,
disable_longest_common_substring: bool = False,
disable_similarity: bool = False,
lcs_bound_config: LCSBoundConfig | None = None,
disable_word_level_longest_common_subsequence: bool = False,
disable_char_level_longest_common_subsequence: bool = True,
) -> None:
columns = generation_df.columns.tolist()
assert (
Expand All @@ -60,10 +62,17 @@ def __init__(
self.generation_key = generation_key

self.disable_exact_match = disable_exact_match
self.disable_lcs = disable_lcs
self.disable_longest_common_substring = disable_longest_common_substring
self.disable_similarity = disable_similarity
self.lcs_bound_config = lcs_bound_config

self.disable_word_level_longest_common_subsequence = (
disable_word_level_longest_common_subsequence
)
self.disable_char_level_longest_common_subsequence = (
disable_char_level_longest_common_subsequence
)

super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame())

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class TextInclusionAnalysisNodeOutput(BaseAnalysisOutput):
edit_similarity_score: Optional[pd.Series]
filtered_true_positive_list: list[str] | None
augmented_output_dataset: pd.DataFrame
char_level_longest_common_subsequence: Optional[pd.Series]
word_level_longest_common_subsequence: Optional[pd.Series]


def _clean_text(text: str) -> str:
Expand Down Expand Up @@ -219,6 +221,32 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None:

super().__init__(analysis_input=analysis_input)

def _compute_word_level_longest_common_subsequence_helper(
self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None
) -> int:
"""Compute char level longest common subsequence between target and generation text.
Text are cleaned first.

Returns:
int: Number of shared words between the two strings.
"""
s1 = _clean_text(row[s1_column or self.target_key])
s2 = _clean_text(row[s2_column or self.generation_key])
return _word_level_longest_common_subsequence_helper(s1, s2)

def _compute_char_level_longest_common_subsequence_helper(
self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None
) -> int:
"""Compute word level longest common subsequence between target and generation text.
Text are cleaned first.

Returns:
int: Number of shared words between the two strings.
"""
s1 = _clean_text(row[s1_column or self.target_key])
s2 = _clean_text(row[s2_column or self.generation_key])
return _char_level_longest_common_subsequence_helper(s1, s2)

def _compute_edit_similarity(
self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None
) -> int:
Expand Down Expand Up @@ -389,9 +417,11 @@ def run_analysis(self) -> TextInclusionAnalysisNodeOutput:
edit_similarity_score=None,
filtered_true_positive_list=None,
augmented_output_dataset=generation_df,
word_level_longest_common_subsequence=None,
char_level_longest_common_subsequence=None,
)

if not analysis_input.disable_lcs:
if not analysis_input.disable_longest_common_substring:
# Longest common substring

lcs_result = generation_df.progress_apply(
Expand Down Expand Up @@ -425,4 +455,26 @@ def run_analysis(self) -> TextInclusionAnalysisNodeOutput:
outputs.edit_similarity = generation_df["edit_similarity"]
outputs.edit_similarity_score = generation_df["edit_similarity_score"]

if not analysis_input.disable_word_level_longest_common_subsequence:
generation_df["word_level_longest_common_subsequence"] = (
generation_df.progress_apply(
self._compute_word_level_longest_common_subsequence_helper, axis=1
)
)

outputs.word_level_longest_common_subsequence = generation_df[
"word_level_longest_common_subsequence"
]

if not analysis_input.disable_char_level_longest_common_subsequence:
generation_df["char_level_longest_common_subsequence"] = (
generation_df.progress_apply(
self._compute_char_level_longest_common_subsequence_helper, axis=1
)
)

outputs.char_level_longest_common_subsequence = generation_df[
"char_level_longest_common_subsequence"
]

return outputs
41 changes: 40 additions & 1 deletion privacy_guard/analysis/tests/test_text_inclusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_output_types(self) -> None:

def test_text_inclusion_no_lcs(self) -> None:
analysis_input = TextInclusionAnalysisInput(
generation_df=pd.DataFrame(self.data), disable_lcs=True
generation_df=pd.DataFrame(self.data), disable_longest_common_substring=True
)
analysis_node = TextInclusionAnalysisNode(analysis_input=analysis_input)

Expand All @@ -199,6 +199,9 @@ def test_text_inclusion_no_lcs(self) -> None:
self.assertIsNotNone(results["edit_similarity"], None)
self.assertIsNotNone(results["edit_similarity_score"], None)

self.assertIsNone(results["char_level_longest_common_subsequence"])
self.assertIsNotNone(results["word_level_longest_common_subsequence"])

def test_text_inclusion_no_similarity(self) -> None:
analysis_input = TextInclusionAnalysisInput(
generation_df=pd.DataFrame(self.data), disable_similarity=True
Expand All @@ -219,6 +222,40 @@ def test_text_inclusion_no_similarity(self) -> None:
self.assertEqual(results["edit_similarity"], None)
self.assertEqual(results["edit_similarity_score"], None)

self.assertIsNone(results["char_level_longest_common_subsequence"])
self.assertIsNotNone(results["word_level_longest_common_subsequence"])

def test_text_inclusion_with_char_level_longest_common_subsequence(self) -> None:
analysis_input = TextInclusionAnalysisInput(
generation_df=pd.DataFrame(self.data),
disable_char_level_longest_common_subsequence=False,
disable_word_level_longest_common_subsequence=False,
)
analysis_node = TextInclusionAnalysisNode(analysis_input=analysis_input)

results = analysis_node.compute_outputs()

self.assertIn("exact_match", results)

self.assertIn("inclusion_score", results)

self.assertIn("longest_common_substring", results)
self.assertIn("decision_targets_lcs", results)
self.assertIsNotNone(results["longest_common_substring"])
self.assertIsNotNone(results["decision_targets_lcs"])

self.assertIsNotNone(results["edit_similarity"])
self.assertIsNotNone(results["edit_similarity_score"])

self.assertIsNotNone(results["char_level_longest_common_subsequence"])
self.assertIsNotNone(results["word_level_longest_common_subsequence"])

for char_lcs, word_lcs in zip(
results["char_level_longest_common_subsequence"],
results["word_level_longest_common_subsequence"],
):
self.assertGreaterEqual(char_lcs, word_lcs)

def test_text_inclusion_augmented_output(self) -> None:
analysis_input = TextInclusionAnalysisInput(
generation_df=pd.DataFrame(self.data)
Expand Down Expand Up @@ -255,6 +292,8 @@ def test_multi_target(self) -> None:
target_key="targets",
disable_exact_match=True,
disable_similarity=True,
disable_word_level_longest_common_subsequence=True,
disable_char_level_longest_common_subsequence=True,
)
multi_analysis_node = TextInclusionAnalysisNode(
analysis_input=multi_analysis_input
Expand Down
Loading