-
-
Notifications
You must be signed in to change notification settings - Fork 187
Expand file tree
/
Copy pathscinet_anomalydetection_example.py
More file actions
67 lines (56 loc) · 1.94 KB
/
Copy pathscinet_anomalydetection_example.py
File metadata and controls
67 lines (56 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
A minimalist, standalone example of the PyPOTS SCINet model for time-series anomaly detection.
This script is auto-generated by extracting hyperparameters from the test code.
"""
from benchpots.datasets import preprocess_random_walk
from pypots.anomaly_detection import SCINet
from pypots.nn.functional import calc_acc, calc_precision_recall_f1
def main():
n_steps = 48
n_features = 35
# 1. Generate a random walk time-series dataset
dataset = preprocess_random_walk(
n_steps=n_steps,
n_features=n_features,
n_classes=5,
n_samples_each_class=40,
missing_rate=0.1,
anomaly_rate=0.05,
)
# 2. Extract training and test sets
train_set = {"X": dataset["train_X"], "anomaly_y": dataset["train_anomaly_y"]}
val_set = {"X": dataset["val_X"], "X_ori": dataset["val_X_ori"], "anomaly_y": dataset["val_anomaly_y"]}
test_set = {"X": dataset["test_X"]}
test_anomaly_y = dataset["test_anomaly_y"].flatten()
# 3. Initialize the model
model = SCINet(
n_steps,
n_features,
anomaly_rate=0.05,
n_stacks=2,
n_levels=2,
n_groups=1,
n_decoder_layers=2,
d_hidden=64,
kernel_size=5,
dropout=0,
concat_len=0,
pos_enc=True,
epochs=2,
device="cpu",
)
# 4. Train the model
print("🚀 Training the SCINet anomaly detection model...")
model.fit(train_set, val_set)
# 5. Calculate anomaly scores
print("🔮 Calculating anomaly scores for the test set...")
results = model.predict(test_set)
scores = results["anomaly_detection"]
# 6. Evaluate
accuracy = calc_acc(scores, test_anomaly_y)
precision, recall, f1 = calc_precision_recall_f1(scores, test_anomaly_y)
print(
f"✅ SCINet anomaly detection - Accuracy: {accuracy:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}"
)
if __name__ == "__main__":
main()