forked from twitter/caladrius
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloader.py
More file actions
126 lines (94 loc) · 4.32 KB
/
loader.py
File metadata and controls
126 lines (94 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
""" This module contains functions for loading classes from configuration file
variables. """
import logging
import warnings
from importlib import import_module
from typing import Type, Dict, Any, List
import yaml
LOG: logging.Logger = logging.getLogger(__name__)
def get_class(class_path: str) -> Type:
""" Method for loading a class from a absolute class path string.
Arguments:
class_path (str): The absolute import path to the class. For example:
pkg.module.ClassName
Returns:
The class object referred to in the supplied class path.
Raises:
ModuleNotFoundError: If the module part of the class path could not
be found.
AttributeError: If the module could be found but the specified class
name was not defined within it.
"""
LOG.info("Loading class: %s", class_path)
module_path, class_name = class_path.rsplit(".", 1)
try:
module = import_module(module_path)
except ModuleNotFoundError as mnf_err:
LOG.error("Module %s could not be found", module_path)
raise mnf_err
try:
found_class = module.__getattribute__(class_name)
except AttributeError as att_err:
LOG.error("Class %s is not part of module %s", class_name, module_path)
raise att_err
LOG.info("Successfully loaded class: %s from module: %s", class_name,
module_path)
return found_class
def load_config(file_path: str) -> Dict[str, Any]:
""" Converts the yaml file at the supplied path to a dictionary.
Arguments:
file_path (str): The path to the yaml formatted configuration file.
Returns:
A dictionary formed from the supplied yaml file.
"""
LOG.info("Loading yaml file at: %s", file_path)
with open(file_path, "r") as yaml_file:
yaml_dict: Dict[str, Any] = yaml.load(yaml_file)
return yaml_dict
def get_model_classes(config: Dict[str, Any], dsps_name: str,
model_type: str) -> List[Type]:
""" This method loads model classes from lists in the config dictionary and
checks for name and description properties.
Arguments:
config (dict): The main configuration dictionary containing the model
class paths under "{dsps_name}.{model_type}.models"
key.
dsps_name (str): The name of the streaming system whose models are
to be loaded.
model_type (str): The model type, traffic, topology etc.
Returns:
List[Type]: A list of Model Types.
Raises:
RuntimeError: If a model class does not have a name class property
set or if the name property of one model is the same
as another in the list.
UserWarning: If a model class does not have the description class
property set.
"""
model_classes: List[Type] = []
model_names: List[str] = []
for model in config[f"{dsps_name}.{model_type}.models"]:
model_class: Type = get_class(model)
if model_class.name == "base":
name_msg: str = (f"Model {str(model_class)} does not have a "
f"'name' class property defined. This is required"
f" for it to be correctly identified in the API.")
LOG.error(name_msg)
raise RuntimeError(name_msg)
if model_class.name in model_names:
other_model_index: int = model_names.index(model_class.name)
dup_msg: str = (f"The model {str(model_class)} has the same 'name'"
f" class property as "
f"{str(model_classes[other_model_index])}. The "
f"names of models should be unique.")
LOG.error(dup_msg)
raise RuntimeError(dup_msg)
if model_class.description == "base":
desc_msg: str = (f"Model {str(model_class)} does not have a "
f"'description' class property defined. This is "
f"recommended for use in the API.")
LOG.warning(desc_msg)
warnings.warn(desc_msg)
model_classes.append(model_class)
model_names.append(model_class.name)
return model_classes