Skip to content

Commit edc877e

Browse files
committed
feat: supprt function walk and cg
1 parent bd3d24e commit edc877e

File tree

7 files changed

+183
-21
lines changed

7 files changed

+183
-21
lines changed

scubatrace/file.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,18 @@ def function_by_line(self, line: int) -> Function | None:
326326
return func
327327
return None
328328

329+
def functions_by_name(self, name: str) -> list[Function]:
330+
"""
331+
The functions that have the specified name.
332+
333+
Args:
334+
name (str): The name of the function to check.
335+
336+
Returns:
337+
list[Function]: A list of functions that have the specified name.
338+
"""
339+
return [f for f in self.functions if f.name == name]
340+
329341
def statements_by_line(self, line: int) -> list[Statement]:
330342
"""
331343
The statements that are located on the specified line number.

scubatrace/function.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
3+
from collections import defaultdict, deque
44
from functools import cached_property
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Callable, Generator
66

77
import networkx as nx
88
from tree_sitter import Node
@@ -102,7 +102,9 @@ def first_statement(self) -> Statement | None:
102102
return self.statements[0]
103103

104104
def __str__(self) -> str:
105-
return self.signature
105+
return (
106+
f'"{self.name.replace("::", "--")} ({self.file.name}\\:{self.start_line})"'
107+
)
106108

107109
def set_joernid(self, joern_id: str):
108110
self.joern_id = joern_id
@@ -318,6 +320,106 @@ def callers(self) -> dict[Function, list[Statement]]:
318320
break
319321
return callers
320322

323+
def walk_backward(
324+
self,
325+
filter: Callable[[Statement], bool] | None = None,
326+
stop_by: Callable[[Statement], bool] | None = None,
327+
depth: int = -1,
328+
base: str = "call",
329+
) -> Generator[Function, None, None]:
330+
for caller in super().walk_backward(
331+
filter=filter,
332+
stop_by=stop_by,
333+
depth=depth,
334+
base=base,
335+
):
336+
assert isinstance(caller, Function)
337+
yield caller
338+
339+
def walk_forward(
340+
self,
341+
filter: Callable[[Statement], bool] | None = None,
342+
stop_by: Callable[[Statement], bool] | None = None,
343+
depth: int = -1,
344+
base: str = "call",
345+
) -> Generator[Function, None, None]:
346+
for callee in super().walk_forward(
347+
filter=filter,
348+
stop_by=stop_by,
349+
depth=depth,
350+
base=base,
351+
):
352+
assert isinstance(callee, Function)
353+
yield callee
354+
355+
def __build_callgraph(self, depth: int = -1) -> nx.MultiDiGraph:
356+
cg = nx.MultiDiGraph()
357+
cg.add_node(
358+
self,
359+
color="red",
360+
shape="box",
361+
style="rounded",
362+
)
363+
forward_depth = 2048 if depth == -1 else depth
364+
dq: deque[Function | FunctionDeclaration] = deque([self])
365+
visited: set[Function | FunctionDeclaration] = set([self])
366+
while len(dq) > 0 and forward_depth > 0:
367+
size = len(dq)
368+
for _ in range(size):
369+
caller = dq.popleft()
370+
if not isinstance(caller, Function):
371+
continue
372+
for callee, callsites in caller.callees.items():
373+
for callsite in callsites:
374+
cg.add_edge(
375+
caller,
376+
callee,
377+
key=callsite.signature,
378+
label=callsite.start_line,
379+
)
380+
if callee not in visited:
381+
visited.add(callee)
382+
dq.append(callee)
383+
forward_depth -= 1
384+
385+
backward_depth = 2048 if depth == -1 else depth
386+
dq = deque([self])
387+
visited = set([self])
388+
while len(dq) > 0 and backward_depth > 0:
389+
size = len(dq)
390+
for _ in range(size):
391+
callee = dq.popleft()
392+
if not isinstance(callee, Function):
393+
continue
394+
for caller, callsites in callee.callers.items():
395+
for callsite in callsites:
396+
cg.add_edge(
397+
caller,
398+
callee,
399+
key=callsite.signature,
400+
label=callsite.start_line,
401+
)
402+
if caller not in visited:
403+
visited.add(caller)
404+
dq.append(caller)
405+
backward_depth -= 1
406+
return cg
407+
408+
def export_callgraph(self, path: str, depth: int = -1) -> nx.MultiDiGraph:
409+
"""
410+
Exports the call graph of the function to a DOT file.
411+
412+
Args:
413+
path (str): The path to save the DOT file.
414+
depth (int): The depth of the call graph to export. -1 means no limit.
415+
416+
Returns:
417+
nx.MultiDiGraph: The call graph of the function.
418+
"""
419+
cg = self.__build_callgraph(depth)
420+
nx.nx_pydot.write_dot(cg, path)
421+
return cg
422+
321423
def slice_by_statements(
322424
self,
323425
statements: list[Statement],

scubatrace/statement.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,11 @@ def walk_backward(
514514
nexts.extend(stats)
515515
case "control_dependent":
516516
nexts = cur_stat.pre_control_dependents
517+
case "call":
518+
from .function import Function
519+
520+
assert isinstance(cur_stat, Function)
521+
nexts = [caller for caller in cur_stat.callers.keys()]
517522
case _:
518523
nexts = cur_stat.pre_controls
519524
for pre in nexts:
@@ -539,7 +544,7 @@ def walk_forward(
539544
stop_by (Callable[[Statement], bool] | None): A function to stop the walking when it returns True.
540545
depth (int): The maximum depth to walk forward. Default is -1, which means no limit.
541546
base (str): The base type of the walk.
542-
Can be "control", "data_dependent", or "control_dependent".
547+
Can be "control", "data_dependent", "control_dependent", "call".
543548
544549
Yields:
545550
Statement: The statements that match the filter or all statements if no filter is provided.
@@ -564,11 +569,18 @@ def walk_forward(
564569
nexts.extend(stats)
565570
case "control_dependent":
566571
nexts = cur_stat.post_control_dependents
572+
case "call":
573+
from .function import Function
574+
575+
assert isinstance(cur_stat, Function)
576+
nexts = [caller for caller in cur_stat.callees.keys()]
567577
case _:
568578
nexts = cur_stat.post_controls
569579
for post in nexts:
570580
if post in visited:
571581
continue
582+
if not isinstance(post, Statement):
583+
continue
572584
visited.add(post)
573585
dq.appendleft(post)
574586
depth -= 1

tests/samples/c/main.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include "include/sub.h"
22
#include <stdio.h>
33

4+
int mul(int a, int b);
5+
46
int add(int a, int b)
57
{
6-
return a + sub(a, b);
8+
return a + sub(a, b) + mul(a, b);
79
}
810

911
int main(int argc, char** argv)

tests/test_file.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,32 @@ def test_file_functions(self):
3636
self.assertIsNotNone(func.name)
3737

3838
def test_file_function_by_line(self):
39-
function = self.file.function_by_line(5)
39+
function = self.file.function_by_line(8)
4040
self.assertIsNotNone(function)
4141
assert function is not None
4242
self.assertEqual(function.name, "add")
4343

44+
def test_file_function_by_name(self):
45+
function = self.file.functions_by_name("add")
46+
self.assertEqual(len(function), 1)
47+
function = self.file.functions_by_name("main")
48+
self.assertEqual(len(function), 1)
49+
4450
def test_file_statements(self):
4551
statements = self.file.statements
4652
self.assertGreater(len(statements), 0)
4753
for stmt in statements:
4854
self.assertIsNotNone(stmt.text)
4955

5056
def test_file_statement_by_line(self):
51-
statements = self.file.statements_by_line(14)
57+
statements = self.file.statements_by_line(16)
5258
self.assertGreater(len(statements), 0)
5359
self.assertEqual(statements[0].text, "int c = count + argc;")
5460

5561
self.assertEqual(len(self.file.statements_by_line(-1)), 0)
5662
self.assertGreater(len(self.file.statements_by_line(1)), 0)
5763

58-
self.assertEqual(self.file.statements_by_line(18)[0].text, "a -= 1;")
64+
self.assertEqual(self.file.statements_by_line(20)[0].text, "a -= 1;")
5965

6066
def test_file_identifiers(self):
6167
identifiers = self.file.identifiers
@@ -71,7 +77,7 @@ def test_file_variables(self):
7177

7278
def test_file_cfg(self):
7379
assert self.file is not None
74-
cfg = self.file.export_cfg_dot(f"{self.file.name}.dot")
80+
cfg = self.file.export_cfg_dot(f"{self.project_path}/{self.file.name}.dot")
7581
self.assertIsNotNone(cfg)
7682
self.assertGreater(len(cfg.nodes), 0)
7783
self.assertGreater(len(cfg.edges), 0)
@@ -84,7 +90,7 @@ def test_file_query(self):
8490
)@call
8591
"""
8692
query = self.file.query(query_str)
87-
target_lines = [6, 36]
93+
target_lines = [8, 38]
8894
self.assertEqual(len(query), len(target_lines))
8995
for stat in query:
9096
self.assertIn(stat.start_line, target_lines)

tests/test_function.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def setUp(self):
1515
file = self.project.files.get("main.c")
1616
assert file is not None
1717
self.file = file
18-
function = self.file.function_by_line(11)
19-
assert function is not None
20-
self.function = function
18+
self.function = self.file.functions_by_name("main")[0]
2119

2220
def test_function_create(self):
2321
function = scubatrace.Function.create(self.function.node, self.function.parent)
@@ -31,7 +29,7 @@ def test_function_callees(self):
3129
self.assertIn("printf", [callee.name for callee in callees])
3230

3331
def test_function_callers(self):
34-
function = self.file.function_by_line(4)
32+
function = self.file.function_by_line(6)
3533
self.assertIsNotNone(function)
3634
assert function is not None
3735
callers = function.callers
@@ -43,7 +41,7 @@ def test_function_lines(self):
4341

4442
def test_function_parameter_lines(self):
4543
self.assertEqual(len(self.function.parameter_lines), 1)
46-
self.assertEqual(self.function.parameter_lines[0], 9)
44+
self.assertEqual(self.function.parameter_lines[0], 11)
4745

4846
def test_function_parameters(self):
4947
parameters = self.function.parameters
@@ -58,10 +56,40 @@ def test_function_variables(self):
5856
self.assertEqual(variables[len(variables) - 1].text, "count")
5957

6058
def test_function_export_cfg_dot(self):
61-
cfg = self.function.export_cfg_dot("cfg.dot", with_cdg=True, with_ddg=True)
59+
cfg = self.function.export_cfg_dot(
60+
f"{self.project_path}/{self.function.name}.dot",
61+
with_cdg=True,
62+
with_ddg=True,
63+
)
6264
self.assertIsNotNone(cfg)
6365

6466
def test_function_slicing_by_lines(self):
65-
stats = self.function.slice_by_lines([14])
66-
self.assertEqual(stats[0].start_line, 9)
67-
self.assertEqual(stats[len(stats) - 1].start_line, 38)
67+
stats = self.function.slice_by_lines([16])
68+
self.assertEqual(stats[0].start_line, 11)
69+
self.assertEqual(stats[len(stats) - 1].start_line, second=40)
70+
71+
def test_function_walk_backward(self):
72+
function = self.file.functions_by_name("add")[0]
73+
assert function is not None
74+
functions = list(function.walk_backward())
75+
self.assertEqual(len(functions), 3)
76+
functions_start_lines = sorted([f.start_line for f in functions])
77+
self.assertEqual(functions_start_lines, [6, 11, 51])
78+
79+
def test_function_walk_forward(self):
80+
functions = list(self.function.walk_forward())
81+
self.assertGreater(len(functions), 0)
82+
functions_start_lines = sorted([f.start_line for f in functions])
83+
self.assertIn(11, functions_start_lines)
84+
self.assertIn(6, functions_start_lines)
85+
self.assertIn(51, functions_start_lines)
86+
87+
def test_function_export_callgraph(self):
88+
function = self.file.functions_by_name("add")[0]
89+
callgraph = function.export_callgraph(
90+
f"{self.project_path}/{function.name}_callgraph.dot",
91+
depth=1,
92+
)
93+
self.assertIsNotNone(callgraph)
94+
self.assertGreater(len(callgraph.nodes), 0)
95+
self.assertGreater(len(callgraph.edges), 0)

tests/test_identifer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def setUp(self):
1515
self.file = self.project.files.get("main.c")
1616
assert self.file is not None
1717
self.assertGreater(len(self.file.statements), 0)
18-
self.statement = self.file.statements_by_line(14)[0]
18+
self.statement = self.file.statements_by_line(16)[0]
1919
self.assertGreater(len(self.statement.identifiers), 0)
2020
self.identifier = self.statement.identifiers[0]
2121

@@ -27,4 +27,4 @@ def test_identifier_post_data_dependents(self):
2727
dependents = self.identifier.post_data_dependents
2828
self.assertEqual(len(dependents), 5)
2929
dependents_lines = sorted([dep.start_line for dep in dependents])
30-
self.assertEqual(dependents_lines, [15, 15, 34, 36, 38])
30+
self.assertEqual(dependents_lines, [17, 17, 36, 38, 40])

0 commit comments

Comments
 (0)