Skip to content

Commit 104d375

Browse files
riemanliThe Meridian Authors
authored andcommitted
Add has_constant='add' to sm.add_constant in VIF computation.
PiperOrigin-RevId: 837273955
1 parent 82f910b commit 104d375

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Fixed an out-of-bounds bug in EDA's VIF check.
27+
* Added cost per media unit checks to EDA.
2628
* Add support for holdout set in `GoodnessOfFitCheck`.
2729
* Introduce modules needed for Meridian Scenario Planner and add
2830
`scenarioplanner` extra.

meridian/model/eda/eda_engine.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@ def _calculate_vif(input_da: xr.DataArray, var_dim: str) -> xr.DataArray:
361361
"""
362362
num_vars = input_da.sizes[var_dim]
363363
np_data = input_da.values.reshape(-1, num_vars)
364-
np_data_with_const = sm.add_constant(np_data, prepend=True)
364+
np_data_with_const = sm.add_constant(
365+
np_data, prepend=True, has_constant='add'
366+
)
365367

366368
# Compute VIF for each variable excluding const which is the first one in the
367369
# 'variable' dimension.
@@ -2094,7 +2096,9 @@ def run_all_critical_checks(self) -> list[eda_outcome.EDAOutcome]:
20942096
except Exception as e: # pylint: disable=broad-except
20952097
error_finding = eda_outcome.EDAFinding(
20962098
severity=eda_outcome.EDASeverity.ERROR,
2097-
explanation=f'An error occurred during check {check.__name__}: {e}',
2099+
explanation=(
2100+
f'An error occurred during running {check.__name__}: {e!r}'
2101+
),
20982102
)
20992103
outcomes.append(
21002104
eda_outcome.EDAOutcome(

meridian/model/eda/eda_engine_test.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5610,6 +5610,37 @@ def test_check_national_vif_has_correct_vif_value(self):
56105610
national_artifact.vif_da.values, expected_national_vif
56115611
)
56125612

5613+
def test_check_vif_with_constant_variable(self):
5614+
meridian = model.Meridian(self.national_input_data_media_and_rf)
5615+
engine = eda_engine.EDAEngine(meridian)
5616+
shape = (_N_TIMES_VIF,)
5617+
v1 = _RNG.random(shape)
5618+
v2 = np.ones(shape)
5619+
v3 = _RNG.random(shape)
5620+
data_np = np.stack([v1, v2, v3], axis=-1)
5621+
data = (
5622+
_create_data_array_with_var_dim(data_np, "VIF", "var")
5623+
.rename({"var_dim": eda_engine._STACK_VAR_COORD_NAME})
5624+
.assign_coords(
5625+
{eda_engine._STACK_VAR_COORD_NAME: ["var_1", "var_2", "var_3"]}
5626+
)
5627+
)
5628+
self._mock_eda_engine_property(
5629+
"_stacked_national_treatment_control_scaled_da", data
5630+
)
5631+
5632+
outcome = engine.check_national_vif()
5633+
5634+
self.assertIsInstance(outcome, eda_outcome.EDAOutcome)
5635+
self.assertEqual(
5636+
outcome.check_type, eda_outcome.EDACheckType.MULTICOLLINEARITY
5637+
)
5638+
self.assertLen(outcome.analysis_artifacts, 1)
5639+
5640+
national_artifact = outcome.analysis_artifacts[0]
5641+
self.assertIsInstance(national_artifact, eda_outcome.VIFArtifact)
5642+
self.assertEqual(national_artifact.vif_da.sel(var="var_2"), 0)
5643+
56135644
@parameterized.named_parameters(
56145645
dict(
56155646
testcase_name="national_model",
@@ -6312,7 +6343,7 @@ def test_run_all_critical_checks_with_exception(self):
63126343
outcomes[1].findings[0].severity, eda_outcome.EDASeverity.ERROR
63136344
)
63146345
self.assertIn(
6315-
"An error occurred during check check_vif: Test Error",
6346+
"An error occurred during running check_vif: ValueError('Test Error')",
63166347
outcomes[1].findings[0].explanation,
63176348
)
63186349

@@ -6325,7 +6356,8 @@ def test_run_all_critical_checks_with_exception(self):
63256356
outcomes[2].findings[0].severity, eda_outcome.EDASeverity.ERROR
63266357
)
63276358
self.assertIn(
6328-
"An error occurred during check check_pairwise_corr: Another Error",
6359+
"An error occurred during running check_pairwise_corr:"
6360+
" TypeError('Another Error')",
63296361
outcomes[2].findings[0].explanation,
63306362
)
63316363

0 commit comments

Comments
 (0)