Skip to content

Commit 564bad1

Browse files
Matthieu Meeusmeta-codesync[bot]
authored andcommitted
Expanding tree edit similarity to different languages (#128)
Summary: Pull Request resolved: #128 Extend tree-edit similarity (`PyTreeSitterAttack`) to support more coding languages, in particular the ones already covered by CodeBLEU (Python, C, C++, Java, Rust, JavaScript, Go, Ruby, PHP, C#). Previously only Python and C++ were supported. This diff introduces: - Unified and extended grammar backend in `py_tree_sitter_attack.py`: Replaces the standalone tree-sitter-python and tree-sitter-cpp packages with the codebleu package's bundled `my-languages.so`. This provides a single grammar library covering all 10 languages, is consistent with the `CodeBleuAttack` module and simplifies `_get_parser()` to a 3-line function. Verified that the codebleu-bundled Python grammar produces identical trees (zero edit distance) to the previous implementation. - No changes to the analysis layer: `TreeEditDistanceNode` is already language-agnostic, it operates purely on zss Node trees. Reviewed By: mgrange1998 Differential Revision: D102700637 fbshipit-source-id: 3708fd9a784522e512c0ad87d6b3a1677540d1e3
1 parent e4c4436 commit 564bad1

3 files changed

Lines changed: 508 additions & 44 deletions

File tree

privacy_guard/analysis/tests/test_tree_edit_distance_node.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,28 +115,108 @@ def test_similarity_values(self) -> None:
115115
output = _run_e2e(df)
116116
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
117117

118+
def test_java_similarity(self) -> None:
119+
"""Identical Java code yields ~1.0; structurally similar code is high."""
120+
with self.subTest("identical"):
121+
code = "class Foo { int add(int a, int b) { return a + b; } }"
122+
df = pd.DataFrame(
123+
{
124+
"target_code_string": [code],
125+
"model_generated_code_string": [code],
126+
}
127+
)
128+
output = _run_e2e(df, default_language="java")
129+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
130+
131+
with self.subTest("similar"):
132+
df = pd.DataFrame(
133+
{
134+
"target_code_string": [
135+
"class Foo { int add(int a, int b) { return a + b; } }"
136+
],
137+
"model_generated_code_string": [
138+
"class Bar { int sum(int x, int y) { return x + y; } }"
139+
],
140+
}
141+
)
142+
output = _run_e2e(df, default_language="java")
143+
self.assertGreater(output.avg_similarity, 0.7)
144+
145+
def test_c_similarity(self) -> None:
146+
"""Identical C code yields ~1.0."""
147+
code = "int add(int a, int b) { return a + b; }"
148+
df = pd.DataFrame(
149+
{
150+
"target_code_string": [code],
151+
"model_generated_code_string": [code],
152+
}
153+
)
154+
output = _run_e2e(df, default_language="c")
155+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
156+
157+
def test_rust_similarity(self) -> None:
158+
"""Identical Rust code yields ~1.0."""
159+
code = "fn add(a: i32, b: i32) -> i32 { a + b }"
160+
df = pd.DataFrame(
161+
{
162+
"target_code_string": [code],
163+
"model_generated_code_string": [code],
164+
}
165+
)
166+
output = _run_e2e(df, default_language="rust")
167+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
168+
169+
def test_ruby_similarity(self) -> None:
170+
"""Identical Ruby code yields ~1.0."""
171+
code = "def add(a, b)\n a + b\nend"
172+
df = pd.DataFrame(
173+
{
174+
"target_code_string": [code],
175+
"model_generated_code_string": [code],
176+
}
177+
)
178+
output = _run_e2e(df, default_language="ruby")
179+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
180+
181+
def test_c_sharp_similarity(self) -> None:
182+
"""Identical C# code yields ~1.0."""
183+
code = "class Foo { int Add(int a, int b) { return a + b; } }"
184+
df = pd.DataFrame(
185+
{
186+
"target_code_string": [code],
187+
"model_generated_code_string": [code],
188+
}
189+
)
190+
output = _run_e2e(df, default_language="c_sharp")
191+
self.assertAlmostEqual(output.avg_similarity, 1.0, places=5)
192+
118193
def test_avg_similarity_by_language(self) -> None:
119-
"""Mixed Python+C++ input produces per-language averages."""
194+
"""Mixed multi-language input produces per-language averages."""
120195
df = pd.DataFrame(
121196
{
122197
"target_code_string": [
123198
"def foo():\n return 1\n",
124199
"int main() { return 0; }",
200+
"class Foo { int add(int a, int b) { return a + b; } }",
201+
"fn add(a: i32, b: i32) -> i32 { a + b }",
202+
"int add(int a, int b) { return a + b; }",
125203
],
126204
"model_generated_code_string": [
127205
"def foo():\n return 1\n",
128206
"int main() { return 0; }",
207+
"class Foo { int add(int a, int b) { return a + b; } }",
208+
"fn add(a: i32, b: i32) -> i32 { a + b }",
209+
"int add(int a, int b) { return a + b; }",
129210
],
130-
"language": ["python", "cpp"],
211+
"language": ["python", "cpp", "java", "rust", "c"],
131212
}
132213
)
133214
output = _run_e2e(df)
134215
assert output.avg_similarity_by_language is not None
135216
by_lang = output.avg_similarity_by_language
136-
self.assertIn("python", by_lang)
137-
self.assertIn("cpp", by_lang)
138-
self.assertAlmostEqual(by_lang["python"], 1.0, places=5)
139-
self.assertAlmostEqual(by_lang["cpp"], 1.0, places=5)
217+
for lang in ["python", "cpp", "java", "rust", "c"]:
218+
self.assertIn(lang, by_lang)
219+
self.assertAlmostEqual(by_lang[lang], 1.0, places=5)
140220

141221
def test_compute_similarity_static_method(self) -> None:
142222
"""TreeEditDistanceNode.compute_similarity works standalone."""

privacy_guard/attacks/code_similarity/py_tree_sitter_attack.py

Lines changed: 23 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
# pyre-strict
1616

17-
import ctypes
17+
import importlib.resources
1818
import logging
19-
from types import ModuleType
2019
from typing import Any
2120

2221
import pandas as pd
23-
import tree_sitter_cpp # @manual=fbsource//third-party/pypi/tree-sitter-cpp:tree-sitter-cpp
24-
import tree_sitter_python # @manual=fbsource//third-party/pypi/tree-sitter-python:tree-sitter-python
22+
from codebleu.codebleu import ( # @manual=fbsource//third-party/pypi/codebleu:codebleu
23+
AVAILABLE_LANGS,
24+
)
2525
from privacy_guard.analysis.code_similarity.code_similarity_analysis_input import (
2626
CodeSimilarityAnalysisInput,
2727
)
@@ -38,57 +38,42 @@
3838

3939
logger: logging.Logger = logging.getLogger(__name__)
4040

41-
# Maps user-facing language strings to tree-sitter language modules.
42-
_LANGUAGE_REGISTRY: dict[str, ModuleType] = {
43-
"python": tree_sitter_python,
44-
"py": tree_sitter_python,
45-
"c++": tree_sitter_cpp,
46-
"cpp": tree_sitter_cpp,
41+
# Aliases that map to canonical codebleu language names.
42+
_LANGUAGE_ALIASES: dict[str, str] = {
43+
"py": "python",
44+
"c++": "cpp",
45+
"js": "javascript",
4746
}
4847

49-
50-
def _language_from_capsule(ts_module: ModuleType) -> Language:
51-
"""Create a tree-sitter Language from a language module's capsule.
52-
53-
tree-sitter 0.20.4 expects ``Language(library_path, name)`` but the
54-
modern language packages (tree-sitter-python, tree-sitter-cpp) expose
55-
a ``language()`` function returning a PyCapsule. We extract the raw
56-
pointer from the capsule and construct a Language-compatible object.
57-
"""
58-
capsule = ts_module.language() # type: ignore[attr-defined]
59-
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p
60-
ctypes.pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, ctypes.c_char_p]
61-
language_id: ctypes.c_void_p = ctypes.pythonapi.PyCapsule_GetPointer(
62-
capsule, b"tree_sitter.Language"
63-
)
64-
lang = Language.__new__(Language)
65-
lang.language_id = language_id # type: ignore[attr-defined]
66-
return lang
48+
SUPPORTED_LANGS: list[str] = AVAILABLE_LANGS
6749

6850

6951
def _get_parser(language: str) -> Parser: # pyre-ignore[11]
7052
"""Create a tree-sitter Parser for the given language.
7153
54+
Uses the grammar bundled in the ``codebleu`` package (``my-languages.so``)
55+
which supports: java, javascript, c_sharp, php, c, cpp, python, go, ruby, rust.
56+
7257
Args:
73-
language: a key in _LANGUAGE_REGISTRY (e.g. "python", "cpp")
58+
language: a language name from ``codebleu.AVAILABLE_LANGS``,
59+
or an alias (e.g. "py", "c++", "js").
7460
7561
Returns:
7662
A configured tree-sitter Parser instance.
7763
7864
Raises:
7965
ValueError: if the language is not supported.
8066
"""
81-
lang_key = language.lower()
82-
ts_module = _LANGUAGE_REGISTRY.get(lang_key)
83-
if ts_module is None:
67+
lang_key = _LANGUAGE_ALIASES.get(language.lower(), language.lower())
68+
if lang_key not in AVAILABLE_LANGS:
8469
raise ValueError(
85-
f"Unsupported language '{language}'. "
86-
f"Supported: {sorted(_LANGUAGE_REGISTRY.keys())}"
70+
f"Unsupported language '{language}'. Supported: {sorted(AVAILABLE_LANGS)}"
8771
)
88-
89-
ts_language = _language_from_capsule(ts_module)
90-
parser = Parser() # pyre-ignore[16]
91-
# pyre-ignore[16]: Module `tree_sitter` has no attribute `Parser`
72+
ts_language = Language(
73+
importlib.resources.files("codebleu") / "my-languages.so", lang_key
74+
)
75+
# pyre-ignore[16]: Module `tree_sitter` has no attribute `Parser`.
76+
parser = Parser()
9277
parser.set_language(ts_language)
9378
return parser
9479

0 commit comments

Comments
 (0)