This repository was archived by the owner on Mar 11, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 354
Expand file tree
/
Copy pathbase.py
More file actions
58 lines (48 loc) · 1.77 KB
/
base.py
File metadata and controls
58 lines (48 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
"""
Base class for post-processing rules in Merlion.
"""
from abc import abstractmethod
from copy import copy, deepcopy
import inspect
from merlion.utils import TimeSeries
from merlion.utils.misc import AutodocABCMeta
class PostRuleBase(metaclass=AutodocABCMeta):
"""
Base class for post-processing rules in Merlion. These objects are primarily
for post-processing the sequence of anomaly scores returned by anomaly detection
models. All post-rules are callable objects, and they have a ``train()`` method
which may accept additional implementation-specific keyword arguments.
"""
def to_dict(self):
params = inspect.signature(self.__init__).parameters
d = {k: deepcopy(getattr(self, k)) for k in params}
d["name"] = type(self).__name__
return d
@classmethod
def from_dict(cls, state_dict):
state_dict = copy(state_dict)
state_dict.pop("name", None)
return cls(**state_dict)
def __copy__(self):
return self.from_dict(self.to_dict())
def __deepcopy__(self, memodict=None):
if memodict is None:
memodict = {}
return self.__copy__()
def __repr__(self):
kwargs = self.to_dict()
name = kwargs.pop("name")
kwargs_str = ", ".join(f"{k}={v}" for k, v in sorted(kwargs.items()))
return f"{name}({kwargs_str})"
@abstractmethod
def train(self, anomaly_scores: TimeSeries):
raise NotImplementedError
@abstractmethod
def __call__(self, anomaly_scores: TimeSeries):
raise NotImplementedError