Skip to content

Commit ff62333

Browse files
author
The Meridian Authors
committed
Refactor model fit plot to include interactive tooltips and hover effects
PiperOrigin-RevId: 876095304
1 parent ce806e9 commit ff62333

File tree

3 files changed

+119
-46
lines changed

3 files changed

+119
-46
lines changed

meridian/analysis/visualizer.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -481,53 +481,88 @@ def plot_model_fit(
481481
y_axis_label = summary_text.KPI_LABEL
482482
else:
483483
y_axis_label = summary_text.REVENUE_LABEL
484-
plot = (
485-
alt.Chart(model_fit_df, width=c.VEGALITE_FACET_EXTRA_LARGE_WIDTH)
486-
.mark_line()
487-
.encode(
488-
x=alt.X(
489-
f'{c.TIME}:T',
490-
title='Time period',
491-
axis=alt.Axis(
492-
format=c.QUARTER_FORMAT,
493-
grid=False,
494-
tickCount=8,
495-
domainColor=c.GREY_300,
496-
),
497-
),
498-
y=alt.Y(
499-
f'{c.MEAN}:Q',
500-
title=y_axis_label,
501-
axis=alt.Axis(
502-
ticks=False,
503-
domain=False,
504-
tickCount=5,
505-
labelPadding=c.PADDING_10,
506-
labelExpr=formatter.compact_number_expr(),
507-
**formatter.Y_AXIS_TITLE_CONFIG,
508-
),
484+
485+
base = alt.Chart(model_fit_df, width=c.VEGALITE_FACET_EXTRA_LARGE_WIDTH)
486+
487+
hover = alt.selection_point(
488+
fields=[c.TIME],
489+
nearest=True,
490+
on='pointerover',
491+
empty=False,
492+
clear='pointerout',
493+
)
494+
495+
lines = base.mark_line().encode(
496+
x=alt.X(
497+
f'{c.TIME}:T',
498+
title='Time period',
499+
axis=alt.Axis(
500+
format=c.QUARTER_FORMAT,
501+
grid=False,
502+
tickCount=8,
503+
domainColor=c.GREY_300,
509504
),
510-
color=alt.Color(
511-
'type:N', scale=alt.Scale(domain=domain, range=colors)
505+
),
506+
y=alt.Y(
507+
f'{c.MEAN}:Q',
508+
title=y_axis_label,
509+
axis=alt.Axis(
510+
ticks=False,
511+
domain=False,
512+
tickCount=5,
513+
labelPadding=c.PADDING_10,
514+
labelExpr=formatter.compact_number_expr(),
515+
**formatter.Y_AXIS_TITLE_CONFIG,
512516
),
517+
),
518+
color=alt.Color('type:N', scale=alt.Scale(domain=domain, range=colors)),
519+
)
520+
521+
tooltip = [alt.Tooltip(f'{c.TIME}:T', title='Time period')]
522+
for field in domain:
523+
tooltip.append(
524+
alt.Tooltip(
525+
f'{field}:Q',
526+
title=field.title(),
527+
format=',.0f',
528+
)
529+
)
530+
531+
tooltips = (
532+
base.transform_pivot(c.TYPE, value=c.MEAN, groupby=[c.TIME])
533+
.mark_rule(opacity=0)
534+
.encode(
535+
x=f'{c.TIME}:T',
536+
opacity=alt.condition(hover, alt.value(0.3), alt.value(0)),
537+
tooltip=tooltip,
513538
)
539+
.add_params(hover)
514540
)
515541

542+
points = base.mark_circle(size=c.MARK_CIRCLE_SIZE, filled=False).encode(
543+
x=f'{c.TIME}:T',
544+
y=f'{c.MEAN}:Q',
545+
color=alt.Color(
546+
'type:N',
547+
scale=alt.Scale(domain=domain, range=colors),
548+
legend=None,
549+
),
550+
opacity=alt.condition(hover, alt.value(1), alt.value(0)),
551+
)
552+
553+
plot = alt.layer(lines, tooltips, points)
554+
516555
if include_ci:
517556
# Only add a confidence interval area for the modeled data.
518-
confidence_band = (
519-
alt.Chart(model_fit_df)
520-
.mark_area(opacity=0.3)
521-
.encode(
522-
x=f'{c.TIME}:T',
523-
y=f'{c.CI_HI}:Q',
524-
y2=f'{c.CI_LO}:Q',
525-
color=alt.Color(
526-
'type:N',
527-
scale=alt.Scale(domain=[domain[0]], range=[colors[0]]),
528-
legend=None,
529-
),
530-
)
557+
confidence_band = base.mark_area(opacity=0.3).encode(
558+
x=f'{c.TIME}:T',
559+
y=f'{c.CI_HI}:Q',
560+
y2=f'{c.CI_LO}:Q',
561+
color=alt.Color(
562+
'type:N',
563+
scale=alt.Scale(domain=[domain[0]], range=[colors[0]]),
564+
legend=None,
565+
),
531566
)
532567
plot = (plot + confidence_band).resolve_scale(color=c.INDEPENDENT)
533568

meridian/analysis/visualizer_test.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -703,16 +703,16 @@ def test_model_fit_plots_expected_ci(self):
703703

704704
self.assertIsInstance(plot, alt.LayerChart)
705705
self.assertEqual(
706-
plot.layer[1].encoding.color["scale"]["domain"], [c.EXPECTED]
706+
plot.layer[3].encoding.color["scale"]["domain"], [c.EXPECTED]
707707
)
708-
self.assertEqual(plot.layer[1].encoding.y.shorthand, f"{c.CI_HI}:Q")
709-
self.assertEqual(plot.layer[1].encoding.y2.shorthand, f"{c.CI_LO}:Q")
708+
self.assertEqual(plot.layer[3].encoding.y.shorthand, f"{c.CI_HI}:Q")
709+
self.assertEqual(plot.layer[3].encoding.y2.shorthand, f"{c.CI_LO}:Q")
710710

711711
def test_model_fit_plots_no_ci(self):
712712
plot = self.model_fit_kpi_type_revenue.plot_model_fit(include_ci=False)
713-
self.assertIsInstance(plot, alt.Chart)
714-
self.assertEqual(plot.encoding.x.shorthand, f"{c.TIME}:T")
715-
self.assertEqual(plot.encoding.y.shorthand, f"{c.MEAN}:Q")
713+
self.assertIsInstance(plot, alt.LayerChart)
714+
self.assertEqual(plot.layer[0].encoding.x.shorthand, f"{c.TIME}:T")
715+
self.assertEqual(plot.layer[0].encoding.y.shorthand, f"{c.MEAN}:Q")
716716

717717
def test_model_fit_axis_encoding(self):
718718
plot = self.model_fit_kpi_type_revenue.plot_model_fit()
@@ -737,6 +737,43 @@ def test_model_fit_axis_encoding(self):
737737
| formatter.Y_AXIS_TITLE_CONFIG,
738738
)
739739

740+
def test_model_fit_tooltip_encoding(self):
741+
plot = self.model_fit_kpi_type_revenue.plot_model_fit()
742+
tooltip_layer = plot.layer[1]
743+
self.assertEqual(tooltip_layer.mark.type, "rule")
744+
self.assertEqual(tooltip_layer.encoding.x.shorthand, f"{c.TIME}:T")
745+
self.assertFalse(tooltip_layer.encoding.opacity["condition"].empty)
746+
expected_tooltip = [
747+
alt.Tooltip(f"{c.TIME}:T", title="Time period"),
748+
alt.Tooltip(f"{c.EXPECTED}:Q", title="Expected", format=",.0f"),
749+
alt.Tooltip(f"{c.ACTUAL}:Q", title="Actual", format=",.0f"),
750+
alt.Tooltip(f"{c.BASELINE}:Q", title="Baseline", format=",.0f"),
751+
]
752+
self.assertEqual(tooltip_layer.encoding.tooltip, expected_tooltip)
753+
754+
def test_model_fit_points_encoding(self):
755+
plot = self.model_fit_kpi_type_revenue.plot_model_fit()
756+
points_layer = plot.layer[2]
757+
758+
with self.subTest("test_mark_properties"):
759+
self.assertEqual(points_layer.mark.type, "circle")
760+
self.assertFalse(points_layer.mark.filled)
761+
self.assertEqual(points_layer.mark.size, c.MARK_CIRCLE_SIZE)
762+
763+
with self.subTest("test_encoding_shorthands"):
764+
self.assertEqual(points_layer.encoding.x.shorthand, f"{c.TIME}:T")
765+
self.assertEqual(points_layer.encoding.y.shorthand, f"{c.MEAN}:Q")
766+
767+
with self.subTest("test_color_legend"):
768+
self.assertIsNone(points_layer.encoding.color["legend"])
769+
770+
def test_model_fit_selection_points(self):
771+
plot = self.model_fit_kpi_type_revenue.plot_model_fit()
772+
selection_points = plot.params[0].select
773+
self.assertTrue(selection_points["nearest"])
774+
self.assertEqual(selection_points["on"], "pointerover")
775+
self.assertEqual(selection_points["clear"], "pointerout")
776+
740777
def test_model_fit_correct_config(self):
741778
plot = self.model_fit_kpi_type_revenue.plot_model_fit()
742779
self.assertEqual(plot.config.axis.to_dict(), formatter.TEXT_CONFIG)

meridian/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,7 @@
758758
INDEPENDENT = 'independent'
759759
RESPONSE_CURVE_STEP_SIZE = 0.01
760760
OUTLIER_CLIP_FACTOR = 1.2
761+
MARK_CIRCLE_SIZE = 36
761762

762763

763764
# Font names.

0 commit comments

Comments
 (0)