1+ import abc
12from abc import abstractmethod
23from typing import List , Type , Set
34
@@ -75,13 +76,17 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
7576
7677
7778class BranchFactory (AnalysisFactory ):
79+ def __init__ (self , else_ : bool = True ):
80+ super ().__init__ ()
81+ self .else_ = else_
82+
7883 def get_analysis (self , event , scope : Scope = None ) -> List [AnalysisObject ]:
7984 if event .event_type == EventType .BRANCH :
8085 key = (Branch .analysis_type (), event .file , event .line , event .then_id )
8186 then = event .then_id < event .else_id
8287 if key not in self .objects :
8388 self .objects [key ] = Branch (event , then = then )
84- if event .else_id >= 0 :
89+ if self . else_ and event .else_id >= 0 :
8590 else_key = (
8691 Branch .analysis_type (),
8792 event .file ,
@@ -104,6 +109,12 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
104109
105110
106111class LoopFactory (AnalysisFactory ):
112+ def __init__ (self , hit_0 : bool = True , hit_1 : bool = True , hit_more : bool = True ):
113+ super ().__init__ ()
114+ self .hit_0 = hit_0
115+ self .hit_1 = hit_1
116+ self .hit_more = hit_more
117+
107118 def get_all (self ) -> Set [AnalysisObject ]:
108119 return set (obj for value in self .objects .values () for obj in value )
109120
@@ -115,11 +126,13 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
115126 ):
116127 key = (Loop .analysis_type (), event .file , event .line , event .loop_id )
117128 if key not in self .objects :
118- self .objects [key ] = [
119- Loop (event , Loop .evaluate_hit_0 ),
120- Loop (event , Loop .evaluate_hit_1 ),
121- Loop (event , Loop .evaluate_hit_more ),
122- ]
129+ self .objects [key ] = []
130+ if self .hit_0 :
131+ self .objects [key ].append (Loop (event , Loop .evaluate_hit_0 )),
132+ if self .hit_1 :
133+ self .objects [key ].append (Loop (event , Loop .evaluate_hit_1 )),
134+ if self .hit_more :
135+ self .objects [key ].append (Loop (event , Loop .evaluate_hit_more )),
123136 if event .event_type == EventType .LOOP_BEGIN :
124137 list (map (Loop .start_loop , self .objects [key ]))
125138 elif event .event_type == EventType .LOOP_HIT :
@@ -164,7 +177,33 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
164177 return [self .objects [key ]]
165178
166179
167- class ScalarPairFactory (AnalysisFactory ):
180+ class ComparisonFactory (AnalysisFactory , abc .ABC ):
181+ def __init__ (
182+ self ,
183+ eq : bool = True ,
184+ ne : bool = True ,
185+ lt : bool = True ,
186+ le : bool = True ,
187+ gt : bool = True ,
188+ ge : bool = True ,
189+ ):
190+ super ().__init__ ()
191+ self .comparators = []
192+ if eq :
193+ self .comparators .append (Comp .EQ )
194+ if ne :
195+ self .comparators .append (Comp .NE )
196+ if lt :
197+ self .comparators .append (Comp .LT )
198+ if le :
199+ self .comparators .append (Comp .LE )
200+ if gt :
201+ self .comparators .append (Comp .GT )
202+ if ge :
203+ self .comparators .append (Comp .GE )
204+
205+
206+ class ScalarPairFactory (ComparisonFactory ):
168207 def get_analysis (self , event , scope : Scope = None ) -> List [AnalysisObject ]:
169208 if event .event_type == EventType .DEF :
170209 variables = scope .get_all_vars ()
@@ -175,7 +214,7 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
175214 for variable in variables :
176215 if variable .var != event .var :
177216 if variable .type_ in types :
178- for comp in Comp :
217+ for comp in self . comparators :
179218 key = (
180219 ScalarPair .analysis_type (),
181220 event .file ,
@@ -194,32 +233,33 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
194233 for variable in variables :
195234 if variable .type_ == event .type_ :
196235 for comp in (Comp .EQ , Comp .NE ):
197- key = (
198- ScalarPair .analysis_type (),
199- event .file ,
200- event .line ,
201- event .var ,
202- variable .var ,
203- comp ,
204- event .type_ ,
205- )
206- if key not in self .objects :
207- self .objects [key ] = ScalarPair (
208- event , comp , variable .var
236+ if comp in self .comparators :
237+ key = (
238+ ScalarPair .analysis_type (),
239+ event .file ,
240+ event .line ,
241+ event .var ,
242+ variable .var ,
243+ comp ,
244+ event .type_ ,
209245 )
210- objects .append (self .objects [key ])
246+ if key not in self .objects :
247+ self .objects [key ] = ScalarPair (
248+ event , comp , variable .var
249+ )
250+ objects .append (self .objects [key ])
211251 return objects
212252
213253
214- class VariableFactory (AnalysisFactory ):
254+ class VariableFactory (ComparisonFactory ):
215255 def get_analysis (self , event , scope : Scope = None ) -> List [AnalysisObject ]:
216256 if event .event_type == EventType .DEF and event .type_ in [
217257 "int" ,
218258 "float" ,
219259 "bool" ,
220260 ]:
221261 objects = list ()
222- for comp in Comp :
262+ for comp in self . comparators :
223263 key = (
224264 VariablePredicate .analysis_type (),
225265 event .file ,
@@ -234,7 +274,7 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
234274 return objects
235275
236276
237- class ReturnFactory (AnalysisFactory ):
277+ class ReturnFactory (ComparisonFactory ):
238278 def get_analysis (self , event , scope : Scope = None ) -> List [AnalysisObject ]:
239279 if event .event_type == EventType .FUNCTION_EXIT :
240280 objects = list ()
@@ -249,42 +289,45 @@ def get_analysis(self, event, scope: Scope = None) -> List[AnalysisObject]:
249289 type_ , tr = "bytes" , b""
250290 compare = Comp .EQ , Comp .NE
251291 for comp in compare :
252- key = (
253- ReturnPredicate .analysis_type (),
254- event .file ,
255- event .line ,
256- event .function ,
257- comp ,
258- type_ ,
259- )
260- if key not in self .objects :
261- self .objects [key ] = ReturnPredicate (event , comp , value = tr )
262- objects .append (self .objects [key ])
292+ if comp in self .comparators :
293+ key = (
294+ ReturnPredicate .analysis_type (),
295+ event .file ,
296+ event .line ,
297+ event .function ,
298+ comp ,
299+ type_ ,
300+ )
301+ if key not in self .objects :
302+ self .objects [key ] = ReturnPredicate (event , comp , value = tr )
303+ objects .append (self .objects [key ])
263304 if event .type_ == "NoneType" :
264305 for comp in Comp .EQ , Comp .NE :
265- key = (
266- ReturnPredicate .analysis_type (),
267- event .file ,
268- event .line ,
269- event .function ,
270- comp ,
271- event .type_ ,
272- )
273- if key not in self .objects :
274- self .objects [key ] = ReturnPredicate (event , comp , value = None )
275- objects .append (self .objects [key ])
306+ if comp in self .comparators :
307+ key = (
308+ ReturnPredicate .analysis_type (),
309+ event .file ,
310+ event .line ,
311+ event .function ,
312+ comp ,
313+ event .type_ ,
314+ )
315+ if key not in self .objects :
316+ self .objects [key ] = ReturnPredicate (event , comp , value = None )
317+ objects .append (self .objects [key ])
276318 else :
277319 for comp in Comp .EQ , Comp .NE :
278- key = (
279- ReturnPredicate .analysis_type (),
280- event .file ,
281- event .line ,
282- event .function ,
283- comp ,
284- "NoneType" ,
285- )
286- if key in self .objects :
287- objects .append (self .objects [key ])
320+ if comp in self .comparators :
321+ key = (
322+ ReturnPredicate .analysis_type (),
323+ event .file ,
324+ event .line ,
325+ event .function ,
326+ comp ,
327+ "NoneType" ,
328+ )
329+ if key in self .objects :
330+ objects .append (self .objects [key ])
288331 return objects
289332
290333
0 commit comments