Skip to content

Commit 1dcc960

Browse files
mgrange1998meta-codesync[bot]
authored andcommitted
Text Inclusion Analysis, include matched word level longest common subsequence (#92)
Summary: Pull Request resolved: #92 This diff adds support to text inclusion analysis node for returning the matched text from word level longest common subsequence. This allows us to inspect the matched text, and find the length of the matched text compared to the target and generated text. Reviewed By: lucamelis Differential Revision: D89493658 fbshipit-source-id: fbd310375054ddb116c8297e1557579780e9fc53
1 parent 08c319a commit 1dcc960

File tree

2 files changed

+320
-11
lines changed

2 files changed

+320
-11
lines changed

privacy_guard/analysis/extraction/text_inclusion_analysis_node.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,65 @@ class TextInclusionAnalysisNodeOutput(BaseAnalysisOutput):
5353
None # Include for future reference
5454
)
5555

56+
def format_single_word_level_lcs_result(
57+
self,
58+
num_matched_words: int,
59+
matched_string: str,
60+
augmented_row: Dict[str, Any],
61+
analysis_input: TextInclusionAnalysisInput,
62+
) -> Dict[str, Any]:
63+
prompt = augmented_row[analysis_input.prompt_key]
64+
prediction = augmented_row[analysis_input.generation_key]
65+
66+
target = augmented_row[analysis_input.target_key]
67+
# The method here should set remove_consecutive_whitespace based on analysis input
68+
clean_target_len = len(_clean_text_remove_consecutive_whitespace(text=target))
69+
70+
matched_string_char_length = len(matched_string)
71+
word_level_lcs_result_dict = {
72+
"Count of matched words": num_matched_words,
73+
"Length of matched words": matched_string_char_length,
74+
"Matched consecutive sequence": matched_string,
75+
"% target extracted": "N/A"
76+
if clean_target_len == 0
77+
else 100 * matched_string_char_length / clean_target_len,
78+
analysis_input.prompt_key: prompt,
79+
analysis_input.target_key: target,
80+
analysis_input.generation_key: prediction,
81+
}
82+
83+
return word_level_lcs_result_dict
84+
85+
def word_level_lcs_result_formatted(self) -> pd.DataFrame:
86+
"""Returns a interpretble dataframe of the word level results."""
87+
if self.word_level_longest_common_subsequence is None:
88+
raise ValueError("No lcs results to display.")
89+
if self.analysis_input is None:
90+
raise ValueError("No analysis input, can't id keys for formatting")
91+
92+
word_level_longest_common_subsequence_list = list(
93+
self.word_level_longest_common_subsequence
94+
)
95+
96+
displays: List[Dict[str, Any]] = []
97+
98+
for word_level_tuple, augmented_row in zip(
99+
word_level_longest_common_subsequence_list,
100+
self.augmented_output_dataset.T.to_dict().values(),
101+
):
102+
num_matched_words = word_level_tuple[0]
103+
matched_string = word_level_tuple[1]
104+
displays.append(
105+
self.format_single_word_level_lcs_result(
106+
num_matched_words=num_matched_words,
107+
matched_string=matched_string,
108+
augmented_row=augmented_row,
109+
analysis_input=self.analysis_input, # pyre-ignore
110+
)
111+
)
112+
113+
return pd.DataFrame(displays)
114+
56115
def format_single_lcs_result(
57116
self,
58117
lcs_dict: Dict[str, Any],
@@ -154,7 +213,7 @@ def _clean_text_remove_consecutive_whitespace(text: str) -> str:
154213

155214
def _word_level_longest_common_subsequence_helper(
156215
s1: str, s2: str, autojunk: bool = True
157-
) -> int:
216+
) -> Tuple[int, str]:
158217
"""
159218
Implementation of the longest common subsequence at word level.
160219
@@ -171,10 +230,13 @@ def _word_level_longest_common_subsequence_helper(
171230

172231
# Initialize the length of matched words count
173232
matched_words_count = 0
233+
matched_words = []
174234
for block in matching_blocks:
175235
if block.size > 0:
176236
matched_words_count += block.size
177-
return matched_words_count
237+
matched_words.extend(s1_list[block.a : block.a + block.size])
238+
reconstructed_match = " ".join(matched_words)
239+
return (matched_words_count, reconstructed_match)
178240

179241

180242
def _char_level_longest_common_subsequence_helper(
@@ -324,7 +386,7 @@ def __init__(self, analysis_input: TextInclusionAnalysisInput) -> None:
324386

325387
def _compute_word_level_longest_common_subsequence_helper(
326388
self, row: pd.Series, s1_column: str | None = None, s2_column: str | None = None
327-
) -> int:
389+
) -> Tuple[int, str]:
328390
"""Compute char level longest common subsequence between target and generation text.
329391
Text are cleaned first.
330392

privacy_guard/analysis/tests/test_text_inclusion.py

Lines changed: 255 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def test_text_inclusion_with_char_level_longest_common_subsequence(self) -> None
262262
results["char_level_longest_common_subsequence"],
263263
results["word_level_longest_common_subsequence"],
264264
):
265-
self.assertGreaterEqual(char_lcs, word_lcs)
265+
self.assertGreaterEqual(char_lcs, word_lcs[0])
266266

267267
def test_text_inclusion_augmented_output(self) -> None:
268268
analysis_input = TextInclusionAnalysisInput(
@@ -463,16 +463,26 @@ def test_word_level_longest_common_susequence_match(self) -> None:
463463
+ ("t" * 130)
464464
)
465465

466-
self.assertEqual(_word_level_longest_common_subsequence_helper(s1=s1, s2=s2), 2)
467-
self.assertEqual(_word_level_longest_common_subsequence_helper(s1=s1, s2=s1), 5)
466+
self.assertEqual(
467+
_word_level_longest_common_subsequence_helper(s1=s1, s2=s2)[0], 2
468+
)
469+
self.assertEqual(
470+
_word_level_longest_common_subsequence_helper(s1=s1, s2=s1)[0], 5
471+
)
468472

469473
s1 = "a b a"
470474
s2 = "c a b a d"
471475
s3 = "a d b a"
472476

473-
self.assertEqual(_word_level_longest_common_subsequence_helper(s1=s1, s2=s2), 3)
474-
self.assertEqual(_word_level_longest_common_subsequence_helper(s1=s2, s2=s3), 3)
475-
self.assertEqual(_word_level_longest_common_subsequence_helper(s1=s1, s2=s3), 3)
477+
self.assertEqual(
478+
_word_level_longest_common_subsequence_helper(s1=s1, s2=s2), (3, "a b a")
479+
)
480+
self.assertEqual(
481+
_word_level_longest_common_subsequence_helper(s1=s2, s2=s3), (3, "a b a")
482+
)
483+
self.assertEqual(
484+
_word_level_longest_common_subsequence_helper(s1=s1, s2=s3), (3, "a b a")
485+
)
476486

477487
def test_char_level_longest_common_susequence_match(self) -> None:
478488
s1 = ("w" * 5) + ("t" * 16) + ("b" * 5) + ("t" * 15)
@@ -517,11 +527,15 @@ def test_longest_common_susequence_match_autojunk(self) -> None:
517527
s2 = ("x " * 50) + ("t " * 160) + ("c " * 150) + ("t " * 200) + "end2"
518528

519529
self.assertEqual(
520-
_word_level_longest_common_subsequence_helper(s1=s1, s2=s2, autojunk=False),
530+
_word_level_longest_common_subsequence_helper(s1=s1, s2=s2, autojunk=False)[
531+
0
532+
],
521533
260,
522534
)
523535
self.assertEqual(
524-
_word_level_longest_common_subsequence_helper(s1=s1, s2=s2, autojunk=True),
536+
_word_level_longest_common_subsequence_helper(s1=s1, s2=s2, autojunk=True)[
537+
0
538+
],
525539
0,
526540
)
527541

@@ -608,3 +622,236 @@ def test_analysis_with_remove_consecutive_whitespace(self) -> None:
608622
results_basic["edit_similarity_score"].iloc[0],
609623
results_cleaned["edit_similarity_score"].iloc[0],
610624
)
625+
626+
def test_format_single_word_level_lcs_result(self) -> None:
627+
"""Test format_single_word_level_lcs_result returns correct dictionary structure."""
628+
analysis_outputs = self.analysis_node.run_analysis()
629+
self.assertIsInstance(analysis_outputs, TextInclusionAnalysisNodeOutput)
630+
631+
# Get the augmented row data
632+
augmented_row = analysis_outputs.augmented_output_dataset.iloc[-1].to_dict()
633+
634+
# Call format_single_word_level_lcs_result directly
635+
result = analysis_outputs.format_single_word_level_lcs_result(
636+
num_matched_words=3,
637+
matched_string="dolorem ipsum quia",
638+
augmented_row=augmented_row,
639+
analysis_input=self.analysis_input,
640+
)
641+
642+
# Verify the result dictionary has the expected keys
643+
self.assertIn("Count of matched words", result.keys())
644+
self.assertIn("Length of matched words", result.keys())
645+
self.assertIn("Matched consecutive sequence", result.keys())
646+
self.assertIn("% target extracted", result.keys())
647+
self.assertIn("prompt", result.keys())
648+
self.assertIn("output_text", result.keys())
649+
self.assertIn("target", result.keys())
650+
651+
# Verify the values are correct
652+
self.assertEqual(result["Count of matched words"], 3)
653+
self.assertEqual(result["Length of matched words"], len("dolorem ipsum quia"))
654+
self.assertEqual(result["Matched consecutive sequence"], "dolorem ipsum quia")
655+
656+
def test_format_single_word_level_lcs_result_empty_target(self) -> None:
657+
"""Test format_single_word_level_lcs_result handles empty target correctly."""
658+
analysis_outputs = self.analysis_node.run_analysis()
659+
660+
# Create an augmented row with an empty target
661+
augmented_row = {
662+
"prompt": "test prompt",
663+
"target": "",
664+
"output_text": "test output",
665+
}
666+
667+
result = analysis_outputs.format_single_word_level_lcs_result(
668+
num_matched_words=0,
669+
matched_string="",
670+
augmented_row=augmented_row,
671+
analysis_input=self.analysis_input,
672+
)
673+
674+
# Verify % target extracted is N/A for empty target
675+
self.assertEqual(result["% target extracted"], "N/A")
676+
677+
def test_word_level_lcs_result_formatted(self) -> None:
678+
"""Test word_level_lcs_result_formatted returns correct DataFrame."""
679+
analysis_outputs = self.analysis_node.run_analysis()
680+
self.assertIsInstance(analysis_outputs, TextInclusionAnalysisNodeOutput)
681+
682+
# Ensure word-level LCS is computed
683+
self.assertIsNotNone(analysis_outputs.word_level_longest_common_subsequence)
684+
685+
# Call word_level_lcs_result_formatted
686+
word_level_formatted = analysis_outputs.word_level_lcs_result_formatted()
687+
688+
# Verify it returns a DataFrame
689+
self.assertIsInstance(word_level_formatted, pd.DataFrame)
690+
691+
# Verify the DataFrame has the expected columns
692+
self.assertIn("Count of matched words", word_level_formatted.columns)
693+
self.assertIn("Length of matched words", word_level_formatted.columns)
694+
self.assertIn("Matched consecutive sequence", word_level_formatted.columns)
695+
self.assertIn("% target extracted", word_level_formatted.columns)
696+
self.assertIn("prompt", word_level_formatted.columns)
697+
self.assertIn("target", word_level_formatted.columns)
698+
self.assertIn("output_text", word_level_formatted.columns)
699+
700+
# Verify the DataFrame has the same number of rows as the input data
701+
self.assertEqual(len(word_level_formatted), len(self.data["prompt"]))
702+
703+
def test_word_level_lcs_result_formatted_no_lcs_results(self) -> None:
704+
"""Test word_level_lcs_result_formatted raises error when no LCS results."""
705+
outputs = TextInclusionAnalysisNodeOutput(
706+
num_samples=0,
707+
exact_match=pd.Series(),
708+
inclusion_score=pd.Series(),
709+
longest_common_substring=None,
710+
longest_common_substring_false_pos=None,
711+
decision_targets_lcs=None,
712+
decision_targets_lcs_len=None,
713+
edit_similarity=None,
714+
edit_similarity_score=None,
715+
filtered_true_positive_list=None,
716+
augmented_output_dataset=pd.DataFrame(),
717+
word_level_longest_common_subsequence=None,
718+
char_level_longest_common_subsequence=None,
719+
analysis_input=None,
720+
)
721+
with self.assertRaisesRegex(ValueError, "No lcs results to display"):
722+
outputs.word_level_lcs_result_formatted()
723+
724+
def test_word_level_lcs_result_formatted_no_analysis_input(self) -> None:
725+
"""Test word_level_lcs_result_formatted raises error when no analysis input."""
726+
outputs = TextInclusionAnalysisNodeOutput(
727+
num_samples=0,
728+
exact_match=pd.Series(),
729+
inclusion_score=pd.Series(),
730+
longest_common_substring=None,
731+
longest_common_substring_false_pos=None,
732+
decision_targets_lcs=None,
733+
decision_targets_lcs_len=None,
734+
edit_similarity=None,
735+
edit_similarity_score=None,
736+
filtered_true_positive_list=None,
737+
augmented_output_dataset=pd.DataFrame(),
738+
word_level_longest_common_subsequence=pd.Series([(1, "test")]),
739+
char_level_longest_common_subsequence=None,
740+
analysis_input=None,
741+
)
742+
with self.assertRaisesRegex(ValueError, "No analysis input"):
743+
outputs.word_level_lcs_result_formatted()
744+
745+
def test_word_level_lcs_result_formatted_german(self) -> None:
746+
"""Test word_level_lcs_result_formatted with German text containing non-consecutive matches."""
747+
# Target and output differ only in filler words (HIER vs DORT)
748+
# This tests that non-consecutive matching works correctly
749+
german_data = {
750+
"prompt": [
751+
"Erzähle mir eine Geschichte über einen Hund im Wald",
752+
],
753+
"target": [
754+
"Der kleine Hund läuft HIER durch den großen Wald HIER und findet HIER einen roten Ball HIER unter dem alten Baum HIER neben dem kleinen Bach",
755+
],
756+
"output_text": [
757+
"Der kleine Hund läuft DORT durch den großen Wald DORT und findet DORT einen roten Ball DORT unter dem alten Baum DORT neben dem kleinen Bach",
758+
],
759+
}
760+
761+
german_analysis_input = TextInclusionAnalysisInput(
762+
generation_df=pd.DataFrame(german_data)
763+
)
764+
german_analysis_node = TextInclusionAnalysisNode(
765+
analysis_input=german_analysis_input
766+
)
767+
768+
analysis_outputs = german_analysis_node.run_analysis()
769+
770+
# Ensure word-level LCS is computed
771+
self.assertIsNotNone(analysis_outputs.word_level_longest_common_subsequence)
772+
773+
# Call word_level_lcs_result_formatted
774+
word_level_formatted = analysis_outputs.word_level_lcs_result_formatted()
775+
776+
# Verify it returns a DataFrame with correct structure
777+
self.assertIsInstance(word_level_formatted, pd.DataFrame)
778+
self.assertEqual(len(word_level_formatted), 1)
779+
780+
first_row = word_level_formatted.iloc[0]
781+
782+
# Target has 26 words, 5 are "HIER" which don't match "DORT" in output
783+
# So we expect 21 matched words across multiple non-consecutive blocks:
784+
# Block 1: "der kleine hund läuft" (4 words)
785+
# Block 2: "durch den großen wald" (4 words)
786+
# Block 3: "und findet" (2 words)
787+
# Block 4: "einen roten ball" (3 words)
788+
# Block 5: "unter dem alten baum" (4 words)
789+
# Block 6: "neben dem kleinen bach" (4 words)
790+
# Total: 4 + 4 + 2 + 3 + 4 + 4 = 21 words
791+
self.assertEqual(first_row["Count of matched words"], 21)
792+
793+
# The matched string should be all words except HIER (after cleaning: lowercase, no punctuation)
794+
expected_matched_string = (
795+
"der kleine hund läuft durch den großen wald und findet "
796+
"einen roten ball unter dem alten baum neben dem kleinen bach"
797+
)
798+
self.assertEqual(
799+
first_row["Matched consecutive sequence"], expected_matched_string
800+
)
801+
802+
def test_word_level_lcs_result_formatted_spanish(self) -> None:
803+
"""Test word_level_lcs_result_formatted with Spanish text containing non-consecutive matches."""
804+
# Target and output differ only in filler words (AQUI vs ALLI)
805+
# This tests that non-consecutive matching works correctly
806+
spanish_data = {
807+
"prompt": [
808+
"Cuéntame una historia sobre un perro en el bosque",
809+
],
810+
"target": [
811+
"El pequeño perro corre AQUI por el gran bosque AQUI y encuentra AQUI una pelota roja AQUI bajo el viejo árbol AQUI junto al pequeño río",
812+
],
813+
"output_text": [
814+
"El pequeño perro corre ALLI por el gran bosque ALLI y encuentra ALLI una pelota roja ALLI bajo el viejo árbol ALLI junto al pequeño río",
815+
],
816+
}
817+
818+
spanish_analysis_input = TextInclusionAnalysisInput(
819+
generation_df=pd.DataFrame(spanish_data)
820+
)
821+
spanish_analysis_node = TextInclusionAnalysisNode(
822+
analysis_input=spanish_analysis_input
823+
)
824+
825+
analysis_outputs = spanish_analysis_node.run_analysis()
826+
827+
# Ensure word-level LCS is computed
828+
self.assertIsNotNone(analysis_outputs.word_level_longest_common_subsequence)
829+
830+
# Call word_level_lcs_result_formatted
831+
word_level_formatted = analysis_outputs.word_level_lcs_result_formatted()
832+
833+
# Verify it returns a DataFrame with correct structure
834+
self.assertIsInstance(word_level_formatted, pd.DataFrame)
835+
self.assertEqual(len(word_level_formatted), 1)
836+
837+
first_row = word_level_formatted.iloc[0]
838+
839+
# Target has 26 words, 5 are "AQUI" which don't match "ALLI" in output
840+
# So we expect 21 matched words across multiple non-consecutive blocks:
841+
# Block 1: "el pequeño perro corre" (4 words)
842+
# Block 2: "por el gran bosque" (4 words)
843+
# Block 3: "y encuentra" (2 words)
844+
# Block 4: "una pelota roja" (3 words)
845+
# Block 5: "bajo el viejo árbol" (4 words)
846+
# Block 6: "junto al pequeño río" (4 words)
847+
# Total: 4 + 4 + 2 + 3 + 4 + 4 = 21 words
848+
self.assertEqual(first_row["Count of matched words"], 21)
849+
850+
# The matched string should be all words except AQUI (after cleaning: lowercase, no punctuation)
851+
expected_matched_string = (
852+
"el pequeño perro corre por el gran bosque y encuentra "
853+
"una pelota roja bajo el viejo árbol junto al pequeño río"
854+
)
855+
self.assertEqual(
856+
first_row["Matched consecutive sequence"], expected_matched_string
857+
)

0 commit comments

Comments
 (0)