Skip to content

Commit bc2bf6b

Browse files
authored
Replace enums to be extendable (#151)
1 parent 28e0e27 commit bc2bf6b

File tree

8 files changed

+57
-38
lines changed

8 files changed

+57
-38
lines changed

modelscan/error.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from enum import Enum
21
from modelscan.model import Model
32
import abc
43
from pathlib import Path

modelscan/issues.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from collections import defaultdict
88

9+
from modelscan.settings import Property
10+
911
logger = logging.getLogger("modelscan")
1012

1113

@@ -16,8 +18,8 @@ class IssueSeverity(Enum):
1618
CRITICAL = 4
1719

1820

19-
class IssueCode(Enum):
20-
UNSAFE_OPERATOR = 1
21+
class IssueCode:
22+
UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1)
2123

2224

2325
class IssueDetails(metaclass=abc.ABCMeta):
@@ -40,14 +42,14 @@ class Issue:
4042

4143
def __init__(
4244
self,
43-
code: IssueCode,
45+
code: Property,
4446
severity: IssueSeverity,
4547
details: IssueDetails,
4648
) -> None:
4749
"""
4850
Create a issue with given information
4951
50-
:param code: Code of the issue from the issue code enum.
52+
:param code: Code of the issue from the issue code class.
5153
:param severity: The severity level of the issue from Severity enum.
5254
:param details: An implementation of the IssueDetails object.
5355
"""
@@ -82,7 +84,7 @@ def __hash__(self) -> int:
8284

8385
def print(self) -> None:
8486
issue_description = self.code.name
85-
if self.code == IssueCode.UNSAFE_OPERATOR:
87+
if self.code.value == IssueCode.UNSAFE_OPERATOR.value:
8688
issue_description = "Unsafe operator"
8789
else:
8890
logger.error("No issue description for issue code %s", self.code)

modelscan/scanners/h5/scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from modelscan.scanners.scan import ScanResults
1919
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
2020
from modelscan.model import Model
21-
from modelscan.settings import DefaultModelFormats
21+
from modelscan.settings import SupportedModelFormats
2222

2323
logger = logging.getLogger("modelscan")
2424

@@ -28,7 +28,9 @@ def scan(
2828
self,
2929
model: Model,
3030
) -> Optional[ScanResults]:
31-
if DefaultModelFormats.KERAS_H5 not in model.get_context("formats"):
31+
if SupportedModelFormats.KERAS_H5.value not in [
32+
format_property.value for format_property in model.get_context("formats")
33+
]:
3234
return None
3335

3436
dep_error = self.handle_binary_dependencies()

modelscan/scanners/keras/scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
from modelscan.scanners.scan import ScanResults
1010
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
1111
from modelscan.model import Model
12-
from modelscan.settings import DefaultModelFormats
12+
from modelscan.settings import SupportedModelFormats
1313

1414

1515
logger = logging.getLogger("modelscan")
1616

1717

1818
class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
1919
def scan(self, model: Model) -> Optional[ScanResults]:
20-
if DefaultModelFormats.KERAS not in model.get_context("formats"):
20+
if SupportedModelFormats.KERAS.value not in [
21+
format_property.value for format_property in model.get_context("formats")
22+
]:
2123
return None
2224

2325
dep_error = self.handle_binary_dependencies()

modelscan/scanners/pickle/scan.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
scan_pytorch,
1010
)
1111
from modelscan.model import Model
12-
from modelscan.settings import DefaultModelFormats
12+
from modelscan.settings import SupportedModelFormats
1313

1414
logger = logging.getLogger("modelscan")
1515

@@ -19,7 +19,9 @@ def scan(
1919
self,
2020
model: Model,
2121
) -> Optional[ScanResults]:
22-
if DefaultModelFormats.PYTORCH not in model.get_context("formats"):
22+
if SupportedModelFormats.PYTORCH.value not in [
23+
format_property.value for format_property in model.get_context("formats")
24+
]:
2325
return None
2426

2527
if _is_zipfile(model.get_source(), model.get_stream()):
@@ -46,7 +48,9 @@ def scan(
4648
self,
4749
model: Model,
4850
) -> Optional[ScanResults]:
49-
if DefaultModelFormats.NUMPY not in model.get_context("formats"):
51+
if SupportedModelFormats.NUMPY.value not in [
52+
format_property.value for format_property in model.get_context("formats")
53+
]:
5054
return None
5155

5256
results = scan_numpy(
@@ -70,7 +74,9 @@ def scan(
7074
self,
7175
model: Model,
7276
) -> Optional[ScanResults]:
73-
if DefaultModelFormats.PICKLE not in model.get_context("formats"):
77+
if SupportedModelFormats.PICKLE.value not in [
78+
format_property.value for format_property in model.get_context("formats")
79+
]:
7480
return None
7581

7682
results = scan_pickle_bytes(

modelscan/scanners/saved_model/scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
2323
from modelscan.scanners.scan import ScanBase, ScanResults
2424
from modelscan.model import Model
25-
from modelscan.settings import DefaultModelFormats
25+
from modelscan.settings import SupportedModelFormats
2626

2727
logger = logging.getLogger("modelscan")
2828

@@ -32,7 +32,9 @@ def scan(
3232
self,
3333
model: Model,
3434
) -> Optional[ScanResults]:
35-
if DefaultModelFormats.TENSORFLOW not in model.get_context("formats"):
35+
if SupportedModelFormats.TENSORFLOW.value not in [
36+
format_property.value for format_property in model.get_context("formats")
37+
]:
3638
return None
3739

3840
dep_error = self.handle_binary_dependencies()

modelscan/settings.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
import tomlkit
22

3-
from enum import Enum
43
from typing import Any
54

65
from modelscan._version import __version__
76

87

9-
class DefaultModelFormats(Enum):
10-
TENSORFLOW = "tensorflow"
11-
KERAS_H5 = "keras_h5"
12-
KERAS = "keras"
13-
NUMPY = "numpy"
14-
PYTORCH = "pytorch"
15-
PICKLE = "pickle"
8+
class Property:
9+
def __init__(self, name: str, value: Any) -> None:
10+
self.name = name
11+
self.value = value
12+
13+
14+
class SupportedModelFormats:
15+
TENSORFLOW = Property("TENSORFLOW", "tensorflow")
16+
KERAS_H5 = Property("KERAS_H5", "keras_h5")
17+
KERAS = Property("KERAS", "keras")
18+
NUMPY = Property("NUMPY", "numpy")
19+
PYTORCH = Property("PYTORCH", "pytorch")
20+
PICKLE = Property("PICKLE", "pickle")
1621

1722

1823
DEFAULT_REPORTING_MODULES = {
@@ -70,12 +75,12 @@ class DefaultModelFormats(Enum):
7075
"middlewares": {
7176
"modelscan.middlewares.FormatViaExtensionMiddleware": {
7277
"formats": {
73-
DefaultModelFormats.TENSORFLOW: [".pb"],
74-
DefaultModelFormats.KERAS_H5: [".h5"],
75-
DefaultModelFormats.KERAS: [".keras"],
76-
DefaultModelFormats.NUMPY: [".npy"],
77-
DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
78-
DefaultModelFormats.PICKLE: [
78+
SupportedModelFormats.TENSORFLOW: [".pb"],
79+
SupportedModelFormats.KERAS_H5: [".h5"],
80+
SupportedModelFormats.KERAS: [".keras"],
81+
SupportedModelFormats.NUMPY: [".npy"],
82+
SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
83+
SupportedModelFormats.PICKLE: [
7984
".pkl",
8085
".pickle",
8186
".joblib",

modelscan/skip.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import logging
22
from enum import Enum
33

4+
from modelscan.settings import Property
45

56
logger = logging.getLogger("modelscan")
67

78

8-
class SkipCategories(Enum):
9-
SCAN_NOT_SUPPORTED = 1
10-
BAD_ZIP = 2
11-
MODEL_CONFIG = 3
12-
H5_DATA = 4
13-
NOT_IMPLEMENTED = 5
14-
MAGIC_NUMBER = 6
9+
class SkipCategories:
10+
SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1)
11+
BAD_ZIP = Property("BAD_ZIP", 2)
12+
MODEL_CONFIG = Property("MODEL_CONFIG", 3)
13+
H5_DATA = Property("H5_DATA", 4)
14+
NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5)
15+
MAGIC_NUMBER = Property("MAGIC_NUMBER", 6)
1516

1617

1718
class Skip:
@@ -31,7 +32,7 @@ class ModelScanSkipped:
3132
def __init__(
3233
self,
3334
scan_name: str,
34-
category: SkipCategories,
35+
category: Property,
3536
message: str,
3637
source: str,
3738
) -> None:

0 commit comments

Comments
 (0)