|
1 | 1 | import tomlkit |
2 | 2 |
|
3 | | -from enum import Enum |
4 | 3 | from typing import Any |
5 | 4 |
|
6 | 5 | from modelscan._version import __version__ |
7 | 6 |
|
8 | 7 |
|
9 | | -class DefaultModelFormats(Enum): |
10 | | - TENSORFLOW = "tensorflow" |
11 | | - KERAS_H5 = "keras_h5" |
12 | | - KERAS = "keras" |
13 | | - NUMPY = "numpy" |
14 | | - PYTORCH = "pytorch" |
15 | | - PICKLE = "pickle" |
| 8 | +class Property: |
| 9 | + def __init__(self, name: str, value: Any) -> None: |
| 10 | + self.name = name |
| 11 | + self.value = value |
| 12 | + |
| 13 | + |
| 14 | +class SupportedModelFormats: |
| 15 | + TENSORFLOW = Property("TENSORFLOW", "tensorflow") |
| 16 | + KERAS_H5 = Property("KERAS_H5", "keras_h5") |
| 17 | + KERAS = Property("KERAS", "keras") |
| 18 | + NUMPY = Property("NUMPY", "numpy") |
| 19 | + PYTORCH = Property("PYTORCH", "pytorch") |
| 20 | + PICKLE = Property("PICKLE", "pickle") |
16 | 21 |
|
17 | 22 |
|
18 | 23 | DEFAULT_REPORTING_MODULES = { |
@@ -70,12 +75,12 @@ class DefaultModelFormats(Enum): |
70 | 75 | "middlewares": { |
71 | 76 | "modelscan.middlewares.FormatViaExtensionMiddleware": { |
72 | 77 | "formats": { |
73 | | - DefaultModelFormats.TENSORFLOW: [".pb"], |
74 | | - DefaultModelFormats.KERAS_H5: [".h5"], |
75 | | - DefaultModelFormats.KERAS: [".keras"], |
76 | | - DefaultModelFormats.NUMPY: [".npy"], |
77 | | - DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], |
78 | | - DefaultModelFormats.PICKLE: [ |
| 78 | + SupportedModelFormats.TENSORFLOW: [".pb"], |
| 79 | + SupportedModelFormats.KERAS_H5: [".h5"], |
| 80 | + SupportedModelFormats.KERAS: [".keras"], |
| 81 | + SupportedModelFormats.NUMPY: [".npy"], |
| 82 | + SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], |
| 83 | + SupportedModelFormats.PICKLE: [ |
79 | 84 | ".pkl", |
80 | 85 | ".pickle", |
81 | 86 | ".joblib", |
|
0 commit comments