Skip to content

Commit db3493a

Browse files
committed
fix: fix defer control flow in go
1 parent 4d954f3 commit db3493a

File tree

4 files changed

+190
-10
lines changed

4 files changed

+190
-10
lines changed

scubatrace/function.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ def statements(self) -> list[Statement]:
9696
return []
9797
return BlockStatement.build_statements(self.body_node, self)
9898

99+
@cached_property
100+
def first_statement(self) -> Statement | None:
101+
if len(self.statements) == 0:
102+
return None
103+
return self.statements[0]
104+
99105
def __str__(self) -> str:
100106
return self.signature
101107

@@ -190,7 +196,6 @@ def name_node(self) -> Node:
190196
return node
191197

192198
@property
193-
@abstractmethod
194199
def name(self) -> str:
195200
"""
196201
The name of the function.
@@ -200,7 +205,14 @@ def name(self) -> str:
200205
return name_node.text.decode()
201206

202207
@cached_property
203-
@abstractmethod
208+
def exits(self) -> list[Statement]:
209+
"""
210+
The exit statements of the function, such as return statements.
211+
"""
212+
exits = self.statements_by_types(self.language.EXIT_STATEMENTS, recursive=True)
213+
return exits
214+
215+
@cached_property
204216
def accessible_functions(self) -> list[Function]:
205217
funcs = []
206218
for file in self.file.imports:
@@ -279,7 +291,6 @@ def callees(self) -> dict[Function | FunctionDeclaration, list[Statement]]:
279291
return callees
280292

281293
@cached_property
282-
@abstractmethod
283294
def callers(self) -> dict[Function, list[Statement]]:
284295
"""
285296
The functions that call this function and their corresponding call sites.
@@ -467,12 +478,12 @@ def export_cfg_dot(
467478
)
468479
graph.add_node("edge", fontname="SF Pro Rounded, system-ui", arrowhead="vee")
469480
graph.add_node(self.signature, label=self.dot_text, color="red")
470-
if len(self.statements) == 0:
481+
if self.first_statement is None:
471482
graph.add_node(
472483
self.signature, label="No statements found", color="red", shape="box"
473484
)
474485
else:
475-
graph.add_edge(self.signature, self.statements[0].signature, label="CFG")
486+
graph.add_edge(self.signature, self.first_statement.signature, label="CFG")
476487
self.__build_cfg_graph(graph, self.statements)
477488

478489
if with_cdg:

scubatrace/go/function.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,18 @@
11
from __future__ import annotations
22

3+
from functools import cached_property
4+
35
from ..function import Function
4-
from .statement import GoBlockStatement
6+
from .statement import GoBlockStatement, Statement
7+
58

9+
class GoFunction(Function, GoBlockStatement):
10+
@cached_property
11+
def first_statement(self) -> Statement | None:
12+
for statement in self.statements:
13+
if statement.node_type != "defer_statement":
14+
return statement
615

7-
class GoFunction(Function, GoBlockStatement): ...
16+
@cached_property
17+
def defer_statements(self) -> list[Statement]:
18+
return self.statements_by_type("defer_statement")

scubatrace/go/statement.py

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,135 @@
11
from __future__ import annotations
22

3-
from ..statement import BlockStatement, SimpleStatement
3+
from functools import cached_property
44

5+
from ..statement import BlockStatement, SimpleStatement, Statement
56

6-
class GoSimpleStatement(SimpleStatement): ...
77

8+
class GoSimpleStatement(SimpleStatement):
9+
@cached_property
10+
def post_controls(self) -> list[Statement]:
11+
if self.node_type in self.language.EXIT_STATEMENTS:
12+
return []
13+
exits_statements = self.function.exits if self.function is not None else []
814

9-
class GoBlockStatement(BlockStatement): ...
15+
last_defer_statement = []
16+
from .function import GoFunction
17+
18+
if self.function is not None and isinstance(self.function, GoFunction):
19+
last_defer_statement = (
20+
self.function.defer_statements[-1:]
21+
if len(self.function.defer_statements) > 0
22+
else []
23+
)
24+
25+
if self.node_type == "defer_statement":
26+
assert isinstance(self.function, GoFunction)
27+
defter_index = self.function.defer_statements.index(self)
28+
return (
29+
[self.function.defer_statements[defter_index - 1]]
30+
if defter_index > 0
31+
else exits_statements
32+
)
33+
34+
if self.node_type in self.language.CONTINUE_STATEMENTS:
35+
loop_stat = self.ancestor_by_types(self.language.LOOP_STATEMENTS)
36+
return [loop_stat] if loop_stat else []
37+
if self.node_type in self.language.BREAK_STATEMENTS:
38+
loop_stat = self.ancestor_by_types(
39+
self.language.LOOP_STATEMENTS + self.language.SWITCH_STATEMENTS
40+
)
41+
preorder_successor = loop_stat.preorder_successor if loop_stat else None
42+
return [preorder_successor] if preorder_successor else []
43+
if self.node_type in self.language.GOTO_STATEMENTS:
44+
function = self.function
45+
if function is not None:
46+
label_name_node = self.node.child_by_field_name("label")
47+
assert label_name_node is not None and label_name_node.text is not None
48+
label_name = label_name_node.text.decode()
49+
label_stat = function.query_oneshot(
50+
self.language.query_goto_label(label_name)
51+
)
52+
return [label_stat] if label_stat else []
53+
54+
if self.parent.node_type in self.language.LOOP_STATEMENTS:
55+
# while () {last_statement;}
56+
loop_stat = self.ancestor_by_types(self.language.LOOP_STATEMENTS)
57+
is_last_statement = self.next_sibling is None
58+
if is_last_statement:
59+
return [loop_stat] if loop_stat else []
60+
if self.parent.node_type in self.language.IF_STATEMENTS:
61+
# if () {last_statement;} else { ...}
62+
consequences = self.parent.statements_by_field_name("consequence")
63+
if self in consequences:
64+
is_last_consequences = consequences.index(self) == len(consequences) - 1
65+
if is_last_consequences:
66+
return (
67+
[self.right_uncle_ancestor] if self.right_uncle_ancestor else []
68+
)
69+
70+
preorder_successor = self.preorder_successor
71+
while (
72+
preorder_successor is not None
73+
and preorder_successor.node_type == "defer_statement"
74+
):
75+
preorder_successor = preorder_successor.preorder_successor
76+
if preorder_successor is None:
77+
return last_defer_statement
78+
if preorder_successor.node_type in self.language.EXIT_STATEMENTS:
79+
return last_defer_statement
80+
else:
81+
return [preorder_successor]
82+
83+
84+
class GoBlockStatement(BlockStatement):
85+
@cached_property
86+
def post_controls(self) -> list[Statement]:
87+
exits_statements = self.function.exits if self.function is not None else []
88+
last_defer_statement = []
89+
from .function import GoFunction
90+
91+
if self.function is not None and isinstance(self.function, GoFunction):
92+
last_defer_statement = (
93+
self.function.defer_statements[-1:]
94+
if len(self.function.defer_statements) > 0
95+
else []
96+
)
97+
98+
if self.node_type in self.language.IF_STATEMENTS:
99+
consequences = self.statements_by_field_name("consequence")
100+
alternatives = self.statements_by_field_name("alternative")
101+
nexts = []
102+
if len(consequences) > 0:
103+
nexts.append(consequences[0])
104+
if len(alternatives) > 0:
105+
nexts.append(alternatives[0])
106+
elif (
107+
self.preorder_successor is not None
108+
and self.preorder_successor not in exits_statements
109+
):
110+
nexts.append(self.preorder_successor)
111+
else:
112+
nexts.extend(last_defer_statement)
113+
return nexts
114+
if self.node_type in self.language.SWITCH_STATEMENTS:
115+
if len(self.statements) > 0:
116+
return [self.statements[0]]
117+
if self.parent.node_type in self.language.SWITCH_STATEMENTS:
118+
if self.text.strip().startswith("default:") and len(self.statements) > 0:
119+
return [self.statements[0]]
120+
if self.parent.node_type in self.language.LOOP_STATEMENTS:
121+
# while () {last_statement;}
122+
loop_stat = self.ancestor_by_types(self.language.LOOP_STATEMENTS)
123+
is_last_statement = self.next_sibling is None
124+
if is_last_statement:
125+
return [loop_stat] if loop_stat else []
126+
127+
nexts = [self.statements[0]] if len(self.statements) > 0 else []
128+
if self.preorder_successor is not None:
129+
nexts.append(self.preorder_successor)
130+
131+
if len(last_defer_statement) > 0:
132+
nexts = [stat for stat in nexts if stat not in exits_statements]
133+
if len(nexts) == 0:
134+
return last_defer_statement
135+
return nexts

tests/samples/go/main.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
)
7+
8+
func main() {
9+
defer fmt.Println("world")
10+
defer fmt.Println("hello")
11+
flag.Parse()
12+
defer fmt.Println("Flags parsed successfully")
13+
defer func() {
14+
if r := recover(); r != nil {
15+
fmt.Println("Recovered from panic:", r)
16+
}
17+
}()
18+
19+
var (
20+
testFlag1 = flag.String("testFlag1", "default1", "Description for testFlag1")
21+
testFlag2 = flag.String("testFlag2", "default2", "Description for testFlag2")
22+
)
23+
if *testFlag1 != "default1" || *testFlag2 != "default2" {
24+
panic("Flag values do not match expected defaults")
25+
}
26+
if *testFlag1 == "default1" && *testFlag2 == "default2" {
27+
println("Flag values match expected defaults")
28+
} else {
29+
panic("Flag values do not match expected defaults")
30+
}
31+
println("All tests passed successfully")
32+
}

0 commit comments

Comments
 (0)