Skip to content

Commit 2d49c4f

Browse files
authored
feat: watchdog observer for config auto loading (#119)
Signed-off-by: Nandita Koppisetty <[email protected]>
1 parent f8f2403 commit 2d49c4f

27 files changed

+421
-207
lines changed

numaprom/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33

4-
from numaprom._config import UnifiedConf, MetricConf, ServiceConf, NumapromConf
4+
from numaprom._config import UnifiedConf, MetricConf, AppConf, NumapromConf
55

66

77
def get_logger(name):
@@ -25,4 +25,4 @@ def get_logger(name):
2525
return logger
2626

2727

28-
__all__ = ["UnifiedConf", "MetricConf", "ServiceConf", "NumapromConf", "get_logger"]
28+
__all__ = ["UnifiedConf", "MetricConf", "AppConf", "NumapromConf", "get_logger"]

numaprom/_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ class MetricConf:
2626

2727

2828
@dataclass
29-
class ServiceConf:
30-
service: str = "default"
29+
class AppConf:
30+
app: str = "default"
3131
namespace: str = "default"
3232
metric_configs: List[MetricConf] = field(default_factory=lambda: [MetricConf()])
3333
unified_configs: List[UnifiedConf] = field(default_factory=list)
3434

3535

3636
@dataclass
3737
class NumapromConf:
38-
configs: List[ServiceConf]
38+
configs: List[AppConf]

numaprom/_constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
NUMAPROM_DIR = os.path.dirname(__file__)
44
ROOT_DIR = os.path.split(NUMAPROM_DIR)[0]
55
TESTS_DIR = os.path.join(ROOT_DIR, "tests")
6+
TESTS_RESOURCES = os.path.join(TESTS_DIR, "resources")
67
DATA_DIR = os.path.join(NUMAPROM_DIR, "data")
78
CONFIG_DIR = os.path.join(NUMAPROM_DIR, "configs")
89
DEFAULT_CONFIG_DIR = os.path.join(NUMAPROM_DIR, "default-configs")
@@ -17,3 +18,5 @@
1718
INFERENCE_VTX_KEY = "inference"
1819
THRESHOLD_VTX_KEY = "threshold"
1920
POSTPROC_VTX_KEY = "postproc"
21+
22+
CONFIG_PATHS = ["./numaprom/configs", "./numaprom/default-configs"]

numaprom/tools.py

+2-68
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414
from botocore.session import get_session
1515
from mlflow.entities.model_registry import ModelVersion
1616
from mlflow.exceptions import RestException
17-
from numalogic.config import NumalogicConf, PostprocessFactory
17+
from numalogic.config import PostprocessFactory
1818
from numalogic.models.threshold import SigmoidThreshold
1919
from numalogic.registry import MLflowRegistry, ArtifactData
20-
from omegaconf import OmegaConf
2120
from pynumaflow.function import Messages, Message
2221

23-
from numaprom import get_logger, MetricConf, ServiceConf, NumapromConf, UnifiedConf
22+
from numaprom import get_logger, MetricConf
2423
from numaprom._constants import (
2524
DEFAULT_TRACKING_URI,
2625
DEFAULT_PROMETHEUS_SERVER,
27-
CONFIG_DIR,
28-
DEFAULT_CONFIG_DIR,
2926
)
3027
from numaprom.entities import TrainerPayload, StreamPayload
3128
from numaprom.clients.prometheus import Prometheus
@@ -155,69 +152,6 @@ def save_model(
155152
return version
156153

157154

158-
def get_all_configs():
159-
schema: NumapromConf = OmegaConf.structured(NumapromConf)
160-
161-
conf = OmegaConf.load(os.path.join(CONFIG_DIR, "config.yaml"))
162-
given_configs = OmegaConf.merge(schema, conf).configs
163-
164-
conf = OmegaConf.load(os.path.join(DEFAULT_CONFIG_DIR, "config.yaml"))
165-
default_configs = OmegaConf.merge(schema, conf).configs
166-
167-
conf = OmegaConf.load(os.path.join(DEFAULT_CONFIG_DIR, "numalogic_config.yaml"))
168-
schema: NumalogicConf = OmegaConf.structured(NumalogicConf)
169-
default_numalogic = OmegaConf.merge(schema, conf)
170-
171-
return given_configs, default_configs, default_numalogic
172-
173-
174-
def get_service_config(metric: str, namespace: str):
175-
given_configs, default_configs, default_numalogic = get_all_configs()
176-
177-
# search and load from given configs
178-
service_config = list(filter(lambda conf: (conf.namespace == namespace), given_configs))
179-
180-
# if not search and load from default configs
181-
if not service_config:
182-
for _conf in default_configs:
183-
if metric in _conf.unified_configs[0].unified_metrics:
184-
service_config = [_conf]
185-
break
186-
187-
# if not in default configs, initialize Namespace conf with default values
188-
if not service_config:
189-
service_config = OmegaConf.structured(ServiceConf)
190-
else:
191-
service_config = service_config[0]
192-
193-
# loading and setting default numalogic config
194-
for metric_config in service_config.metric_configs:
195-
if OmegaConf.is_missing(metric_config, "numalogic_conf"):
196-
metric_config.numalogic_conf = default_numalogic
197-
198-
return service_config
199-
200-
201-
def get_metric_config(metric: str, namespace: str) -> Optional[MetricConf]:
202-
service_config = get_service_config(metric, namespace)
203-
metric_config = list(
204-
filter(lambda conf: (conf.metric == metric), service_config.metric_configs)
205-
)
206-
if not metric_config:
207-
return service_config.metric_configs[0]
208-
return metric_config[0]
209-
210-
211-
def get_unified_config(metric: str, namespace: str) -> Optional[UnifiedConf]:
212-
service_config = get_service_config(metric, namespace)
213-
unified_config = list(
214-
filter(lambda conf: (metric in conf.unified_metrics), service_config.unified_configs)
215-
)
216-
if not unified_config:
217-
return None
218-
return unified_config[0]
219-
220-
221155
def fetch_data(
222156
payload: TrainerPayload, metric_config: MetricConf, labels: dict, return_labels=None
223157
) -> pd.DataFrame:

numaprom/udf/inference.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from numaprom.entities import Status, StreamPayload, Header
1515
from numaprom.tools import (
1616
load_model,
17-
get_metric_config,
1817
msg_forward,
1918
)
19+
from numaprom.watcher import ConfigManager
2020

2121
_LOGGER = get_logger(__name__)
2222

@@ -73,9 +73,7 @@ def inference(_: str, datum: Datum) -> bytes:
7373
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
7474

7575
# Load config
76-
metric_config = get_metric_config(
77-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
78-
)
76+
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
7977
numalogic_conf = metric_config.numalogic_conf
8078

8179
# Load inference model

numaprom/udf/postprocess.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
from numaprom.clients.redis import get_redis_client
1313
from numaprom.tools import (
1414
msgs_forward,
15-
get_unified_config,
16-
get_metric_config,
1715
WindowScorer,
1816
)
17+
from numaprom.watcher import ConfigManager
1918

2019
_LOGGER = get_logger(__name__)
2120

@@ -130,9 +129,7 @@ def __construct_unified_payload(
130129

131130

132131
def _publish(final_score: float, payload: StreamPayload) -> List[bytes]:
133-
unified_config = get_unified_config(
134-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
135-
)
132+
unified_config = ConfigManager().get_unified_config(payload.composite_keys)
136133

137134
publisher_json = __construct_publisher_payload(payload, final_score).as_json()
138135
_LOGGER.info("%s - Payload sent to publisher: %s", payload.uuid, publisher_json)
@@ -181,9 +178,7 @@ def postprocess(_: str, datum: Datum) -> List[bytes]:
181178
payload = StreamPayload(**orjson.loads(_in_msg))
182179

183180
# Load config
184-
metric_config = get_metric_config(
185-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
186-
)
181+
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
187182

188183
_LOGGER.debug("%s - Received Payload: %r ", payload.uuid, payload)
189184

numaprom/udf/preprocess.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from numaprom import get_logger
77
from numaprom.entities import Status, StreamPayload, Header
8-
from numaprom.tools import msg_forward, load_model, get_metric_config
8+
from numaprom.tools import msg_forward, load_model
9+
from numaprom.watcher import ConfigManager
910

1011
_LOGGER = get_logger(__name__)
1112

@@ -19,9 +20,7 @@ def preprocess(_: str, datum: Datum) -> bytes:
1920
_LOGGER.info("%s - Received Payload: %r ", payload.uuid, payload)
2021

2122
# Load config
22-
metric_config = get_metric_config(
23-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
24-
)
23+
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
2524
preprocess_cfgs = metric_config.numalogic_conf.preprocess
2625

2726
# Load preprocess artifact

numaprom/udf/threshold.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
conditional_forward,
1212
calculate_static_thresh,
1313
load_model,
14-
get_metric_config,
1514
)
15+
from numaprom.watcher import ConfigManager
1616

1717
_LOGGER = get_logger(__name__)
1818

@@ -44,9 +44,8 @@ def threshold(_: str, datum: Datum) -> list[tuple[str, bytes]]:
4444
)
4545

4646
# Load config
47-
metric_config = get_metric_config(
48-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
49-
)
47+
cm = ConfigManager()
48+
metric_config = cm.get_metric_config(payload.composite_keys)
5049
thresh_cfg = metric_config.numalogic_conf.threshold
5150

5251
# Check if payload needs static inference

numaprom/udf/window.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from numaprom import get_logger
1313
from numaprom.entities import StreamPayload, Status, Header
1414
from numaprom.clients.redis import get_redis_client
15-
from numaprom.tools import msg_forward, create_composite_keys, get_metric_config
15+
from numaprom.tools import msg_forward, create_composite_keys
16+
from numaprom.watcher import ConfigManager
1617

1718
_LOGGER = get_logger(__name__)
1819

@@ -68,7 +69,9 @@ def window(_: str, datum: Datum) -> Optional[bytes]:
6869
_start_time = time.perf_counter()
6970
msg = orjson.loads(datum.value)
7071

71-
metric_config = get_metric_config(metric=msg["name"], namespace=msg["labels"]["namespace"])
72+
metric_config = ConfigManager().get_metric_config(
73+
{"name": msg["name"], "namespace": msg["labels"]["namespace"]}
74+
)
7275
win_size = metric_config.numalogic_conf.model.conf["seq_len"]
7376
buff_size = int(os.getenv("BUFF_SIZE", 10 * win_size))
7477

numaprom/udsink/train.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from numaprom import get_logger
1616
from numaprom.entities import TrainerPayload
1717
from numaprom.clients.redis import get_redis_client
18-
from numaprom.tools import get_metric_config, save_model, fetch_data
18+
from numaprom.tools import save_model, fetch_data
19+
from numaprom.watcher import ConfigManager
1920

2021
_LOGGER = get_logger(__name__)
2122

@@ -103,9 +104,7 @@ def train(datums: List[Datum]) -> Responses:
103104
responses.append(Response.as_success(_datum.id))
104105
continue
105106

106-
metric_config = get_metric_config(
107-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
108-
)
107+
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
109108
model_cfg = metric_config.numalogic_conf.model
110109

111110
train_df = fetch_data(

numaprom/udsink/train_rollout.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from numaprom import get_logger
1616
from numaprom.entities import TrainerPayload
1717
from numaprom.clients.redis import get_redis_client
18-
from numaprom.tools import get_metric_config, save_model, fetch_data
18+
from numaprom.tools import save_model, fetch_data
19+
from numaprom.watcher import ConfigManager
1920

2021
_LOGGER = get_logger(__name__)
2122

@@ -117,10 +118,7 @@ def train_rollout(datums: Iterator[Datum]) -> Responses:
117118
responses.append(Response.as_success(_datum.id))
118119
continue
119120

120-
metric_config = get_metric_config(
121-
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
122-
)
123-
121+
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
124122
model_cfg = metric_config.numalogic_conf.model
125123

126124
# ToDo: standardize the label name

0 commit comments

Comments
 (0)