Skip to content

Commit 88c3c53

Browse files
authored
Merge pull request #135 from amosproj/fix_ad_empty_return
FIX: mad ad now reliably returns an empty spark frame if no anomalies…
2 parents 17393d7 + 98e7d8b commit 88c3c53

File tree

2 files changed

+122
-4
lines changed

2 files changed

+122
-4
lines changed

src/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/mad/mad_anomaly_detection.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pandas as pd
1616

1717
from pyspark.sql import DataFrame
18+
from pyspark.sql.types import StructField, StructType, DoubleType, BooleanType
1819
from typing import Optional, List, Union
1920

2021
from ...._pipeline_utils.models import (
@@ -210,6 +211,22 @@ def libraries() -> Libraries:
210211
def settings() -> dict:
211212
return {}
212213

214+
@staticmethod
215+
def _build_result_schema(df: DataFrame) -> StructType:
216+
return StructType(
217+
list(df.schema.fields)
218+
+ [
219+
StructField("mad_zscore", DoubleType(), True),
220+
StructField("is_anomaly", BooleanType(), True),
221+
]
222+
)
223+
224+
@staticmethod
225+
def _empty_result_df(df: DataFrame, schema: StructType) -> DataFrame:
226+
"""Create an empty DataFrame with the correct schema using pandas."""
227+
empty_pdf = pd.DataFrame(columns=schema.fieldNames())
228+
return df.sparkSession.createDataFrame(empty_pdf, schema=schema)
229+
213230
def detect(self, df: DataFrame) -> DataFrame:
214231
"""
215232
Detects anomalies in the input DataFrame using the configured MAD scorer.
@@ -228,13 +245,25 @@ def detect(self, df: DataFrame) -> DataFrame:
228245
- `is_anomaly`: Boolean anomaly flag.
229246
"""
230247

248+
result_schema = self._build_result_schema(df)
249+
231250
pdf = df.toPandas()
251+
if pdf.empty:
252+
return self._empty_result_df(df, result_schema)
232253

233254
scores = self.scorer.score(pdf["value"])
234255
pdf["mad_zscore"] = scores
235256
pdf["is_anomaly"] = self.scorer.is_anomaly(scores)
236257

237-
return df.sparkSession.createDataFrame(pdf[pdf["is_anomaly"]].copy())
258+
anomalies_pdf = pdf[pdf["is_anomaly"]].copy()
259+
anomalies_pdf = anomalies_pdf[result_schema.fieldNames()]
260+
261+
if anomalies_pdf.empty:
262+
return self._empty_result_df(df, result_schema)
263+
264+
# Ensure correct column order matches schema
265+
anomalies_pdf = anomalies_pdf[result_schema.fieldNames()]
266+
return df.sparkSession.createDataFrame(anomalies_pdf, schema=result_schema)
238267

239268

240269
class DecompositionMadAnomalyDetection(AnomalyDetectionInterface):
@@ -368,6 +397,16 @@ def _decompose(self, df: DataFrame) -> DataFrame:
368397
else:
369398
raise ValueError(f"Unsupported decomposition method: {self.decomposition}")
370399

400+
@staticmethod
401+
def _build_result_schema(df: DataFrame) -> StructType:
402+
return StructType(
403+
list(df.schema.fields)
404+
+ [
405+
StructField("mad_zscore", DoubleType(), True),
406+
StructField("is_anomaly", BooleanType(), True),
407+
]
408+
)
409+
371410
def detect(self, df: DataFrame) -> DataFrame:
372411
"""
373412
Detects anomalies by scoring the decomposition residuals using the configured MAD scorer.
@@ -385,12 +424,25 @@ def detect(self, df: DataFrame) -> DataFrame:
385424
- `mad_zscore`: MAD-based anomaly score computed on `residual`.
386425
- `is_anomaly`: Boolean anomaly flag.
387426
"""
388-
427+
389428
decomposed_df = self._decompose(df)
429+
result_schema = self._build_result_schema(decomposed_df)
430+
390431
pdf = decomposed_df.toPandas().sort_values(self.timestamp_column)
391432

433+
if pdf.empty:
434+
return MadAnomalyDetection._empty_result_df(decomposed_df, result_schema)
435+
392436
scores = self.scorer.score(pdf["residual"])
393437
pdf["mad_zscore"] = scores
394438
pdf["is_anomaly"] = self.scorer.is_anomaly(scores)
395439

396-
return df.sparkSession.createDataFrame(pdf[pdf["is_anomaly"]].copy())
440+
anomalies_pdf = pdf[pdf["is_anomaly"]].copy()
441+
anomalies_pdf = anomalies_pdf[result_schema.fieldNames()]
442+
443+
if anomalies_pdf.empty:
444+
return MadAnomalyDetection._empty_result_df(decomposed_df, result_schema)
445+
446+
# Ensure correct column order matches schema
447+
anomalies_pdf = anomalies_pdf[result_schema.fieldNames()]
448+
return df.sparkSession.createDataFrame(anomalies_pdf, schema=result_schema)

tests/sdk/python/rtdip_sdk/pipelines/anomaly_detection/spark/test_mad.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
16+
import pandas as pd
1517
import pytest
1618

1719
from src.sdk.python.rtdip_sdk.pipelines.anomaly_detection.spark.mad.mad_anomaly_detection import (
@@ -22,6 +24,21 @@
2224
)
2325

2426

27+
@pytest.fixture
28+
def spark_dataframe_without_anomalies(spark_session):
29+
data = [(i, float(10.0 + 0.05 * np.sin(i))) for i in range(1, 31)]
30+
columns = ["timestamp", "value"]
31+
return spark_session.createDataFrame(data, columns)
32+
33+
34+
@pytest.fixture
35+
def spark_dataframe_without_anomalies_timestamp(spark_session):
36+
timestamps = pd.date_range("2025-02-01", periods=72, freq="h")
37+
values = 10.0 + 0.1 * np.sin(np.arange(72))
38+
pdf = pd.DataFrame({"timestamp": timestamps, "value": values})
39+
return spark_session.createDataFrame(pdf)
40+
41+
2542
@pytest.fixture
2643
def spark_dataframe_with_anomalies(spark_session):
2744
data = [
@@ -40,6 +57,28 @@ def spark_dataframe_with_anomalies(spark_session):
4057
return spark_session.createDataFrame(data, columns)
4158

4259

60+
def test_mad_anomaly_detection_global_no_anomalies(
61+
spark_dataframe_without_anomalies,
62+
):
63+
mad_detector = MadAnomalyDetection()
64+
65+
result_df = mad_detector.detect(spark_dataframe_without_anomalies)
66+
67+
assert result_df.count() == 0
68+
assert result_df.columns == ["timestamp", "value", "mad_zscore", "is_anomaly"]
69+
70+
71+
def test_mad_anomaly_detection_rolling_no_anomalies(
72+
spark_dataframe_without_anomalies,
73+
):
74+
mad_detector = MadAnomalyDetection(scorer=RollingMadScorer(window_size=5))
75+
76+
result_df = mad_detector.detect(spark_dataframe_without_anomalies)
77+
78+
assert result_df.count() == 0
79+
assert result_df.columns == ["timestamp", "value", "mad_zscore", "is_anomaly"]
80+
81+
4382
def test_mad_anomaly_detection_global(spark_dataframe_with_anomalies):
4483
mad_detector = MadAnomalyDetection()
4584

@@ -136,7 +175,7 @@ def spark_dataframe_synthetic_stl(spark_session):
136175
n = 500
137176
period = 24
138177

139-
timestamps = pd.date_range("2025-01-01", periods=n, freq="H")
178+
timestamps = pd.date_range("2025-01-01", periods=n, freq="h")
140179
trend = 0.02 * np.arange(n)
141180
seasonal = 5 * np.sin(2 * np.pi * np.arange(n) / period)
142181
noise = 0.3 * np.random.randn(n)
@@ -151,6 +190,33 @@ def spark_dataframe_synthetic_stl(spark_session):
151190
return spark_session.createDataFrame(pdf)
152191

153192

193+
@pytest.mark.parametrize(
194+
"decomposition, scorer",
195+
[
196+
("stl", GlobalMadScorer(threshold=3.5)),
197+
("mstl", RollingMadScorer(threshold=3.5, window_size=24)),
198+
],
199+
)
200+
def test_decomposition_mad_anomaly_detection_no_anomalies(
201+
spark_dataframe_without_anomalies_timestamp,
202+
decomposition,
203+
scorer,
204+
):
205+
detector = DecompositionMadAnomalyDetection(
206+
scorer=scorer,
207+
decomposition=decomposition,
208+
period=24,
209+
timestamp_column="timestamp",
210+
value_column="value",
211+
)
212+
213+
result_df = detector.detect(spark_dataframe_without_anomalies_timestamp)
214+
215+
assert result_df.count() == 0
216+
assert "mad_zscore" in result_df.columns
217+
assert "is_anomaly" in result_df.columns
218+
219+
154220
@pytest.mark.parametrize(
155221
"decomposition, period, scorer",
156222
[

0 commit comments

Comments
 (0)