Skip to content

Commit 04e0eb6

Browse files
authored
Merge pull request #1 from trailofbits/fegge/miden-assembly
Add support for Miden assembly
2 parents fbbe460 + 8829277 commit 04e0eb6

File tree

14 files changed

+25942
-600
lines changed

14 files changed

+25942
-600
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ A language-specific parser walks the directory, parses each file into a tree-sit
5858
| Circom | `.circom` | templates, functions, signals, components |
5959
| Haskell | `.hs` | functions, data types, type classes, instances |
6060
| Erlang | `.erl` | functions, records, behaviours, modules |
61+
| Miden Assembly | `.masm` | procedures, entrypoints, constants, invocations |
6162

6263
```mermaid
6364
flowchart TD
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Miden assembly language parser for Trailmark."""
2+
3+
from trailmark.parsers.masm.parser import MasmParser
4+
5+
__all__ = ["MasmParser"]
Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
"""Miden assembly language parser using a vendored tree-sitter grammar."""
2+
3+
from __future__ import annotations
4+
5+
from pathlib import Path
6+
7+
from tree_sitter import Language, Node, Parser
8+
9+
from trailmark.models.edges import CodeEdge, EdgeConfidence, EdgeKind
10+
from trailmark.models.graph import CodeGraph
11+
from trailmark.models.nodes import (
12+
BranchInfo,
13+
CodeUnit,
14+
NodeKind,
15+
Parameter,
16+
)
17+
from trailmark.parsers._common import (
18+
add_contains_edge,
19+
add_module_node,
20+
compute_complexity,
21+
make_location,
22+
module_id_from_path,
23+
node_text,
24+
parse_directory,
25+
)
26+
27+
_BRANCH_NODE_TYPES = frozenset(
28+
{
29+
"if",
30+
"while",
31+
"repeat",
32+
}
33+
)
34+
35+
_EXTENSIONS = (".masm",)
36+
37+
38+
class MasmParser:
39+
"""Parses Miden assembly source files into CodeGraph."""
40+
41+
@property
42+
def language(self) -> str:
43+
return "masm"
44+
45+
def __init__(self) -> None:
46+
from trailmark.tree_sitter_custom.masm import (
47+
language as masm_language,
48+
)
49+
50+
lang = Language(masm_language())
51+
self._parser = Parser(lang)
52+
53+
def parse_file(self, file_path: str) -> CodeGraph:
54+
"""Parse a single Miden assembly file into a CodeGraph."""
55+
source = Path(file_path).read_bytes()
56+
tree = self._parser.parse(source)
57+
graph = CodeGraph(language="masm", root_path=file_path)
58+
module_id = module_id_from_path(file_path)
59+
_visit_module(tree.root_node, file_path, module_id, graph)
60+
return graph
61+
62+
def parse_directory(self, dir_path: str) -> CodeGraph:
63+
"""Parse all .masm files under dir_path."""
64+
return parse_directory(
65+
self.parse_file,
66+
"masm",
67+
dir_path,
68+
_EXTENSIONS,
69+
)
70+
71+
72+
def _visit_module(
73+
root: Node,
74+
file_path: str,
75+
module_id: str,
76+
graph: CodeGraph,
77+
) -> None:
78+
"""Walk the top-level of a Miden assembly file."""
79+
add_module_node(root, file_path, module_id, graph)
80+
for child in root.children:
81+
_visit_top_level(child, file_path, module_id, graph)
82+
83+
84+
def _visit_top_level(
85+
child: Node,
86+
file_path: str,
87+
module_id: str,
88+
graph: CodeGraph,
89+
) -> None:
90+
"""Dispatch a single top-level node."""
91+
if child.type == "procedure":
92+
_extract_procedure(child, file_path, module_id, graph)
93+
elif child.type == "entrypoint":
94+
_extract_entrypoint(child, file_path, module_id, graph)
95+
elif child.type == "import":
96+
_extract_import(child, graph)
97+
elif child.type == "reexport":
98+
_extract_import(child, graph)
99+
elif child.type == "constant":
100+
_extract_constant(child, file_path, module_id, graph)
101+
102+
103+
def _extract_procedure(
104+
node: Node,
105+
file_path: str,
106+
module_id: str,
107+
graph: CodeGraph,
108+
) -> None:
109+
"""Extract a procedure definition."""
110+
name_node = node.child_by_field_name("name")
111+
if name_node is None:
112+
return
113+
proc_name = _ident_text(name_node)
114+
proc_id = f"{module_id}:{proc_name}"
115+
116+
visibility = _extract_visibility(node)
117+
body = node.child_by_field_name("body")
118+
branches, calls = _collect_body(body, file_path)
119+
complexity = compute_complexity(branches)
120+
docstring = _extract_docstring(node)
121+
122+
# Extract num_locals as a synthetic parameter for visibility.
123+
params = _extract_num_locals(node)
124+
125+
unit = CodeUnit(
126+
id=proc_id,
127+
name=proc_name,
128+
kind=NodeKind.FUNCTION,
129+
location=make_location(node, file_path),
130+
parameters=tuple(params),
131+
cyclomatic_complexity=complexity,
132+
branches=tuple(branches),
133+
docstring=_prepend_visibility(visibility, docstring),
134+
)
135+
graph.nodes[proc_id] = unit
136+
add_contains_edge(graph, module_id, proc_id)
137+
_add_call_edges(calls, proc_id, module_id, file_path, graph)
138+
139+
140+
def _extract_entrypoint(
141+
node: Node,
142+
file_path: str,
143+
module_id: str,
144+
graph: CodeGraph,
145+
) -> None:
146+
"""Extract the begin...end entrypoint as a synthetic function."""
147+
entry_id = f"{module_id}:begin"
148+
149+
body = node.child_by_field_name("body")
150+
branches, calls = _collect_body(body, file_path)
151+
complexity = compute_complexity(branches)
152+
docstring = _extract_docstring(node)
153+
154+
unit = CodeUnit(
155+
id=entry_id,
156+
name="begin",
157+
kind=NodeKind.FUNCTION,
158+
location=make_location(node, file_path),
159+
cyclomatic_complexity=complexity,
160+
branches=tuple(branches),
161+
docstring=docstring,
162+
)
163+
graph.nodes[entry_id] = unit
164+
add_contains_edge(graph, module_id, entry_id)
165+
_add_call_edges(calls, entry_id, module_id, file_path, graph)
166+
167+
168+
def _extract_constant(
169+
node: Node,
170+
file_path: str,
171+
module_id: str,
172+
graph: CodeGraph,
173+
) -> None:
174+
"""Extract a constant definition as a FUNCTION node.
175+
176+
Constants in Miden assembly are compile-time values (const.NAME=expr).
177+
We represent them as nodes so they appear in the code graph.
178+
"""
179+
name_node = node.child_by_field_name("name")
180+
if name_node is None:
181+
return
182+
const_name = node_text(name_node).strip()
183+
const_id = f"{module_id}:{const_name}"
184+
docstring = _extract_docstring(node)
185+
186+
unit = CodeUnit(
187+
id=const_id,
188+
name=const_name,
189+
kind=NodeKind.FUNCTION,
190+
location=make_location(node, file_path),
191+
cyclomatic_complexity=1,
192+
docstring=docstring,
193+
)
194+
graph.nodes[const_id] = unit
195+
add_contains_edge(graph, module_id, const_id)
196+
197+
198+
def _extract_import(node: Node, graph: CodeGraph) -> None:
199+
"""Extract use statements as dependency info."""
200+
path_node = node.child_by_field_name("path")
201+
if path_node is None:
202+
return
203+
raw = node_text(path_node).strip()
204+
# Path is like "std::math::u64" or "::foo::bar"; take the root segment.
205+
segments = [s for s in raw.split("::") if s]
206+
if segments:
207+
dep = segments[0]
208+
if dep not in graph.dependencies:
209+
graph.dependencies.append(dep)
210+
211+
212+
def _extract_visibility(node: Node) -> str:
213+
"""Extract the visibility keyword (export or proc) from a procedure."""
214+
vis_node = node.child_by_field_name("visibility")
215+
if vis_node is None:
216+
return "proc"
217+
return node_text(vis_node).strip()
218+
219+
220+
def _extract_num_locals(node: Node) -> list[Parameter]:
221+
"""Extract @locals(N) annotation as a synthetic num_locals parameter."""
222+
for i, child in enumerate(node.children):
223+
if node.field_name_for_child(i) == "annotations" and child.type == "annotation":
224+
name_child = child.child_by_field_name("name")
225+
if name_child is not None and node_text(name_child).strip() == "locals":
226+
value_child = child.child_by_field_name("value")
227+
if value_child is not None:
228+
# annotation_args contains the decimal value
229+
for vc in value_child.children:
230+
if vc.type == "decimal":
231+
text = node_text(vc).strip()
232+
if text:
233+
return [Parameter(name="num_locals", default=text)]
234+
return []
235+
236+
237+
def _extract_docstring(node: Node) -> str | None:
238+
"""Extract #! doc comments attached to a procedure or entrypoint."""
239+
docs_node = node.child_by_field_name("docs")
240+
if docs_node is None:
241+
return None
242+
text = node_text(docs_node).strip()
243+
if not text:
244+
return None
245+
# Clean up doc comment lines: remove #! prefix from each line.
246+
lines = []
247+
for line in text.splitlines():
248+
cleaned = line.strip()
249+
if cleaned.startswith("#!"):
250+
cleaned = cleaned[2:].strip()
251+
lines.append(cleaned)
252+
result = "\n".join(lines).strip()
253+
return result if result else None
254+
255+
256+
def _prepend_visibility(visibility: str, docstring: str | None) -> str | None:
257+
"""Prepend visibility info to docstring if exported."""
258+
if visibility == "pub":
259+
prefix = "[export]"
260+
if docstring:
261+
return f"{prefix} {docstring}"
262+
return prefix
263+
return docstring
264+
265+
266+
def _collect_body(
267+
body: Node | None,
268+
file_path: str,
269+
) -> tuple[list[BranchInfo], list[tuple[str, Node]]]:
270+
"""Collect branches and invoke calls from a procedure/entrypoint body."""
271+
branches: list[BranchInfo] = []
272+
calls: list[tuple[str, Node]] = []
273+
if body is not None:
274+
_walk_body(body, file_path, branches, calls)
275+
return branches, calls
276+
277+
278+
def _walk_body(
279+
node: Node,
280+
file_path: str,
281+
branches: list[BranchInfo],
282+
calls: list[tuple[str, Node]],
283+
) -> None:
284+
"""Walk the AST collecting branches and invoke expressions."""
285+
stack: list[Node] = list(reversed(node.children))
286+
while stack:
287+
child = stack.pop()
288+
if child.type in _BRANCH_NODE_TYPES and child.child_count > 0:
289+
condition = _branch_condition(child)
290+
branches.append(
291+
BranchInfo(
292+
location=make_location(child, file_path),
293+
condition=condition,
294+
)
295+
)
296+
if child.type == "invoke":
297+
name = _invoke_target(child)
298+
if name:
299+
calls.append((name, child))
300+
stack.extend(reversed(child.children))
301+
302+
303+
def _branch_condition(node: Node) -> str:
304+
"""Describe the branch condition for a control flow node."""
305+
if node.type == "if":
306+
return "if.true"
307+
if node.type == "while":
308+
return "while.true"
309+
if node.type == "repeat":
310+
count = node.child_by_field_name("count")
311+
if count is not None:
312+
return f"repeat.{node_text(count).strip()}"
313+
return "repeat"
314+
return node.type
315+
316+
317+
def _invoke_target(node: Node) -> str:
318+
"""Extract the invocation target path from an invoke node.
319+
320+
Invoke nodes have the form: exec.path, call.path, syscall.path, procref.path.
321+
"""
322+
path_node = node.child_by_field_name("path")
323+
if path_node is None:
324+
return ""
325+
return node_text(path_node).strip()
326+
327+
328+
def _ident_text(node: Node) -> str:
329+
"""Extract identifier text, handling quoted identifiers."""
330+
text = node_text(node).strip()
331+
# Quoted identifiers are wrapped in double quotes.
332+
if text.startswith('"') and text.endswith('"'):
333+
text = text[1:-1]
334+
return text
335+
336+
337+
def _add_call_edges(
338+
calls: list[tuple[str, Node]],
339+
source_id: str,
340+
module_id: str,
341+
file_path: str,
342+
graph: CodeGraph,
343+
) -> None:
344+
"""Add CALLS edges for collected invoke expressions."""
345+
for call_name, call_node in calls:
346+
# call_name is a path like "foo::bar" or just "my_proc".
347+
# Use the last segment as the target name within the module.
348+
segments = [s for s in call_name.split("::") if s]
349+
if not segments:
350+
continue
351+
target_name = segments[-1]
352+
target_id = f"{module_id}:{target_name}"
353+
graph.edges.append(
354+
CodeEdge(
355+
source_id=source_id,
356+
target_id=target_id,
357+
kind=EdgeKind.CALLS,
358+
confidence=EdgeConfidence.CERTAIN
359+
if len(segments) == 1
360+
else EdgeConfidence.INFERRED,
361+
location=make_location(call_node, file_path),
362+
)
363+
)

src/trailmark/query/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"circom": ("trailmark.parsers.circom", "CircomParser"),
3333
"haskell": ("trailmark.parsers.haskell", "HaskellParser"),
3434
"erlang": ("trailmark.parsers.erlang", "ErlangParser"),
35+
"masm": ("trailmark.parsers.masm", "MasmParser"),
3536
}
3637

3738
_SUPPORTED_LANGUAGES = frozenset(_PARSER_MAP.keys())

0 commit comments

Comments
 (0)