Skip to content

Commit a7051b7

Browse files
authored
Multiprocessing refactoring (#713)
* migration to ProcessPoolExecutor * style+fix * lazyload InferenceSession * log info for scan and test * Removed extra * custom BM * [skip actions] [mlval] 2025-05-06T15:11:31+03:00 * rollback to multiprocessing * testfix * style * small fixes * Linter fix * startswith a tuple * Simplified * add works faster * doc upd * logging optimization * test fix * style * typization
1 parent 4336e09 commit a7051b7

File tree

9 files changed

+142
-84
lines changed

9 files changed

+142
-84
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[flake8]
22
max-line-length = 120
3-
extend-ignore = E203,E303,E131,E402
3+
extend-ignore = E402
44
per-file-ignores = __init__.py:F401

credsweeper/app.py

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from credsweeper.config import Config
1616
from credsweeper.credentials import Candidate, CredentialManager, CandidateKey
1717
from credsweeper.deep_scanner.deep_scanner import DeepScanner
18+
from credsweeper.file_handler.content_provider import ContentProvider
1819
from credsweeper.file_handler.diff_content_provider import DiffContentProvider
1920
from credsweeper.file_handler.file_path_extractor import FilePathExtractor
2021
from credsweeper.file_handler.abstract_provider import AbstractProvider
2122
from credsweeper.file_handler.text_content_provider import TextContentProvider
2223
from credsweeper.scanner import Scanner
24+
from credsweeper.ml_model.ml_validator import MlValidator
2325
from credsweeper.utils import Util
2426

2527
logger = logging.getLogger(__name__)
@@ -94,7 +96,7 @@ def __init__(self,
9496
log_level: str - level for pool initializer according logging levels (UPPERCASE)
9597
9698
"""
97-
self.pool_count: int = int(pool_count) if int(pool_count) > 1 else 1
99+
self.pool_count: int = max(1, int(pool_count))
98100
if not (_severity := Severity.get(severity)):
99101
raise RuntimeError(f"Severity level provided: {severity}"
100102
f" -- must be one of: {' | '.join([i.value for i in Severity])}")
@@ -123,9 +125,9 @@ def __init__(self,
123125
self.ml_config = ml_config
124126
self.ml_model = ml_model
125127
self.ml_providers = ml_providers
126-
self.ml_validator = None
127128
self.__thrifty = thrifty
128129
self.__log_level = log_level
130+
self.__ml_validator: Optional[MlValidator] = None
129131

130132
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
131133

@@ -182,35 +184,22 @@ def _use_ml_validation(self) -> bool:
182184

183185
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
184186

185-
# the import cannot be done on top due
186-
# TypeError: cannot pickle 'onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession' object
187-
from credsweeper.ml_model import MlValidator
188-
189-
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
190-
191187
@property
192188
def ml_validator(self) -> MlValidator:
193189
"""ml_validator getter"""
194-
from credsweeper.ml_model import MlValidator
195190
if not self.__ml_validator:
196-
self.__ml_validator: MlValidator = MlValidator(
191+
self.__ml_validator = MlValidator(
197192
threshold=self.ml_threshold, #
198193
ml_config=self.ml_config, #
199194
ml_model=self.ml_model, #
200195
ml_providers=self.ml_providers, #
201196
)
202-
assert self.__ml_validator, "self.__ml_validator was not initialized"
197+
if not self.__ml_validator:
198+
raise RuntimeError("MlValidator was not initialized!")
203199
return self.__ml_validator
204200

205201
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
206202

207-
@ml_validator.setter
208-
def ml_validator(self, _ml_validator: Optional[MlValidator]) -> None:
209-
"""ml_validator setter"""
210-
self.__ml_validator = _ml_validator
211-
212-
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
213-
214203
@staticmethod
215204
def pool_initializer(log_kwargs) -> None:
216205
"""Ignore SIGINT in child processes."""
@@ -219,20 +208,6 @@ def pool_initializer(log_kwargs) -> None:
219208

220209
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
221210

222-
@property
223-
def config(self) -> Config:
224-
"""config getter"""
225-
return self.__config
226-
227-
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
228-
229-
@config.setter
230-
def config(self, config: Config) -> None:
231-
"""config setter"""
232-
self.__config = config
233-
234-
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
235-
236211
def run(self, content_provider: AbstractProvider) -> int:
237212
"""Run an analysis of 'content_provider' object.
238213
@@ -241,9 +216,10 @@ def run(self, content_provider: AbstractProvider) -> int:
241216
242217
"""
243218
_empty_list: Sequence[Union[DiffContentProvider, TextContentProvider]] = []
244-
file_extractors: Sequence[Union[DiffContentProvider, TextContentProvider]] = \
245-
content_provider.get_scannable_files(self.config) if content_provider else _empty_list
246-
logger.info(f"Start Scanner for {len(file_extractors)} providers")
219+
file_extractors = content_provider.get_scannable_files(self.config) if content_provider else _empty_list
220+
if not file_extractors:
221+
logger.info(f"No scannable targets for {len(content_provider.paths)} paths")
222+
return 0
247223
self.scan(file_extractors)
248224
self.post_processing()
249225
# PatchesProvider has the attribute. Circular import error appears with using the isinstance
@@ -260,7 +236,7 @@ def scan(self, content_providers: Sequence[Union[DiffContentProvider, TextConten
260236
content_providers: file objects to scan
261237
262238
"""
263-
if 1 < self.pool_count:
239+
if 1 < self.pool_count and 1 < len(content_providers):
264240
self.__multi_jobs_scan(content_providers)
265241
else:
266242
self.__single_job_scan(content_providers)
@@ -269,6 +245,7 @@ def scan(self, content_providers: Sequence[Union[DiffContentProvider, TextConten
269245

270246
def __single_job_scan(self, content_providers: Sequence[Union[DiffContentProvider, TextContentProvider]]) -> None:
271247
"""Performs scan in main thread"""
248+
logger.info(f"Scan for {len(content_providers)} providers")
272249
all_cred = self.files_scan(content_providers)
273250
self.credential_manager.set_credentials(all_cred)
274251

@@ -284,12 +261,14 @@ def __multi_jobs_scan(self, content_providers: Sequence[Union[DiffContentProvide
284261
if "SILENCE" == self.__log_level:
285262
logging.addLevelName(60, "SILENCE")
286263
log_kwargs["level"] = self.__log_level
287-
with multiprocessing.get_context("spawn").Pool(processes=self.pool_count,
288-
initializer=self.pool_initializer,
264+
pool_count = min(self.pool_count, len(content_providers))
265+
logger.info(f"Scan in {pool_count} processes for {len(content_providers)} providers")
266+
with multiprocessing.get_context("spawn").Pool(processes=pool_count,
267+
initializer=CredSweeper.pool_initializer,
289268
initargs=(log_kwargs, )) as pool:
290269
try:
291-
for scan_results in pool.imap_unordered(self.files_scan, (content_providers[x::self.pool_count]
292-
for x in range(self.pool_count))):
270+
for scan_results in pool.imap_unordered(self.files_scan,
271+
(content_providers[x::pool_count] for x in range(pool_count))):
293272
for cred in scan_results:
294273
self.credential_manager.add_credential(cred)
295274
except KeyboardInterrupt:
@@ -301,9 +280,7 @@ def __multi_jobs_scan(self, content_providers: Sequence[Union[DiffContentProvide
301280

302281
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
303282

304-
def files_scan(
305-
self, #
306-
content_providers: Sequence[Union[DiffContentProvider, TextContentProvider]]) -> List[Candidate]:
283+
def files_scan(self, content_providers: Sequence[ContentProvider]) -> List[Candidate]:
307284
"""Auxiliary method for scan one sequence"""
308285
all_cred: List[Candidate] = []
309286
for provider in content_providers:
@@ -316,7 +293,7 @@ def files_scan(
316293

317294
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
318295

319-
def file_scan(self, content_provider: Union[DiffContentProvider, TextContentProvider]) -> List[Candidate]:
296+
def file_scan(self, content_provider: ContentProvider) -> List[Candidate]:
320297
"""Run scanning of file from 'file_provider'.
321298
322299
Args:

credsweeper/filters/value_base64_encoded_pem_check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def run(self, line_data: LineData, target: AnalysisTarget) -> bool:
3030
with contextlib.suppress(Exception):
3131
text = Util.decode_base64(line_data.value, padding_safe=True, urlsafe_detect=True)
3232
lines = text.decode(ASCII).splitlines()
33-
lines_pos = [x for x in range(len(lines))]
33+
lines_pos = list(range(len(lines)))
3434
for line_pos, line in zip(lines_pos, lines):
3535
if PEM_BEGIN_PATTERN in line:
3636
new_target = AnalysisTarget(line_pos, lines, lines_pos, target.descriptor)

credsweeper/ml_model/ml_validator.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import hashlib
2+
import json
23
import logging
34
from pathlib import Path
45
from typing import List, Tuple, Union, Optional, Dict
56

67
import numpy as np
7-
import onnxruntime as ort
8+
from onnxruntime import InferenceSession
89

910
import credsweeper.ml_model.features as features
1011
from credsweeper.common.constants import ThresholdPreset, ML_HUNK
@@ -22,6 +23,8 @@ class MlValidator:
2223
# applied for unknown characters
2324
FAKE_CHAR = '\x01'
2425

26+
_dir_path = Path(__file__).parent
27+
2528
def __init__(
2629
self, #
2730
threshold: Union[float, ThresholdPreset], #
@@ -36,35 +39,36 @@ def __init__(
3639
ml_model: path to ml model
3740
ml_providers: coma separated list of providers https://onnxruntime.ai/docs/execution-providers/
3841
"""
39-
dir_path = Path(__file__).parent
42+
self.__session: Optional[InferenceSession] = None
4043

4144
if ml_config:
4245
ml_config_path = Path(ml_config)
4346
else:
44-
ml_config_path = dir_path / "ml_config.json"
47+
ml_config_path = MlValidator._dir_path / "ml_config.json"
4548
with open(ml_config_path, "rb") as f:
46-
md5_config = hashlib.md5(f.read()).hexdigest()
49+
__ml_config_data = f.read()
50+
51+
model_config = json.loads(__ml_config_data)
4752

4853
if ml_model:
4954
ml_model_path = Path(ml_model)
5055
else:
51-
ml_model_path = dir_path / "ml_model.onnx"
56+
ml_model_path = MlValidator._dir_path / "ml_model.onnx"
5257
with open(ml_model_path, "rb") as f:
53-
md5_model = hashlib.md5(f.read()).hexdigest()
58+
self.__ml_model_data = f.read()
5459

5560
if ml_providers:
56-
providers = ml_providers.split(',')
61+
self.providers = ml_providers.split(',')
5762
else:
58-
providers = ["CPUExecutionProvider"]
59-
self.model_session = ort.InferenceSession(ml_model_path, providers=providers)
63+
self.providers = ["CPUExecutionProvider"]
6064

61-
model_config = Util.json_load(ml_config_path)
6265
if isinstance(threshold, float):
6366
self.threshold = threshold
6467
elif isinstance(threshold, ThresholdPreset) and "thresholds" in model_config:
6568
self.threshold = model_config["thresholds"][threshold.value]
6669
else:
6770
self.threshold = 0.5
71+
logger.warning(f"Use fallback threshold value: {self.threshold}")
6872

6973
char_set = set(model_config["char_set"])
7074
if len(char_set) != len(model_config["char_set"]):
@@ -80,26 +84,44 @@ def __init__(
8084

8185
self.common_feature_list = []
8286
self.unique_feature_list = []
83-
logger.info("Init ML validator with %s provider; config:'%s' md5:%s model:'%s' md5:%s", providers,
84-
ml_config_path, md5_config, ml_model_path, md5_model)
85-
logger.debug("ML validator details: %s", model_config)
87+
if logger.isEnabledFor(logging.INFO):
88+
config_dbg = str(model_config) if logger.isEnabledFor(logging.DEBUG) else ''
89+
config_md5 = hashlib.md5(__ml_config_data).hexdigest()
90+
model_md5 = hashlib.md5(self.__ml_model_data).hexdigest()
91+
logger.info("Init ML validator with providers: '%s' ; model:'%s' md5:%s ; config:'%s' md5:%s ; %s",
92+
self.providers, ml_config_path, config_md5, ml_model_path, model_md5, config_dbg)
8693
for feature_definition in model_config["features"]:
8794
feature_class = feature_definition["type"]
8895
kwargs = feature_definition.get("kwargs", {})
8996
feature_constructor = getattr(features, feature_class, None)
9097
if feature_constructor is None:
91-
raise ValueError(f'Error while parsing model details. Cannot create feature "{feature_class}"')
98+
raise ValueError(f"Error while parsing model details. Cannot create feature '{feature_class}'"
99+
f" from {feature_definition}")
92100
try:
93101
feature = feature_constructor(**kwargs)
94102
except TypeError:
95-
logger.error(f'Error while parsing model details. Cannot create feature "{feature_class}"'
96-
f' with kwargs "{kwargs}"')
103+
logger.error(f"Error while parsing model details. Cannot create feature '{feature_class}'"
104+
f" from {feature_definition}")
97105
raise
98106
if feature_definition["type"] in ["RuleName"]:
99107
self.unique_feature_list.append(feature)
100108
else:
101109
self.common_feature_list.append(feature)
102110

111+
def __reduce__(self):
112+
# TypeError: cannot pickle 'onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession' object
113+
self.__session = None
114+
return super().__reduce__()
115+
116+
@property
117+
def session(self) -> InferenceSession:
118+
"""session getter to prevent pickle error"""
119+
if not self.__session:
120+
self.__session = InferenceSession(self.__ml_model_data, providers=self.providers)
121+
if not self.__session:
122+
raise RuntimeError("InferenceSession was not initialized!")
123+
return self.__session
124+
103125
def encode(self, text: str, limit: int) -> np.ndarray:
104126
"""Encodes prepared text to array"""
105127
result_array: np.ndarray = np.zeros(shape=(limit, self.num_classes), dtype=np.float32)
@@ -136,7 +158,7 @@ def _call_model(self, line_input: np.ndarray, variable_input: np.ndarray, value_
136158
"value_input": value_input.astype(np.float32),
137159
"feature_input": feature_input.astype(np.float32),
138160
}
139-
result = self.model_session.run(output_names=None, input_feed=input_feed)
161+
result = self.session.run(output_names=None, input_feed=input_feed)
140162
if result and isinstance(result[0], np.ndarray):
141163
return result[0]
142164
raise RuntimeError(f"Unexpected type {type(result[0])}")
@@ -178,8 +200,8 @@ def get_group_features(self, candidates: List[Candidate]) -> Tuple[np.ndarray, n
178200
default_candidate = candidates[0]
179201
line_input = self.encode_line(default_candidate.line_data_list[0].line,
180202
default_candidate.line_data_list[0].value_start)[np.newaxis]
181-
variable = ""
182-
value = ""
203+
variable = ''
204+
value = ''
183205
for candidate in candidates:
184206
if not variable and candidate.line_data_list[0].variable:
185207
variable = candidate.line_data_list[0].variable
@@ -251,8 +273,8 @@ def validate_groups(self, group_list: List[Tuple[CandidateKey, List[Candidate]]]
251273
features_list)
252274
is_cred = probability > self.threshold
253275
if logger.isEnabledFor(logging.DEBUG):
254-
for i, _ in enumerate(is_cred):
255-
logger.debug("ML decision: %s with prediction: %s for value: %s", is_cred[i], probability[i],
276+
for i, decision in enumerate(is_cred):
277+
logger.debug("ML decision: %s with prediction: %s for value: %s", decision, probability[i],
256278
group_list[i][0])
257279
# apply cast to float to avoid json export issue
258280
return is_cred, probability.astype(float)

credsweeper/utils/pem_key_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def sanitize_line(cls, line: str, recurse_level: int = 5) -> str:
126126
line = line.strip(string.whitespace)
127127
if line.startswith("//"):
128128
# simplify first condition for speed-up of doxygen style processing
129-
if line.startswith("// ") or line.startswith("/// "):
129+
if line.startswith(("// ", "/// ")):
130130
# Assume that the commented line is to be separated from base64 code, it may be a part of PEM, otherwise
131131
line = line[3:]
132132
if line.startswith("/*"):

credsweeper/utils/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def is_jks(data: Union[bytes, bytearray]) -> bool:
465465
def is_lzma(data: Union[bytes, bytearray]) -> bool:
466466
"""According https://en.wikipedia.org/wiki/List_of_file_signatures - lzma also xz"""
467467
if isinstance(data, (bytes, bytearray)) and 6 <= len(data):
468-
if data.startswith(b"\xFD\x37\x7A\x58\x5A\x00") or data.startswith(b"\x5D\x00\x00"):
468+
if data.startswith((b"\xFD\x37\x7A\x58\x5A\x00", b"\x5D\x00\x00")):
469469
return True
470470
return False
471471

docs/source/credsweeper.deep_scanner.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ credsweeper.deep\_scanner.bzip2\_scanner module
2828
:undoc-members:
2929
:show-inheritance:
3030

31+
credsweeper.deep\_scanner.deb\_scanner module
32+
---------------------------------------------
33+
34+
.. automodule:: credsweeper.deep_scanner.deb_scanner
35+
:members:
36+
:undoc-members:
37+
:show-inheritance:
38+
3139
credsweeper.deep\_scanner.deep\_scanner module
3240
----------------------------------------------
3341

tests/test_app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -732,8 +732,10 @@ def test_external_ml_n(self) -> None:
732732
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
733733

734734
def test_external_ml_p(self) -> None:
735-
log_pattern = re.compile(
736-
r".*Init ML validator with .+ provider; config:'.+' md5:([0-9a-f]{32}) model:'.+' md5:([0-9a-f]{32})")
735+
log_pattern = re.compile(r".*Init ML validator with providers: \S+ ;"
736+
r" model:'.+' md5:([0-9a-f]{32}) ;"
737+
r" config:'.+' md5:([0-9a-f]{32}) ;"
738+
r" .*")
737739
_stdout, _stderr = self._m_credsweeper(["--path", str(APP_PATH), "--log", "INFO"])
738740
self.assertEqual(0, len(_stderr))
739741
self.assertNotIn("CRITICAL", _stdout)

0 commit comments

Comments
 (0)