Skip to content

Commit e9eb3fb

Browse files
ez96The Meridian Authors
authored andcommitted
add boxplots plotting to MeridianEDA
PiperOrigin-RevId: 831633970
1 parent b21371a commit e9eb3fb

File tree

3 files changed

+796
-192
lines changed

3 files changed

+796
-192
lines changed

meridian/model/eda/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@
1818
VARIABLE_1 = 'var1'
1919
VARIABLE_2 = 'var2'
2020
VARIABLE = 'var'
21+
VALUE = 'value'
2122
CORRELATION = 'correlation'
23+
NATIONALIZE = 'nationalize'

meridian/model/eda/meridian_eda.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,19 @@
1515
"""Module containing Meridian related exploratory data analysis (EDA) functionalities."""
1616
from __future__ import annotations
1717

18-
from typing import Literal, TYPE_CHECKING, Union
18+
from typing import Callable, Literal, TYPE_CHECKING, Union
1919

2020
import altair as alt
2121
from meridian import constants
2222
from meridian.model.eda import constants as eda_constants
23+
from meridian.model.eda import eda_engine
2324
import pandas as pd
25+
import xarray as xr
2426

2527
if TYPE_CHECKING:
2628
from meridian.model import model # pylint: disable=g-bad-import-order,g-import-not-at-top
2729

30+
2831
__all__ = [
2932
'MeridianEDA',
3033
]
@@ -55,6 +58,171 @@ def generate_and_save_report(self, filename: str, filepath: str):
5558
# TODO: Implement.
5659
raise NotImplementedError()
5760

61+
def plot_kpi_boxplot(
62+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
63+
) -> alt.Chart:
64+
"""Plots the boxplot for KPI variation."""
65+
return self._plot_boxplots(
66+
geos,
67+
'Boxplots of KPI',
68+
self._meridian.eda_engine.national_kpi_scaled_da,
69+
self._meridian.eda_engine.kpi_scaled_da,
70+
lambda data: data.to_dataframe()
71+
.reset_index()
72+
.rename(columns={data.name: constants.VALUE})[[constants.VALUE]]
73+
.assign(var=constants.KPI),
74+
)
75+
76+
def plot_frequency_boxplot(
77+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
78+
) -> alt.Chart:
79+
"""Plots the boxplot for frequency variation."""
80+
return self._plot_boxplots(
81+
geos,
82+
'Boxplots of frequency',
83+
self._meridian.eda_engine.national_all_freq_da,
84+
self._meridian.eda_engine.all_freq_da,
85+
lambda data: pd.melt(data.to_pandas().reset_index(drop=True)).rename(
86+
columns={constants.RF_CHANNEL: eda_constants.VARIABLE}
87+
),
88+
)
89+
90+
def plot_reach_boxplot(
91+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
92+
) -> alt.Chart:
93+
"""Plots the boxplot for reach variation."""
94+
return self._plot_boxplots(
95+
geos,
96+
'Boxplots of reach',
97+
self._meridian.eda_engine.national_all_reach_scaled_da,
98+
self._meridian.eda_engine.all_reach_scaled_da,
99+
lambda data: pd.melt(data.to_pandas().reset_index(drop=True)).rename(
100+
columns={constants.RF_CHANNEL: eda_constants.VARIABLE}
101+
),
102+
)
103+
104+
def plot_non_media_boxplot(
105+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
106+
) -> alt.Chart:
107+
"""Plots the boxplot for non-media treatments variation."""
108+
return self._plot_boxplots(
109+
geos,
110+
'Boxplots of non-media treatments',
111+
self._meridian.eda_engine.national_non_media_scaled_da,
112+
self._meridian.eda_engine.non_media_scaled_da,
113+
lambda data: pd.melt(data.to_pandas().reset_index(drop=True)).rename(
114+
columns={constants.NON_MEDIA_CHANNEL: eda_constants.VARIABLE}
115+
),
116+
)
117+
118+
def plot_treatments_excl_non_media_boxplot(
119+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
120+
) -> alt.Chart:
121+
"""Plots the boxplot for treatments variation excluding non-media treatments."""
122+
return self._plot_boxplots(
123+
geos,
124+
'Boxplots of paid and organic impressions',
125+
self._meridian.eda_engine.national_treatment_control_scaled_ds,
126+
self._meridian.eda_engine.treatment_control_scaled_ds,
127+
lambda data: self._process_stacked_ds(
128+
eda_engine.stack_variables(
129+
data.drop_dims(
130+
[constants.NON_MEDIA_CHANNEL, constants.CONTROL_VARIABLE],
131+
errors='ignore',
132+
)
133+
)
134+
),
135+
)
136+
137+
def plot_controls_boxplot(
138+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
139+
) -> alt.Chart:
140+
"""Plots the boxplot for controls variation."""
141+
return self._plot_boxplots(
142+
geos,
143+
'Boxplots of controls',
144+
self._meridian.eda_engine.national_controls_scaled_da,
145+
self._meridian.eda_engine.controls_scaled_da,
146+
lambda data: pd.melt(data.to_pandas().reset_index(drop=True)).rename(
147+
columns={constants.CONTROL_VARIABLE: eda_constants.VARIABLE}
148+
),
149+
)
150+
151+
def plot_spend_boxplot(
152+
self, geos: Union[int, list[str], Literal['nationalize']] = 1
153+
) -> alt.Chart:
154+
"""Plots the boxplot for spend variation."""
155+
return self._plot_boxplots(
156+
geos,
157+
'Boxplots of spend for each paid channel',
158+
self._meridian.eda_engine.national_all_spend_ds,
159+
self._meridian.eda_engine.all_spend_ds,
160+
lambda data: self._process_stacked_ds(eda_engine.stack_variables(data)),
161+
)
162+
163+
def _plot_boxplots(
164+
self,
165+
geos: Union[int, list[str], Literal['nationalize']],
166+
title_prefix: str,
167+
national_data_source: xr.DataArray | xr.Dataset,
168+
geo_data_source: xr.DataArray | xr.Dataset,
169+
processing_function: Callable[[xr.DataArray | xr.Dataset], pd.DataFrame],
170+
) -> alt.Chart:
171+
"""Helper function for plotting boxplots."""
172+
geos_to_plot = self._validate_and_get_geos_to_plot(geos)
173+
174+
use_national_data = (
175+
self._meridian.is_national or geos == eda_constants.NATIONALIZE
176+
)
177+
data_source = national_data_source if use_national_data else geo_data_source
178+
if data_source is None:
179+
raise ValueError(
180+
'There is no data to plot! Make sure your InputData contains the'
181+
' component you are triyng to plot.'
182+
)
183+
charts = []
184+
185+
for geo_to_plot in geos_to_plot:
186+
title = f'{title_prefix} for {geo_to_plot}'
187+
188+
if use_national_data:
189+
plot_data = data_source
190+
else:
191+
plot_data = data_source.sel(geo=geo_to_plot)
192+
193+
plot_data = processing_function(plot_data)
194+
unique_variables = plot_data[eda_constants.VARIABLE].unique()
195+
196+
charts.append((
197+
alt.Chart(plot_data)
198+
.mark_boxplot(ticks=True, size=40, extent=1.5)
199+
.encode(
200+
x=alt.X(
201+
f'{eda_constants.VARIABLE}:N',
202+
title=None,
203+
sort=unique_variables,
204+
scale=alt.Scale(paddingInner=0.02),
205+
),
206+
y=alt.Y(
207+
f'{eda_constants.VALUE}:Q',
208+
title='Value',
209+
sort=unique_variables,
210+
scale=alt.Scale(zero=True),
211+
),
212+
color=alt.Color(f'{eda_constants.VARIABLE}:N', legend=None),
213+
)
214+
.properties(title=title, width=450, height=400)
215+
))
216+
217+
final_chart = (
218+
alt.vconcat(*charts)
219+
.resolve_legend(color='independent')
220+
.configure_axis(labelAngle=315)
221+
.configure_title(anchor='start')
222+
.configure_view(stroke=None)
223+
)
224+
return final_chart
225+
58226
def plot_pairwise_correlation(
59227
self, geos: Union[int, list[str], Literal['nationalize']] = 1
60228
) -> alt.Chart:
@@ -185,6 +353,14 @@ def _plot_2d_heatmap(
185353

186354
return chart
187355

356+
def _process_stacked_ds(self, data: xr.DataArray) -> pd.DataFrame:
357+
"""Processes a stacked Dataset so it can be plotted by Altair."""
358+
return (
359+
data.rename(eda_constants.VALUE)
360+
.to_dataframe()
361+
.reset_index()[[eda_constants.VARIABLE, eda_constants.VALUE]]
362+
)
363+
188364
def _generate_pairwise_correlation_report(self) -> str:
189365
"""Creates the HTML snippet for Pairwise Correlation report section."""
190366
# TODO: Implement.

0 commit comments

Comments
 (0)