Skip to content

Commit 620d1f3

Browse files
authored
fix: training only when model loading from registry source (#159)
Signed-off-by: s0nicboOm <[email protected]>
1 parent 3e2a40a commit 620d1f3

10 files changed

+156
-31
lines changed

numaprom/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def emit(self, record):
2323

2424
def __get_logger() -> logger:
2525
# Collect logs from logging library
26+
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)
27+
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
2628
logging.basicConfig(handlers=[InterceptHandler()], level=0)
2729
logger.remove()
2830

numaprom/default-configs/numalogic_config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ preprocess:
1616
- name: "StandardScaler"
1717
threshold:
1818
name: "StdDevThreshold"
19+
conf:
20+
min_threshold: 0.01
1921
postprocess:
2022
name: "TanhNorm"
2123
stateful: false

numaprom/udf/inference.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,17 @@ def inference(_: list[str], datum: Datum) -> bytes:
9797
payload.set_status(Status.ARTIFACT_NOT_FOUND)
9898
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
9999

100-
# Check if current model is stale
101-
if RedisRegistry.is_artifact_stale(artifact_data, int(metric_config.retrain_freq_hr)):
100+
LOGGER.info(
101+
"{uuid} - Loaded artifact data from {source} ",
102+
uuid=payload.uuid,
103+
source=artifact_data.extras.get("source"),
104+
)
105+
106+
# Check if current model is stale and source is 'registry'
107+
if (
108+
RedisRegistry.is_artifact_stale(artifact_data, int(metric_config.retrain_freq_hr))
109+
and artifact_data.extras.get("source") == "registry"
110+
):
102111
payload.set_header(Header.MODEL_STALE)
103112

104113
# Generate predictions

0 commit comments

Comments
 (0)