Skip to content

Commit 5c31d5b

Browse files
Merge pull request #7 from YuminosukeSato/feat/dynamic-regression
feat: add dynamic regression (time-varying coefficients)
2 parents a0a912a + 2d9cb28 commit 5c31d5b

10 files changed

Lines changed: 1020 additions & 29 deletions

File tree

python/causal_impact/data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,7 @@ def _parse_periods(
185185
return pre_start, pre_end, post_start, post_end
186186

187187
@staticmethod
188-
def _validate_periods(
189-
pre_start, pre_end, post_start, post_end, time_index
190-
) -> None:
188+
def _validate_periods(pre_start, pre_end, post_start, post_end, time_index) -> None:
191189
idx_min = time_index.min()
192190
idx_max = time_index.max()
193191

python/causal_impact/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"standardize_data": True,
2626
"prior_level_sd": 0.01,
2727
"expected_model_size": 2,
28+
"dynamic_regression": False,
2829
"nseasons": None,
2930
"season_duration": None,
3031
}
@@ -100,6 +101,7 @@ def _run_sampler(self, prepared: PreparedData, args: dict):
100101
expected_model_size=float(args["expected_model_size"]),
101102
nseasons=args["nseasons"],
102103
season_duration=args["season_duration"],
104+
dynamic_regression=bool(args.get("dynamic_regression", False)),
103105
)
104106

105107
def _compute_results(self, prepared: PreparedData, samples) -> CausalImpactResults:

python/causal_impact/options.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class ModelOptions:
1919
standardize_data: bool = True
2020
prior_level_sd: float = 0.01
2121
expected_model_size: int = 1
22+
dynamic_regression: bool = False
2223
nseasons: int | None = None
2324
season_duration: int | None = None
2425

@@ -38,6 +39,12 @@ def __post_init__(self) -> None:
3839
if self.expected_model_size <= 0:
3940
msg = f"expected_model_size must be > 0, got {self.expected_model_size}"
4041
raise ValueError(msg)
42+
if not isinstance(self.dynamic_regression, bool):
43+
msg = (
44+
"dynamic_regression must be a bool, "
45+
f"got {type(self.dynamic_regression).__name__}"
46+
)
47+
raise ValueError(msg)
4148
if self.nseasons is not None:
4249
if not isinstance(self.nseasons, int):
4350
msg = f"nseasons must be an integer, got {self.nseasons}"

python/causal_impact/plot.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,18 @@ def plot(
5757
def _plot_original(ax, y, time_index, post_index, results):
5858
ax.plot(time_index, y, color="black", linewidth=1, label="Observed")
5959
ax.plot(
60-
post_index, results.predictions_mean, color="blue",
61-
linestyle="--", label="Counterfactual",
60+
post_index,
61+
results.predictions_mean,
62+
color="blue",
63+
linestyle="--",
64+
label="Counterfactual",
6265
)
6366
ax.fill_between(
64-
post_index, results.predictions_lower, results.predictions_upper,
65-
alpha=0.2, color="blue",
67+
post_index,
68+
results.predictions_lower,
69+
results.predictions_upper,
70+
alpha=0.2,
71+
color="blue",
6672
)
6773
ax.set_ylabel("Response")
6874
ax.legend(loc="upper left", fontsize=8)

0 commit comments

Comments
 (0)