Skip to content

Commit 30e18c0

Browse files
authored
Implement tests for IQR anomaly detection
Added tests for IqrAnomalyDetection and IqrAnomalyDetectionRollingWindow using Pytest. Signed-off-by: Mehdi-kbz <141425685+Mehdi-kbz@users.noreply.github.com>
1 parent 3149092 commit 30e18c0

1 file changed

Lines changed: 123 additions & 0 deletions

File tree

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2025 RTDIP
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.iqr_anomaly_detection import (
18+
IqrAnomalyDetection,
19+
IqrAnomalyDetectionRollingWindow,
20+
)
21+
22+
23+
@pytest.fixture
24+
def spark_dataframe_with_anomalies(spark_session):
25+
data = [
26+
(1, 10.0),
27+
(2, 12.0),
28+
(3, 10.5),
29+
(4, 11.0),
30+
(5, 30.0), # Anomalous value
31+
(6, 10.2),
32+
(7, 9.8),
33+
(8, 10.1),
34+
(9, 10.3),
35+
(10, 10.0),
36+
]
37+
columns = ["timestamp", "value"]
38+
return spark_session.createDataFrame(data, columns)
39+
40+
41+
def test_iqr_anomaly_detection(spark_dataframe_with_anomalies):
42+
iqr_detector = IqrAnomalyDetection()
43+
result_df = iqr_detector.detect(spark_dataframe_with_anomalies)
44+
45+
# direct anomaly count check
46+
assert result_df.count() == 1
47+
48+
row = result_df.collect()[0]
49+
50+
assert row["value"] == 30.0
51+
52+
53+
@pytest.fixture
54+
def spark_dataframe_with_anomalies_big(spark_session):
55+
data = [
56+
(1, 5.8),
57+
(2, 6.6),
58+
(3, 6.2),
59+
(4, 7.5),
60+
(5, 7.0),
61+
(6, 8.3),
62+
(7, 8.1),
63+
(8, 9.7),
64+
(9, 9.2),
65+
(10, 10.5),
66+
(11, 10.7),
67+
(12, 11.4),
68+
(13, 12.1),
69+
(14, 11.6),
70+
(15, 13.0),
71+
(16, 13.6),
72+
(17, 14.2),
73+
(18, 14.8),
74+
(19, 15.3),
75+
(20, 15.0),
76+
(21, 16.2),
77+
(22, 16.8),
78+
(23, 17.4),
79+
(24, 18.1),
80+
(25, 17.7),
81+
(26, 18.9),
82+
(27, 19.5),
83+
(28, 19.2),
84+
(29, 20.1),
85+
(30, 20.7),
86+
(31, 0.0),
87+
(32, 21.5),
88+
(33, 22.0),
89+
(34, 22.9),
90+
(35, 23.4),
91+
(36, 30.0),
92+
(37, 23.8),
93+
(38, 24.9),
94+
(39, 25.1),
95+
(40, 26.0),
96+
(41, 40.0),
97+
(42, 26.5),
98+
(43, 27.4),
99+
(44, 28.0),
100+
(45, 28.8),
101+
(46, 29.1),
102+
(47, 29.8),
103+
(48, 30.5),
104+
(49, 31.0),
105+
(50, 31.6),
106+
]
107+
108+
columns = ["timestamp", "value"]
109+
return spark_session.createDataFrame(data, columns)
110+
111+
112+
def test_iqr_anomaly_detection_rolling_window(spark_dataframe_with_anomalies_big):
113+
# Using a smaller window size to detect anomalies in the larger dataset
114+
iqr_detector = IqrAnomalyDetectionRollingWindow(window_size=15)
115+
result_df = iqr_detector.detect(spark_dataframe_with_anomalies_big)
116+
117+
# assert all 3 anomalies are detected
118+
assert result_df.count() == 3
119+
120+
# check that the detected anomalies are the expected ones
121+
assert result_df.collect()[0]["value"] == 0.0
122+
assert result_df.collect()[1]["value"] == 30.0
123+
assert result_df.collect()[2]["value"] == 40.0

0 commit comments

Comments
 (0)