Skip to content

Commit 89b8126

Browse files
authored
Merge pull request #23 from smythi93/dev
enhance DefUseFactory with thread-safe scope management and improve v…
2 parents 77b1f37 + 3076bcc commit 89b8126

File tree

3 files changed

+150
-78
lines changed

3 files changed

+150
-78
lines changed

src/sflkit/analysis/factory.py

Lines changed: 108 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def get_analysis(
8585
if key not in self.objects:
8686
self.objects[key] = Line(event)
8787
return [self.objects[key]]
88-
return None
88+
return []
8989

9090

9191
class BranchFactory(AnalysisFactory):
@@ -117,7 +117,7 @@ def get_analysis(
117117
return [self.objects[key], self.objects[else_key]]
118118
return [self.objects[key]]
119119

120-
return None
120+
return []
121121

122122

123123
class FunctionFactory(AnalysisFactory):
@@ -130,7 +130,7 @@ def get_analysis(
130130
if key not in self.objects:
131131
self.objects[key] = Function(event)
132132
return [self.objects[key]]
133-
return None
133+
return []
134134

135135

136136
class LoopFactory(AnalysisFactory):
@@ -170,80 +170,106 @@ def get_analysis(
170170
elif event.event_type == EventType.LOOP_END:
171171
return self.objects[key][:]
172172
return list()
173-
return None
173+
return []
174174

175175

176176
class DefUseFactory(AnalysisFactory):
177+
class DefScope:
178+
def __init__(self, parent: "DefUseFactory.DefScope" = None):
179+
self.parent: "DefUseFactory.DefScope" = parent
180+
self.def_events: dict[tuple[str, int], DefEvent] = dict()
181+
182+
def enter(self):
183+
return DefUseFactory.DefScope(self)
184+
185+
def exit(self):
186+
return self.parent or self
187+
188+
def add(self, var_name: str, var_id: int, def_event: DefEvent):
189+
self.def_events[(var_name, var_id)] = def_event
190+
191+
def event(self, var_name: str, var_id: int) -> DefEvent:
192+
current = self
193+
while current is not None:
194+
if (var_name, var_id) in current.def_events:
195+
return current.def_events[(var_name, var_id)]
196+
current = current.parent
197+
return None
198+
177199
def __init__(self):
178200
super().__init__()
179-
self.id_to_def: dict[EventFile, dict[tuple[str, int, int, int], DefEvent]] = (
180-
dict()
181-
)
182-
self.def_stack: dict[
183-
EventFile, dict[int, dict[tuple[str, int], list[tuple[int, DefEvent]]]]
201+
self.id_to_def: dict[EventFile, dict[tuple[str, int], DefEvent]] = dict()
202+
self.id_to_def_thread: dict[
203+
EventFile, dict[int, dict[tuple[str, int], DefEvent]]
184204
] = dict()
205+
self.def_stack: dict[EventFile, dict[int, DefUseFactory.DefScope]] = dict()
206+
207+
def reset(self, event_file: EventFile):
208+
if event_file in self.id_to_def:
209+
del self.id_to_def[event_file]
210+
if event_file in self.id_to_def_thread:
211+
del self.id_to_def_thread[event_file]
212+
if event_file in self.def_stack:
213+
del self.def_stack[event_file]
185214

186215
def _find_def_event(
187216
self,
188217
event_file: EventFile,
218+
thread_id: int,
189219
var_name: str,
190220
var_id: int,
191-
scope_id: int,
192-
thread_id: int,
193221
) -> DefEvent:
194-
# Strategy 1: Exact match in current thread and scope
195-
exact_key = (var_name, var_id, scope_id, thread_id)
196-
if exact_key in self.id_to_def.get(event_file, {}):
197-
return self.id_to_def[event_file][exact_key]
198-
199-
# Strategy 2: Look up the scope stack in the current thread
200-
thread_stack = self.def_stack.get(event_file, {}).get(thread_id, {})
201-
var_key = (var_name, var_id)
202-
if var_key in thread_stack and thread_stack[var_key]:
203-
# Return the most recent (top of stack) DEF event
204-
return thread_stack[var_key][-1][1]
205-
206-
# Strategy 3: Look for the same var_id in other threads (shared objects)
207-
# This handles cross-thread variable sharing
208-
for tid, thread_data in self.def_stack.get(event_file, {}).items():
209-
if tid != thread_id and var_key in thread_data and thread_data[var_key]:
210-
# Return the most recent DEF from another thread
211-
return thread_data[var_key][-1][1]
212-
213-
return None
222+
# Strategy 1: Check in scope stack (from innermost to outermost)
223+
def_event = None
224+
if event_file in self.def_stack and thread_id in self.def_stack[event_file]:
225+
def_event = self.def_stack[event_file][thread_id].event(var_name, var_id)
226+
227+
# Strategy 2: Look up in the thread-specific DEF stack
228+
if def_event is None:
229+
if (
230+
event_file in self.id_to_def_thread
231+
and thread_id in self.id_to_def_thread[event_file]
232+
):
233+
def_event = self.id_to_def_thread[event_file][thread_id].get(
234+
(var_name, var_id), None
235+
)
236+
237+
# Strategy 3: Look up in the global DEF stack (other threads)
238+
if def_event is None:
239+
if event_file in self.id_to_def:
240+
def_event = self.id_to_def[event_file].get((var_name, var_id), None)
241+
242+
return def_event
214243

215244
def get_analysis(
216245
self, event, event_file: EventFile, scope: Scope = None
217246
) -> List[AnalysisObject]:
218247
thread_id = event.thread_id
219-
scope_id = scope.id if scope else 0
220248

221249
if event.event_type == EventType.DEF:
222-
var_key = (event.var, event.var_id)
223-
full_key = (event.var, event.var_id, scope_id, thread_id)
224-
225-
with self._lock:
226-
# Initialize structures if needed
227-
if event_file not in self.id_to_def:
228-
self.id_to_def[event_file] = dict()
229-
if event_file not in self.def_stack:
230-
self.def_stack[event_file] = dict()
231-
if thread_id not in self.def_stack[event_file]:
232-
self.def_stack[event_file][thread_id] = dict()
233-
if var_key not in self.def_stack[event_file][thread_id]:
234-
self.def_stack[event_file][thread_id][var_key] = []
235-
236-
# Store the DEF event
237-
self.id_to_def[event_file][full_key] = event
238-
239-
# Add to stack for this thread
240-
self.def_stack[event_file][thread_id][var_key].append((scope_id, event))
250+
key = (event.var, event.var_id)
251+
252+
# Initialize structures if needed
253+
if event_file not in self.id_to_def:
254+
self.id_to_def[event_file] = dict()
255+
if event_file not in self.id_to_def_thread:
256+
self.id_to_def_thread[event_file] = dict()
257+
if thread_id not in self.id_to_def_thread[event_file]:
258+
self.id_to_def_thread[event_file][thread_id] = dict()
259+
if event_file not in self.def_stack:
260+
self.def_stack[event_file] = dict()
261+
if thread_id not in self.def_stack[event_file]:
262+
self.def_stack[event_file][thread_id] = DefUseFactory.DefScope()
263+
264+
# Store the DEF event
265+
self.id_to_def[event_file][key] = event
266+
self.id_to_def_thread[event_file][thread_id][key] = event
267+
self.def_stack[event_file][thread_id].add(event.var, event.var_id, event)
241268

242269
elif event.event_type == EventType.USE:
243-
with self._lock:
244-
def_event = self._find_def_event(
245-
event_file, event.var, event.var_id, scope_id, thread_id
246-
)
270+
def_event = self._find_def_event(
271+
event_file, thread_id, event.var, event.var_id
272+
)
247273

248274
if def_event:
249275
key = (
@@ -258,7 +284,24 @@ def get_analysis(
258284
if key not in self.objects:
259285
self.objects[key] = DefUse(def_event, event)
260286
return [self.objects[key]]
261-
return None
287+
elif event.event_type == EventType.FUNCTION_ENTER:
288+
if event_file not in self.def_stack:
289+
self.def_stack[event_file] = dict()
290+
if thread_id not in self.def_stack[event_file]:
291+
self.def_stack[event_file][thread_id] = DefUseFactory.DefScope()
292+
else:
293+
self.def_stack[event_file][thread_id] = self.def_stack[event_file][
294+
thread_id
295+
].enter()
296+
elif (
297+
event.event_type == EventType.FUNCTION_EXIT
298+
or event.event_type == EventType.FUNCTION_ERROR
299+
):
300+
if event_file in self.def_stack and thread_id in self.def_stack[event_file]:
301+
self.def_stack[event_file][thread_id] = self.def_stack[event_file][
302+
thread_id
303+
].exit()
304+
return []
262305

263306

264307
class ConditionFactory(AnalysisFactory):
@@ -359,7 +402,7 @@ def get_analysis(
359402
)
360403
objects.append(self.objects[key])
361404
return objects
362-
return None
405+
return []
363406

364407

365408
class VariableFactory(ComparisonFactory):
@@ -386,7 +429,7 @@ def get_analysis(
386429
self.objects[key] = VariablePredicate(event, comp)
387430
objects.append(self.objects[key])
388431
return objects
389-
return None
432+
return []
390433

391434

392435
class ReturnFactory(ComparisonFactory):
@@ -456,7 +499,7 @@ def get_analysis(
456499
)
457500
objects.append(self.objects[key])
458501
return objects
459-
return None
502+
return []
460503

461504

462505
class ConstantCompFactory(AnalysisFactory):
@@ -483,7 +526,7 @@ def get_analysis(
483526
self.objects[key] = self.class_(event)
484527
objects.append(self.objects[key])
485528
return objects
486-
return None
529+
return []
487530

488531

489532
class NoneFactory(ConstantCompFactory):
@@ -521,7 +564,7 @@ def get_analysis(
521564
# noinspection PyArgumentList
522565
self.objects[key] = self.class_(event)
523566
return [self.objects[key]]
524-
return None
567+
return []
525568

526569

527570
class IsAsciiFactory(PredicateFunctionFactory):
@@ -572,7 +615,7 @@ def get_analysis(
572615
Length(event, Length.evaluate_length_more)
573616
)
574617
return self.objects[key][:]
575-
return None
618+
return []
576619

577620

578621
class FunctionErrorFactory(AnalysisFactory):
@@ -584,11 +627,9 @@ def get_analysis(
584627
self, event, event_file: EventFile, scope: Scope = None
585628
) -> List[AnalysisObject]:
586629
if event.event_type == EventType.FUNCTION_ENTER:
587-
with self._lock:
588-
self.function_mapping[event.function_id] = event.line
630+
self.function_mapping[event.function_id] = event.line
589631
if event.event_type in (EventType.FUNCTION_ERROR, EventType.FUNCTION_EXIT):
590-
with self._lock:
591-
line = self.function_mapping.get(event.function_id, event.line)
632+
line = self.function_mapping.get(event.function_id, event.line)
592633
key = (
593634
FunctionErrorPredicate.analysis_type(),
594635
event.file,
@@ -601,7 +642,7 @@ def get_analysis(
601642
event.file, line, event.function
602643
)
603644
return [self.objects[key]]
604-
return None
645+
return []
605646

606647

607648
analysis_factory_mapping = {

src/sflkit/model/scope.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,31 @@ def exit(self):
5151
else:
5252
return self
5353

54+
def __contains__(self, var: str) -> bool:
55+
current = self
56+
while current is not None:
57+
if var in current.variables:
58+
return True
59+
current = current.parent
60+
return False
61+
5462
def value(self, var: str) -> Var:
55-
if var in self.variables:
56-
return self.variables[var].value
57-
elif self.parent is not None:
58-
return self.parent.value(var)
59-
else:
60-
return None
63+
current = self
64+
while current is not None:
65+
if var in current.variables:
66+
return current.variables[var].value
67+
current = current.parent
68+
return None
6169

6270
def add(self, var, value, type_, id_: int = None):
6371
self.variables[var] = Var(var, value, type_, id_)
6472

6573
def get_all_vars_dict(self):
66-
if self.parent is not None:
67-
variables = self.parent.get_all_vars_dict()
68-
else:
69-
variables = dict()
70-
variables.update(self.variables)
74+
current = self
75+
variables = dict()
76+
while current is not None:
77+
variables = {**current.variables, **variables}
78+
current = current.parent
7179
return variables
7280

7381
def get_all_vars(self) -> List[Var]:

tests/test_scope.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,26 @@ def test_var_eq(self):
1818
def test_scope_without_parent(self):
1919
scope = Scope()
2020
self.assertIs(scope, scope.exit())
21+
22+
def test_scope_with_parent(self):
23+
parent_scope = Scope()
24+
child_scope = parent_scope.enter()
25+
self.assertIs(parent_scope, child_scope.exit())
26+
27+
def test_scope_contains(self):
28+
scope = Scope()
29+
scope.variables["x"] = Var("x", 10, int)
30+
self.assertIn("x", scope)
31+
self.assertNotIn("y", scope)
32+
33+
def test_scope_all_vars(self):
34+
parent_scope = Scope()
35+
parent_scope.variables["x"] = Var("x", 10, int)
36+
child_scope = parent_scope.enter()
37+
child_scope.variables["y"] = Var("y", 20, int)
38+
39+
all_vars = child_scope.get_all_vars_dict()
40+
self.assertIn("x", all_vars)
41+
self.assertIn("y", all_vars)
42+
self.assertEqual(all_vars["x"].value, 10)
43+
self.assertEqual(all_vars["y"].value, 20)

0 commit comments

Comments
 (0)