Skip to content

Commit 1c09408

Browse files
committed
add agg argument to suppres aggregation of mediators
1 parent 6355526 commit 1c09408

File tree

2 files changed

+37
-16
lines changed

2 files changed

+37
-16
lines changed

pyfixest/estimation/decomposition.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@ class LinearMediation:
1212
https://gist.github.com/apoorvalal/e7dc9f3e52dcd9d51854b28b3e8a7ba4.
1313
"""
1414

15-
def __init__(self):
15+
def __init__(self, agg, param, coefnames):
16+
self.param = param
17+
self.coefnames = coefnames
18+
# Get the names of the mediator variables
19+
self.mediator_names = [name for name in coefnames if param not in name]
20+
self.agg = agg
1621
pass
1722

1823
def fit(self, X, W, y, store=True):
@@ -34,18 +39,27 @@ def fit(self, X, W, y, store=True):
3439
self.beta_tilde = np.linalg.lstsq(X, y, rcond=1)[0]
3540
self.delta_tilde = np.linalg.lstsq(X, W, rcond=1)[0]
3641
self.gamma_tilde = np.linalg.lstsq(W, y, rcond=1)[0]
37-
self.total_effect, self.mediated_effect = (
38-
self.beta_tilde,
39-
self.delta_tilde @ self.gamma_tilde,
42+
self.total_effect = self.beta_tilde.flatten()
43+
self.mediated_effect = (
44+
(self.delta_tilde @ self.gamma_tilde).flatten()
45+
if self.agg
46+
else self.delta_tilde.flatten() * self.gamma_tilde.flatten()
4047
)
41-
self.direct_effect = self.total_effect - self.mediated_effect
48+
self.direct_effect = self.total_effect - np.sum(self.mediated_effect)
4249
else:
4350
beta_tilde = np.linalg.lstsq(X, y, rcond=1)[0]
4451
delta_tilde = np.linalg.lstsq(X, W, rcond=1)[0]
4552
gamma_tilde = np.linalg.lstsq(W, y, rcond=1)[0]
46-
total_effect, mediated_effect = beta_tilde, delta_tilde @ gamma_tilde
47-
direct_effect = total_effect - mediated_effect
48-
return np.c_[total_effect, mediated_effect, direct_effect].flatten()
53+
total_effect = beta_tilde.flatten()
54+
mediated_effect = (
55+
(delta_tilde @ gamma_tilde).flatten()
56+
if self.agg
57+
else delta_tilde.flatten() * gamma_tilde.flatten()
58+
)
59+
direct_effect = total_effect - np.sum(mediated_effect)
60+
return np.concatenate(
61+
[total_effect, mediated_effect, direct_effect]
62+
).flatten()
4963

5064
def bootstrap(self, rng, B=1_000, alpha=0.05):
5165
"Bootstrap Confidence Intervals for Total, Mediated and Direct Effects."
@@ -62,11 +76,15 @@ def bootstrap(self, rng, B=1_000, alpha=0.05):
6276

6377
def summary(self):
6478
"Summary Table for Total, Mediated and Direct Effects."
65-
effects = np.c_[self.total_effect, self.mediated_effect, self.direct_effect]
66-
summary_arr = np.concatenate([effects, self.ci], axis=0)
79+
effects = np.concatenate(
80+
[self.total_effect, self.mediated_effect, self.direct_effect], axis=0
81+
)
82+
summary_arr = np.concatenate([effects.reshape(1, -1), self.ci], axis=0)
6783
self.summary_table = pd.DataFrame(
6884
summary_arr,
69-
columns=["Total Effect", "Mediated Effect", "Direct Effect"],
85+
columns=["Total Effect:"]
86+
+ [f"Mediated Effect: {var}" for var in self.mediator_names]
87+
+ [f"Direct Effect: {self.param}"],
7088
index=[
7189
"Estimate",
7290
f"CI Lower ({self.alpha/2})",

pyfixest/estimation/feols_.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,7 +1474,7 @@ def ccv(
14741474
n_splits=n_splits,
14751475
)
14761476

1477-
def decompose(self, param, type="gelbach", reps=1000, seed=None):
1477+
def decompose(self, param, agg=True, type="gelbach", reps=1000, seed=None):
14781478
"""
14791479
Implement the Gelbach (2016) decomposition method for mediation analysis.
14801480
@@ -1494,14 +1494,17 @@ def decompose(self, param, type="gelbach", reps=1000, seed=None):
14941494
)
14951495

14961496
param_idx = self._coefnames.index(param)
1497-
X_demean = np.atleast_2d(self._X[:, param_idx]).T
1498-
W_demean = np.atleast_2d(self._X[:, ~param_idx]).T
1497+
mask = np.ones(self._X.shape[1], dtype=bool)
1498+
mask[param_idx] = False
1499+
1500+
X_demean = (self._X[:, ~param_idx]).reshape((self._N, np.sum(not mask)))
1501+
W_demean = (self._X[:, mask]).reshape((self._N, np.sum(mask)))
14991502
Y_demean = self._Y
15001503

15011504
if type == "gelbach":
1502-
med = LinearMediation()
1505+
med = LinearMediation(agg=agg, param=param, coefnames=self._coefnames)
15031506
med.fit(X=X_demean, W=W_demean, y=Y_demean)
1504-
med.bootstrap(rng=rng)
1507+
med.bootstrap(rng=rng, B=reps)
15051508
med.summary()
15061509

15071510
else:

0 commit comments

Comments
 (0)