Skip to content

Commit 7163f3b

Browse files
mgrange1998facebook-github-bot
authored andcommitted
TreeEditDistanceNode for Code Similarity Analysis
Summary: Add the tree edit distance analysis node to PrivacyGuard, completing the code memorization measurement pipeline. See https://arxiv.org/html/2404.08817v1 This diff introduces: - `TreeEditDistanceNode`: A new `BaseAnalysisNode` that computes normalized tree edit distance similarity between AST pairs produced by `PyTreeSitterAttack` (from Diff 1). Uses the Zhang-Shasha algorithm via `zss.simple_distance()` with normalization `max(1 - distance / max(n1, n2), 0)` to produce a 0-1 similarity score. Supports per-language grouping when a `language` column is present. - `TreeEditDistanceNodeOutput`: A `BaseAnalysisOutput` dataclass with fields for `num_samples`, `num_both_parsed`, `per_sample_similarity`, `avg_similarity`, and optional `avg_similarity_by_language`. - Updated to work with the partial AST parsing from Diff 1: since `PyTreeSitterAttack` now always produces an AST (full or partial), the analysis node computes similarity for all rows unconditionally. Consumers can use the `parse_status` columns from the input to distinguish full vs partial parse results. - Adds both targets to the `analysis_library` umbrella. Differential Revision: D93109088
1 parent c3b204a commit 7163f3b

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
import logging
18+
from dataclasses import dataclass, field
19+
from typing import cast
20+
21+
import pandas as pd
22+
from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode
23+
from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput
24+
from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import (
25+
CodeSimilarityAnalysisInput,
26+
)
27+
from zss import Node as ZSSNode, simple_distance
28+
29+
30+
logger: logging.Logger = logging.getLogger(__name__)
31+
32+
33+
def _count_nodes(node: ZSSNode) -> int:
34+
"""Recursively count the number of nodes in a zss tree."""
35+
count = 1
36+
for child in node.children:
37+
count += _count_nodes(child)
38+
return count
39+
40+
41+
@dataclass
42+
class TreeEditDistanceNodeOutput(BaseAnalysisOutput):
43+
"""Output of :class:`TreeEditDistanceNode`.
44+
45+
Attributes:
46+
num_samples: total number of sample rows.
47+
num_both_parsed: number of rows where both target and generated
48+
code produced an AST (always equals *num_samples* since the
49+
attack now returns partial ASTs for malformed code).
50+
per_sample_similarity: DataFrame with a ``similarity`` column.
51+
avg_similarity: average similarity across all pairs.
52+
avg_similarity_by_language: per-language average similarity, or
53+
``None`` when no ``language`` column is present.
54+
"""
55+
56+
num_samples: int
57+
num_both_parsed: int
58+
per_sample_similarity: pd.DataFrame = field(repr=False)
59+
avg_similarity: float
60+
avg_similarity_by_language: dict[str, float] | None
61+
62+
63+
class TreeEditDistanceNode(BaseAnalysisNode):
64+
"""Compute tree-edit-distance similarity between AST pairs.
65+
66+
Uses the Zhang-Shasha algorithm (via ``zss.simple_distance``) to
67+
compute edit distance, then normalises to a 0-1 similarity score::
68+
69+
similarity = max(1 - distance / max(n1, n2), 0)
70+
71+
where *n1* and *n2* are the node counts of the two trees.
72+
73+
Args:
74+
analysis_input: a :class:`CodeSimilarityAnalysisInput` produced
75+
by :class:`PyTreeSitterAttack`.
76+
"""
77+
78+
def __init__(self, analysis_input: CodeSimilarityAnalysisInput) -> None:
79+
super().__init__(analysis_input=analysis_input)
80+
81+
# ------------------------------------------------------------------
82+
# Public static helper
83+
# ------------------------------------------------------------------
84+
85+
@staticmethod
86+
def compute_similarity(tree1: ZSSNode, tree2: ZSSNode) -> float:
87+
"""Compute normalised tree-edit-distance similarity.
88+
89+
Args:
90+
tree1: first zss Node tree.
91+
tree2: second zss Node tree.
92+
93+
Returns:
94+
Similarity in [0, 1] where 1.0 means identical trees.
95+
"""
96+
dist: int = simple_distance(tree1, tree2)
97+
n1 = _count_nodes(tree1)
98+
n2 = _count_nodes(tree2)
99+
max_nodes = max(n1, n2)
100+
if max_nodes == 0:
101+
return 1.0
102+
return max(1.0 - dist / max_nodes, 0.0)
103+
104+
# ------------------------------------------------------------------
105+
# BaseAnalysisNode interface
106+
# ------------------------------------------------------------------
107+
108+
def run_analysis(self) -> TreeEditDistanceNodeOutput:
109+
analysis_input = cast(CodeSimilarityAnalysisInput, self.analysis_input)
110+
df = analysis_input.generation_df
111+
112+
def _row_similarity(row: pd.Series) -> float: # type: ignore[type-arg]
113+
return TreeEditDistanceNode.compute_similarity(
114+
row["target_ast"], row["generated_ast"]
115+
)
116+
117+
similarities = df.apply(_row_similarity, axis=1)
118+
per_sample = pd.DataFrame({"similarity": similarities})
119+
120+
num_both_parsed = len(similarities)
121+
avg_similarity = float(similarities.mean()) if num_both_parsed > 0 else 0.0
122+
123+
avg_by_lang: dict[str, float] | None = None
124+
if "language" in df.columns:
125+
per_sample["language"] = df["language"].values
126+
grouped = per_sample.groupby("language")["similarity"].mean()
127+
avg_by_lang = grouped.to_dict()
128+
129+
return TreeEditDistanceNodeOutput(
130+
num_samples=len(df),
131+
num_both_parsed=num_both_parsed,
132+
per_sample_similarity=per_sample,
133+
avg_similarity=avg_similarity,
134+
avg_similarity_by_language=avg_by_lang,
135+
)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
import unittest
18+
19+
import pandas as pd
20+
from privacy_guard.analysis.code_similarity.tree_edit_distance_node import (
21+
TreeEditDistanceNode,
22+
TreeEditDistanceNodeOutput,
23+
)
24+
from privacy_guard.attacks.code_similarity.py_tree_sitter_attack import (
25+
PyTreeSitterAttack,
26+
)
27+
28+
29+
def _run_e2e(
30+
df: pd.DataFrame,
31+
default_language: str = "python",
32+
) -> TreeEditDistanceNodeOutput:
33+
"""Helper: run attack then analysis end-to-end."""
34+
attack = PyTreeSitterAttack(data=df, default_language=default_language)
35+
analysis_input = attack.run_attack()
36+
node = TreeEditDistanceNode(analysis_input=analysis_input)
37+
return node.run_analysis()
38+
39+
40+
class TestTreeEditDistanceNode(unittest.TestCase):
41+
def test_similarity_values(self) -> None:
42+
"""Identical code should yield ~1.0; different code should be low."""
43+
with self.subTest("identical_python"):
44+
code = "def foo():\n return 1\n"
45+
df = pd.DataFrame(
46+
{
47+
"target_code_string": [code],
48+
"model_generated_code_string": [code],
49+
}
50+
)
51+
output = _run_e2e(df)
52+
self.assertIsInstance(output, TreeEditDistanceNodeOutput)
53+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
54+
self.assertEqual(output.num_both_parsed, 1)
55+
56+
with self.subTest("different_python"):
57+
df = pd.DataFrame(
58+
{
59+
"target_code_string": ["def foo():\n return 1\n"],
60+
"model_generated_code_string": [
61+
"class Bar:\n def __init__(self):\n"
62+
" self.x = 1\n"
63+
" def method(self, a, b):\n"
64+
" return a + b\n"
65+
],
66+
}
67+
)
68+
output = _run_e2e(df)
69+
self.assertLess(output.avg_similarity, 0.5)
70+
71+
with self.subTest("cpp_similarity"):
72+
df = pd.DataFrame(
73+
{
74+
"target_code_string": ["int add(int a, int b) { return a + b; }"],
75+
"model_generated_code_string": [
76+
"int sum(int x, int y) { return x + y; }"
77+
],
78+
}
79+
)
80+
output = _run_e2e(df, default_language="cpp")
81+
self.assertGreater(output.avg_similarity, 0.7)
82+
83+
with self.subTest("partial_parse_high_similarity"):
84+
# Generated code contains the same function as the target
85+
# but is surrounded by syntax errors. After error-node
86+
# filtering the partial AST should still yield high
87+
# similarity against the clean target.
88+
target = "def foo():\n x = 1\n return x\n"
89+
generated = "))))\ndef foo():\n x = 1\n @@@@\n return x\n(((\n"
90+
df = pd.DataFrame(
91+
{
92+
"target_code_string": [target],
93+
"model_generated_code_string": [generated],
94+
}
95+
)
96+
output = _run_e2e(df)
97+
# Partial parse still produces a similarity score (not NaN)
98+
self.assertEqual(output.num_both_parsed, 1)
99+
self.assertGreater(output.avg_similarity, 0.5)
100+
101+
with self.subTest("ast_equivalence_different_strings"):
102+
# Two code snippets that are syntactically equivalent but
103+
# differ in identifier names and string literals should
104+
# yield similarity ≈ 1.0 because tree-sitter AST nodes are
105+
# labelled by grammar category (e.g. "identifier", "string"),
106+
# not by the actual text content.
107+
target = 'def compute():\n result = "hello"\n return result\n'
108+
generated = 'def process():\n output = "world"\n return output\n'
109+
df = pd.DataFrame(
110+
{
111+
"target_code_string": [target],
112+
"model_generated_code_string": [generated],
113+
}
114+
)
115+
output = _run_e2e(df)
116+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
117+
118+
def test_avg_similarity_by_language(self) -> None:
119+
"""Mixed Python+C++ input produces per-language averages."""
120+
df = pd.DataFrame(
121+
{
122+
"target_code_string": [
123+
"def foo():\n return 1\n",
124+
"int main() { return 0; }",
125+
],
126+
"model_generated_code_string": [
127+
"def foo():\n return 1\n",
128+
"int main() { return 0; }",
129+
],
130+
"language": ["python", "cpp"],
131+
}
132+
)
133+
output = _run_e2e(df)
134+
assert output.avg_similarity_by_language is not None
135+
by_lang = output.avg_similarity_by_language
136+
self.assertIn("python", by_lang)
137+
self.assertIn("cpp", by_lang)
138+
self.assertAlmostEqual(by_lang["python"], 1.0, places=5)
139+
self.assertAlmostEqual(by_lang["cpp"], 1.0, places=5)
140+
141+
def test_compute_similarity_static_method(self) -> None:
142+
"""TreeEditDistanceNode.compute_similarity works standalone."""
143+
node1, _ = PyTreeSitterAttack.parse_code("x = 1\n", language="python")
144+
node2, _ = PyTreeSitterAttack.parse_code("x = 1\n", language="python")
145+
146+
sim = TreeEditDistanceNode.compute_similarity(node1, node2)
147+
self.assertAlmostEqual(sim, 1.0, places=5)

0 commit comments

Comments
 (0)