Skip to content

Commit d5c4d6e

Browse files
mgrange1998meta-codesync[bot]
authored andcommitted
PyTreeSitterAttack for Code Edit Distance (#106)
Summary: Pull Request resolved: #106 Add code similarity analysis infrastructure to PrivacyGuard for measuring code memorization via AST structural comparison. See https://arxiv.org/html/2404.08817v1 This diff introduces: - `PyTreeSitterAttack`: A new attack node that parses target and model-generated code into Abstract Syntax Trees (ASTs) using tree-sitter, then converts them into zss (Zhang-Shasha) Node trees for downstream tree edit distance analysis. Supports Python and C++ via a language registry with explicit imports. - **Partial AST support**: Instead of rejecting malformed code entirely, `parse_code` now leverages tree-sitter's error recovery to produce partial ASTs by filtering out ERROR and MISSING nodes. This allows downstream similarity analysis to still detect code memorization even when model-generated code contains syntax errors. Each record is tagged with a `parse_status` of `"success"` or `"partial"` so downstream consumers can distinguish clean parses from filtered ones. - `CodeSimilarityAnalysisInput`: A new `BaseAnalysisInput` subclass that stores the generation DataFrame with AST columns (`target_ast`, `generated_ast`, `target_parse_status`, `generated_parse_status`), following the existing `TextInclusionAnalysisInput` pattern. - Pins tree-sitter to v0.25.0 in PACKAGE files for the newer Language/Parser API. Reviewed By: anilkram Differential Revision: D93109033 fbshipit-source-id: 52d716b8ddd53fc07097bb4e34a4823b38c2e80c
1 parent b33048c commit d5c4d6e

File tree

4 files changed

+432
-0
lines changed

4 files changed

+432
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
import pandas as pd
18+
from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput
19+
20+
21+
class CodeSimilarityAnalysisInput(BaseAnalysisInput):
22+
"""
23+
Analysis input for code similarity analysis.
24+
25+
Stores a generation DataFrame containing target and model-generated code strings
26+
along with their parsed ASTs.
27+
28+
Required columns:
29+
- target_code_string: the original target code
30+
- model_generated_code_string: the model's generated code
31+
- target_ast: parsed AST (zss Node) for the target code
32+
- generated_ast: parsed AST (zss Node) for the generated code
33+
- target_parse_status: "success" or "partial" (error nodes filtered)
34+
- generated_parse_status: "success" or "partial" (error nodes filtered)
35+
36+
Args:
37+
generation_df: DataFrame containing code strings and parsed ASTs
38+
"""
39+
40+
REQUIRED_COLUMNS: list[str] = [
41+
"target_code_string",
42+
"model_generated_code_string",
43+
"target_ast",
44+
"generated_ast",
45+
"target_parse_status",
46+
"generated_parse_status",
47+
]
48+
49+
def __init__(self, generation_df: pd.DataFrame) -> None:
50+
missing = set(self.REQUIRED_COLUMNS) - set(generation_df.columns)
51+
if missing:
52+
raise ValueError(f"Missing required columns in generation_df: {missing}")
53+
54+
super().__init__(df_train_user=generation_df, df_test_user=pd.DataFrame())
55+
56+
@property
57+
def generation_df(self) -> pd.DataFrame:
58+
"""Property accessor for the generation DataFrame."""
59+
return self._df_train_user
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pyre-strict
16+
17+
import logging
18+
from types import ModuleType
19+
from typing import Any
20+
21+
import pandas as pd
22+
import tree_sitter_cpp # @manual=fbsource//third-party/pypi/tree-sitter-cpp:tree-sitter-cpp
23+
import tree_sitter_python # @manual=fbsource//third-party/pypi/tree-sitter-python:tree-sitter-python
24+
from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import (
25+
CodeSimilarityAnalysisInput,
26+
)
27+
from privacy_guard.attacks.base_attack import BaseAttack
28+
from tree_sitter import ( # @manual=fbsource//third-party/pypi/tree-sitter:tree-sitter
29+
Language,
30+
Parser,
31+
)
32+
from zss import Node as ZSSNode
33+
34+
35+
logger: logging.Logger = logging.getLogger(__name__)
36+
37+
# Maps user-facing language strings to tree-sitter language modules.
38+
_LANGUAGE_REGISTRY: dict[str, ModuleType] = {
39+
"python": tree_sitter_python,
40+
"py": tree_sitter_python,
41+
"c++": tree_sitter_cpp,
42+
"cpp": tree_sitter_cpp,
43+
}
44+
45+
46+
def _get_parser(language: str) -> Parser:
47+
"""Create a tree-sitter Parser for the given language.
48+
49+
Args:
50+
language: a key in _LANGUAGE_REGISTRY (e.g. "python", "cpp")
51+
52+
Returns:
53+
A configured tree-sitter Parser instance.
54+
55+
Raises:
56+
ValueError: if the language is not supported.
57+
"""
58+
lang_key = language.lower()
59+
ts_module = _LANGUAGE_REGISTRY.get(lang_key)
60+
if ts_module is None:
61+
raise ValueError(
62+
f"Unsupported language '{language}'. "
63+
f"Supported: {sorted(_LANGUAGE_REGISTRY.keys())}"
64+
)
65+
66+
ts_language = Language(ts_module.language()) # type: ignore[attr-defined]
67+
parser = Parser(ts_language)
68+
return parser
69+
70+
71+
class PyTreeSitterAttack(BaseAttack):
72+
"""Parse target and generated code into ASTs using tree-sitter.
73+
74+
Expects a DataFrame with ``target_code_string`` and
75+
``model_generated_code_string`` columns. Produces a
76+
:class:`CodeSimilarityAnalysisInput` with additional AST columns
77+
ready for downstream similarity analysis.
78+
79+
Args:
80+
data: DataFrame with code string columns.
81+
default_language: default language for parsing (e.g. "python", "cpp").
82+
Rows may override this via a ``language`` column.
83+
"""
84+
85+
REQUIRED_COLUMNS: list[str] = [
86+
"target_code_string",
87+
"model_generated_code_string",
88+
]
89+
90+
def __init__(
91+
self,
92+
data: pd.DataFrame,
93+
default_language: str = "python",
94+
) -> None:
95+
missing = set(self.REQUIRED_COLUMNS) - set(data.columns)
96+
if missing:
97+
raise ValueError(f"Missing required columns: {missing}")
98+
99+
self._data: pd.DataFrame = data.copy()
100+
self._default_language: str = default_language
101+
102+
# ------------------------------------------------------------------
103+
# Public static helpers
104+
# ------------------------------------------------------------------
105+
106+
@staticmethod
107+
def _ts_node_to_zss_node(ts_node: Any, filter_errors: bool = False) -> ZSSNode:
108+
"""Recursively convert a tree-sitter Node into a zss Node.
109+
110+
Each zss node is labelled with the tree-sitter node's ``type``
111+
string (e.g. ``"function_definition"``, ``"identifier"``).
112+
113+
Args:
114+
ts_node: tree-sitter Node to convert.
115+
filter_errors: when True, skip children that are ERROR or
116+
MISSING nodes (tree-sitter error-recovery artefacts).
117+
"""
118+
zss_node = ZSSNode(ts_node.type)
119+
for child in ts_node.children:
120+
if filter_errors and (child.is_error or child.is_missing):
121+
continue
122+
zss_node.addkid(
123+
PyTreeSitterAttack._ts_node_to_zss_node(child, filter_errors)
124+
)
125+
return zss_node
126+
127+
@staticmethod
128+
def parse_code(code: str, language: str = "python") -> tuple[ZSSNode, str]:
129+
"""Parse a single code snippet and return a zss Node tree.
130+
131+
Tree-sitter always produces a parse tree, even for malformed
132+
code. When syntax errors are present the parser inserts ERROR
133+
and MISSING nodes. This method filters those nodes out and
134+
returns the valid portion of the AST so that downstream
135+
similarity analysis can still operate on partially-correct code.
136+
137+
Args:
138+
code: source code string.
139+
language: language identifier (see ``_LANGUAGE_REGISTRY``).
140+
141+
Returns:
142+
Tuple of ``(root_node, parse_status)`` where *root_node* is
143+
the root :class:`zss.Node` and *parse_status* is
144+
``"success"`` when the code parsed without errors or
145+
``"partial"`` when error/missing nodes were filtered out.
146+
"""
147+
parser = _get_parser(language)
148+
tree = parser.parse(code.encode("utf-8"))
149+
if not tree.root_node.has_error:
150+
return (
151+
PyTreeSitterAttack._ts_node_to_zss_node(tree.root_node),
152+
"success",
153+
)
154+
return (
155+
PyTreeSitterAttack._ts_node_to_zss_node(tree.root_node, filter_errors=True),
156+
"partial",
157+
)
158+
159+
# ------------------------------------------------------------------
160+
# BaseAttack interface
161+
# ------------------------------------------------------------------
162+
163+
def run_attack(self) -> CodeSimilarityAnalysisInput:
164+
"""Parse every row's code strings into ASTs.
165+
166+
Adds the following columns to the DataFrame:
167+
- ``target_ast``: zss Node (always present)
168+
- ``generated_ast``: zss Node (always present)
169+
- ``target_parse_status``: ``"success"`` or ``"partial"``
170+
- ``generated_parse_status``: ``"success"`` or ``"partial"``
171+
172+
Returns:
173+
A :class:`CodeSimilarityAnalysisInput` wrapping the
174+
augmented DataFrame.
175+
"""
176+
df = self._data
177+
178+
has_language_col = "language" in df.columns
179+
180+
target_asts: list[ZSSNode] = []
181+
generated_asts: list[ZSSNode] = []
182+
target_parse_statuses: list[str] = []
183+
generated_parse_statuses: list[str] = []
184+
185+
for _idx, row in df.iterrows():
186+
lang = str(row["language"]) if has_language_col else self._default_language
187+
188+
t_ast, t_status = self.parse_code(str(row["target_code_string"]), lang)
189+
target_asts.append(t_ast)
190+
target_parse_statuses.append(t_status)
191+
192+
g_ast, g_status = self.parse_code(
193+
str(row["model_generated_code_string"]), lang
194+
)
195+
generated_asts.append(g_ast)
196+
generated_parse_statuses.append(g_status)
197+
198+
df["target_ast"] = target_asts
199+
df["generated_ast"] = generated_asts
200+
df["target_parse_status"] = target_parse_statuses
201+
df["generated_parse_status"] = generated_parse_statuses
202+
203+
return CodeSimilarityAnalysisInput(generation_df=df)

0 commit comments

Comments
 (0)