Skip to content

Commit 810d352

Browse files
committed
add visualization into sdk + tests
Signed-off-by: simonselbig <simon.selbig@gmx.de>
1 parent 0ccca56 commit 810d352

File tree

22 files changed

+6939
-135
lines changed

22 files changed

+6939
-135
lines changed

amos_team_resources/shell/pipeline_shell_data.py

Lines changed: 277 additions & 133 deletions
Large diffs are not rendered by default.

environment.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ dependencies:
7272
- great-expectations>=0.18.8,<1.0.0
7373
- statsmodels>=0.14.1,<0.15.0
7474
- pmdarima>=2.0.4
75+
- plotly>=5.0.0
76+
- kaleido>=0.2.0
7577
- pip:
7678
# protobuf installed via pip to avoid libabseil conflicts with conda libarrow
7779
- protobuf>=5.29.0,<5.30.0
@@ -98,6 +100,4 @@ dependencies:
98100
- autogluon.timeseries>=1.1.1,<2.0.0
99101
- scikit-learn>=1.3.0,<2.0.0
100102
- xgboost>=2.0.0,<3.0.0
101-
- plotly>=5.0.0
102-
- kaleido>=0.2.0
103103
#- tensorflow>=2.13.0,<3.0.0
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 RTDIP
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
RTDIP Visualization Module.
17+
18+
This module provides standardized visualization components for time series forecasting,
19+
anomaly detection, and model comparison. It supports both Matplotlib (static) and
20+
Plotly (interactive) backends.
21+
22+
Submodules:
23+
- matplotlib: Static visualization using Matplotlib/Seaborn
24+
- plotly: Interactive visualization using Plotly
25+
26+
Example:
27+
```python
28+
from rtdip_sdk.pipelines.visualization.matplotlib.forecasting import ForecastPlot
29+
from rtdip_sdk.pipelines.visualization.plotly.forecasting import ForecastPlotInteractive
30+
31+
# Static plot
32+
plot = ForecastPlot(historical_df, forecast_df, forecast_start)
33+
fig = plot.plot()
34+
plot.save("forecast.png")
35+
36+
# Interactive plot
37+
plot_interactive = ForecastPlotInteractive(historical_df, forecast_df, forecast_start)
38+
fig = plot_interactive.plot()
39+
plot_interactive.save("forecast.html")
40+
```
41+
"""
42+
43+
from . import config
44+
from . import utils
45+
from . import validation
46+
from .interfaces import VisualizationBaseInterface
47+
from .validation import VisualizationDataError
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright 2025 RTDIP
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Standardized visualization configuration for RTDIP time series forecasting.
17+
18+
This module defines standard colors, styles, and settings to ensure consistent
19+
visualizations across all forecasting, anomaly detection, and model comparison tasks.
20+
21+
Supports both Matplotlib (static) and Plotly (interactive) backends.
22+
23+
Example
24+
--------
25+
```python
26+
from rtdip_sdk.pipelines.visualization import config
27+
28+
# Use predefined colors
29+
historical_color = config.COLORS['historical']
30+
31+
# Get model-specific color
32+
model_color = config.get_model_color('autogluon')
33+
34+
# Get figure size for grid
35+
figsize = config.get_figsize_for_grid(6)
36+
```
37+
"""
38+
39+
from typing import Dict, Tuple
40+
41+
# BACKEND CONFIGURATION
42+
VISUALIZATION_BACKEND: str = "matplotlib" # Options: 'matplotlib' or 'plotly'
43+
44+
# COLOR SCHEMES
45+
46+
# Primary colors for different data types
47+
COLORS: Dict[str, str] = {
48+
# Time series data
49+
"historical": "#2C3E50", # historical data
50+
"forecast": "#27AE60", # predictions
51+
"actual": "#2980B9", # ground truth
52+
"anomaly": "#E74C3C", # anomalies/errors
53+
# Confidence intervals
54+
"ci_60": "#27AE60", # alpha=0.3
55+
"ci_80": "#27AE60", # alpha=0.15
56+
"ci_90": "#27AE60", # alpha=0.1
57+
# Special markers
58+
"forecast_start": "#E74C3C", # forecast start line
59+
"threshold": "#F39C12", # thresholds
60+
}
61+
62+
# Model-specific colors (for comparison plots)
63+
MODEL_COLORS: Dict[str, str] = {
64+
"autogluon": "#2ECC71",
65+
"lstm": "#E74C3C",
66+
"xgboost": "#3498DB",
67+
"arima": "#9B59B6",
68+
"prophet": "#F39C12",
69+
"ensemble": "#1ABC9C",
70+
}
71+
72+
# Confidence interval alpha values
73+
CI_ALPHA: Dict[int, float] = {
74+
60: 0.3, # 60% - most opaque
75+
80: 0.2, # 80% - medium
76+
90: 0.1, # 90% - most transparent
77+
}
78+
79+
# FIGURE SIZES
80+
81+
FIGSIZE: Dict[str, Tuple[float, float]] = {
82+
"single": (12, 6), # Single time series plot
83+
"single_tall": (12, 8), # Single plot with more vertical space
84+
"comparison": (14, 6), # Side-by-side comparison
85+
"grid_small": (14, 8), # 2-3 subplot grid
86+
"grid_medium": (16, 10), # 4-6 subplot grid
87+
"grid_large": (18, 12), # 6-9 subplot grid
88+
"dashboard": (20, 16), # Full dashboard with 9+ subplots
89+
"wide": (16, 5), # Wide single plot
90+
}
91+
92+
# EXPORT SETTINGS
93+
94+
EXPORT: Dict[str, any] = {
95+
"dpi": 300, # High resolution
96+
"format": "png", # Default format
97+
"bbox_inches": "tight", # Tight bounding box
98+
"facecolor": "white", # White background
99+
"edgecolor": "none", # No edge color
100+
}
101+
102+
# STYLE SETTINGS
103+
104+
STYLE: str = "seaborn-v0_8-whitegrid"
105+
106+
FONT_SIZES: Dict[str, int] = {
107+
"title": 14,
108+
"subtitle": 12,
109+
"axis_label": 12,
110+
"tick_label": 10,
111+
"legend": 10,
112+
"annotation": 9,
113+
}
114+
115+
LINE_SETTINGS: Dict[str, float] = {
116+
"linewidth": 2.0, # Default line width
117+
"linewidth_thin": 1.5, # Thin lines (for CI, grids)
118+
"marker_size": 4, # Default marker size for line plots
119+
"scatter_size": 80, # Scatter plot marker size
120+
"anomaly_size": 100, # Anomaly marker size
121+
}
122+
123+
GRID: Dict[str, any] = {
124+
"alpha": 0.3, # Grid transparency
125+
"linestyle": "--", # Dashed grid lines
126+
"linewidth": 0.5, # Thin grid lines
127+
}
128+
129+
TIME_FORMATS: Dict[str, str] = {
130+
"hourly": "%Y-%m-%d %H:%M",
131+
"daily": "%Y-%m-%d",
132+
"monthly": "%Y-%m",
133+
"display": "%m/%d %H:%M",
134+
}
135+
136+
METRICS: Dict[str, Dict[str, str]] = {
137+
"mae": {"name": "MAE", "format": ".3f"},
138+
"mse": {"name": "MSE", "format": ".3f"},
139+
"rmse": {"name": "RMSE", "format": ".3f"},
140+
"mape": {"name": "MAPE (%)", "format": ".2f"},
141+
"smape": {"name": "SMAPE (%)", "format": ".2f"},
142+
"r2": {"name": "R²", "format": ".4f"},
143+
"mae_p50": {"name": "MAE (P50)", "format": ".3f"},
144+
"mae_p90": {"name": "MAE (P90)", "format": ".3f"},
145+
}
146+
147+
# Metric display order (left to right, top to bottom)
148+
METRIC_ORDER: list = ["mae", "rmse", "mse", "mape", "smape", "r2"]
149+
150+
# OUTPUT DIRECTORY SETTINGS
151+
DEFAULT_OUTPUT_DIR: str = "output_images"
152+
153+
# COLORBLIND-FRIENDLY PALETTE
154+
155+
COLORBLIND_PALETTE: list = [
156+
"#0173B2",
157+
"#DE8F05",
158+
"#029E73",
159+
"#CC78BC",
160+
"#CA9161",
161+
"#949494",
162+
"#ECE133",
163+
"#56B4E9",
164+
]
165+
166+
167+
# HELPER FUNCTIONS
168+
169+
170+
def get_grid_layout(n_plots: int) -> Tuple[int, int]:
171+
"""
172+
Calculate optimal subplot grid layout (rows, cols) for n_plots.
173+
174+
Prioritizes 3 columns for better horizontal space usage.
175+
176+
Args:
177+
n_plots: Number of subplots needed
178+
179+
Returns:
180+
Tuple of (n_rows, n_cols)
181+
182+
Example
183+
--------
184+
```python
185+
from rtdip_sdk.pipelines.visualization.config import get_grid_layout
186+
187+
rows, cols = get_grid_layout(5) # Returns (2, 3)
188+
```
189+
"""
190+
if n_plots <= 0:
191+
return (0, 0)
192+
elif n_plots == 1:
193+
return (1, 1)
194+
elif n_plots == 2:
195+
return (1, 2)
196+
elif n_plots <= 3:
197+
return (1, 3)
198+
elif n_plots <= 6:
199+
return (2, 3)
200+
elif n_plots <= 9:
201+
return (3, 3)
202+
elif n_plots <= 12:
203+
return (4, 3)
204+
else:
205+
n_cols = 3
206+
n_rows = (n_plots + n_cols - 1) // n_cols
207+
return (n_rows, n_cols)
208+
209+
210+
def get_model_color(model_name: str) -> str:
211+
"""
212+
Get color for a specific model, with fallback to colorblind palette.
213+
214+
Args:
215+
model_name: Model name (case-insensitive)
216+
217+
Returns:
218+
Hex color code string
219+
220+
Example
221+
--------
222+
```python
223+
from rtdip_sdk.pipelines.visualization.config import get_model_color
224+
225+
color = get_model_color('AutoGluon') # Returns '#2ECC71'
226+
color = get_model_color('custom_model') # Returns color from palette
227+
```
228+
"""
229+
model_name_lower = model_name.lower()
230+
231+
if model_name_lower in MODEL_COLORS:
232+
return MODEL_COLORS[model_name_lower]
233+
234+
idx = hash(model_name) % len(COLORBLIND_PALETTE)
235+
return COLORBLIND_PALETTE[idx]
236+
237+
238+
def get_figsize_for_grid(n_plots: int) -> Tuple[float, float]:
239+
"""
240+
Get appropriate figure size for a grid of n plots.
241+
242+
Args:
243+
n_plots: Number of subplots
244+
245+
Returns:
246+
Tuple of (width, height) in inches
247+
248+
Example
249+
--------
250+
```python
251+
from rtdip_sdk.pipelines.visualization.config import get_figsize_for_grid
252+
253+
figsize = get_figsize_for_grid(4) # Returns (16, 10) for grid_medium
254+
```
255+
"""
256+
if n_plots <= 1:
257+
return FIGSIZE["single"]
258+
elif n_plots <= 3:
259+
return FIGSIZE["grid_small"]
260+
elif n_plots <= 6:
261+
return FIGSIZE["grid_medium"]
262+
elif n_plots <= 9:
263+
return FIGSIZE["grid_large"]
264+
else:
265+
return FIGSIZE["dashboard"]

0 commit comments

Comments
 (0)