1515from credsweeper .config import Config
1616from credsweeper .credentials import Candidate , CredentialManager , CandidateKey
1717from credsweeper .deep_scanner .deep_scanner import DeepScanner
18+ from credsweeper .file_handler .content_provider import ContentProvider
1819from credsweeper .file_handler .diff_content_provider import DiffContentProvider
1920from credsweeper .file_handler .file_path_extractor import FilePathExtractor
2021from credsweeper .file_handler .abstract_provider import AbstractProvider
2122from credsweeper .file_handler .text_content_provider import TextContentProvider
2223from credsweeper .scanner import Scanner
24+ from credsweeper .ml_model .ml_validator import MlValidator
2325from credsweeper .utils import Util
2426
2527logger = logging .getLogger (__name__ )
@@ -94,7 +96,7 @@ def __init__(self,
9496 log_level: str - level for pool initializer according logging levels (UPPERCASE)
9597
9698 """
97- self .pool_count : int = int ( pool_count ) if int (pool_count ) > 1 else 1
99+ self .pool_count : int = max ( 1 , int (pool_count ))
98100 if not (_severity := Severity .get (severity )):
99101 raise RuntimeError (f"Severity level provided: { severity } "
100102 f" -- must be one of: { ' | ' .join ([i .value for i in Severity ])} " )
@@ -123,9 +125,9 @@ def __init__(self,
123125 self .ml_config = ml_config
124126 self .ml_model = ml_model
125127 self .ml_providers = ml_providers
126- self .ml_validator = None
127128 self .__thrifty = thrifty
128129 self .__log_level = log_level
130+ self .__ml_validator : Optional [MlValidator ] = None
129131
130132 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
131133
@@ -182,35 +184,22 @@ def _use_ml_validation(self) -> bool:
182184
183185 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
184186
185- # the import cannot be done on top due
186- # TypeError: cannot pickle 'onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession' object
187- from credsweeper .ml_model import MlValidator
188-
189- # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
190-
191187 @property
192188 def ml_validator (self ) -> MlValidator :
193189 """ml_validator getter"""
194- from credsweeper .ml_model import MlValidator
195190 if not self .__ml_validator :
196- self .__ml_validator : MlValidator = MlValidator (
191+ self .__ml_validator = MlValidator (
197192 threshold = self .ml_threshold , #
198193 ml_config = self .ml_config , #
199194 ml_model = self .ml_model , #
200195 ml_providers = self .ml_providers , #
201196 )
202- assert self .__ml_validator , "self.__ml_validator was not initialized"
197+ if not self .__ml_validator :
198+ raise RuntimeError ("MlValidator was not initialized!" )
203199 return self .__ml_validator
204200
205201 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
206202
207- @ml_validator .setter
208- def ml_validator (self , _ml_validator : Optional [MlValidator ]) -> None :
209- """ml_validator setter"""
210- self .__ml_validator = _ml_validator
211-
212- # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
213-
214203 @staticmethod
215204 def pool_initializer (log_kwargs ) -> None :
216205 """Ignore SIGINT in child processes."""
@@ -219,20 +208,6 @@ def pool_initializer(log_kwargs) -> None:
219208
220209 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
221210
222- @property
223- def config (self ) -> Config :
224- """config getter"""
225- return self .__config
226-
227- # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
228-
229- @config .setter
230- def config (self , config : Config ) -> None :
231- """config setter"""
232- self .__config = config
233-
234- # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
235-
236211 def run (self , content_provider : AbstractProvider ) -> int :
237212 """Run an analysis of 'content_provider' object.
238213
@@ -241,9 +216,10 @@ def run(self, content_provider: AbstractProvider) -> int:
241216
242217 """
243218 _empty_list : Sequence [Union [DiffContentProvider , TextContentProvider ]] = []
244- file_extractors : Sequence [Union [DiffContentProvider , TextContentProvider ]] = \
245- content_provider .get_scannable_files (self .config ) if content_provider else _empty_list
246- logger .info (f"Start Scanner for { len (file_extractors )} providers" )
219+ file_extractors = content_provider .get_scannable_files (self .config ) if content_provider else _empty_list
220+ if not file_extractors :
221+ logger .info (f"No scannable targets for { len (content_provider .paths )} paths" )
222+ return 0
247223 self .scan (file_extractors )
248224 self .post_processing ()
249225 # PatchesProvider has the attribute. Circular import error appears with using the isinstance
@@ -260,7 +236,7 @@ def scan(self, content_providers: Sequence[Union[DiffContentProvider, TextConten
260236 content_providers: file objects to scan
261237
262238 """
263- if 1 < self .pool_count :
239+ if 1 < self .pool_count and 1 < len ( content_providers ) :
264240 self .__multi_jobs_scan (content_providers )
265241 else :
266242 self .__single_job_scan (content_providers )
@@ -269,6 +245,7 @@ def scan(self, content_providers: Sequence[Union[DiffContentProvider, TextConten
269245
270246 def __single_job_scan (self , content_providers : Sequence [Union [DiffContentProvider , TextContentProvider ]]) -> None :
271247 """Performs scan in main thread"""
248+ logger .info (f"Scan for { len (content_providers )} providers" )
272249 all_cred = self .files_scan (content_providers )
273250 self .credential_manager .set_credentials (all_cred )
274251
@@ -284,12 +261,14 @@ def __multi_jobs_scan(self, content_providers: Sequence[Union[DiffContentProvide
284261 if "SILENCE" == self .__log_level :
285262 logging .addLevelName (60 , "SILENCE" )
286263 log_kwargs ["level" ] = self .__log_level
287- with multiprocessing .get_context ("spawn" ).Pool (processes = self .pool_count ,
288- initializer = self .pool_initializer ,
264+ pool_count = min (self .pool_count , len (content_providers ))
265+ logger .info (f"Scan in { pool_count } processes for { len (content_providers )} providers" )
266+ with multiprocessing .get_context ("spawn" ).Pool (processes = pool_count ,
267+ initializer = CredSweeper .pool_initializer ,
289268 initargs = (log_kwargs , )) as pool :
290269 try :
291- for scan_results in pool .imap_unordered (self .files_scan , ( content_providers [ x :: self . pool_count ]
292- for x in range (self . pool_count ))):
270+ for scan_results in pool .imap_unordered (self .files_scan ,
271+ ( content_providers [ x :: pool_count ] for x in range (pool_count ))):
293272 for cred in scan_results :
294273 self .credential_manager .add_credential (cred )
295274 except KeyboardInterrupt :
@@ -301,9 +280,7 @@ def __multi_jobs_scan(self, content_providers: Sequence[Union[DiffContentProvide
301280
302281 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
303282
304- def files_scan (
305- self , #
306- content_providers : Sequence [Union [DiffContentProvider , TextContentProvider ]]) -> List [Candidate ]:
283+ def files_scan (self , content_providers : Sequence [ContentProvider ]) -> List [Candidate ]:
307284 """Auxiliary method for scan one sequence"""
308285 all_cred : List [Candidate ] = []
309286 for provider in content_providers :
@@ -316,7 +293,7 @@ def files_scan(
316293
317294 # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
318295
319- def file_scan (self , content_provider : Union [ DiffContentProvider , TextContentProvider ] ) -> List [Candidate ]:
296+ def file_scan (self , content_provider : ContentProvider ) -> List [Candidate ]:
320297 """Run scanning of file from 'file_provider'.
321298
322299 Args:
0 commit comments