@@ -12,7 +12,12 @@ class LinearMediation:
1212 https://gist.github.com/apoorvalal/e7dc9f3e52dcd9d51854b28b3e8a7ba4.
1313 """
1414
15- def __init__ (self ):
15+ def __init__ (self , agg , param , coefnames ):
16+ self .param = param
17+ self .coefnames = coefnames
18+ # Get the names of the mediator variables
19+ self .mediator_names = [name for name in coefnames if param not in name ]
20+ self .agg = agg
1621 pass
1722
1823 def fit (self , X , W , y , store = True ):
@@ -34,18 +39,27 @@ def fit(self, X, W, y, store=True):
3439 self .beta_tilde = np .linalg .lstsq (X , y , rcond = 1 )[0 ]
3540 self .delta_tilde = np .linalg .lstsq (X , W , rcond = 1 )[0 ]
3641 self .gamma_tilde = np .linalg .lstsq (W , y , rcond = 1 )[0 ]
37- self .total_effect , self .mediated_effect = (
38- self .beta_tilde ,
39- self .delta_tilde @ self .gamma_tilde ,
42+ self .total_effect = self .beta_tilde .flatten ()
43+ self .mediated_effect = (
44+ (self .delta_tilde @ self .gamma_tilde ).flatten ()
45+ if self .agg
46+ else self .delta_tilde .flatten () * self .gamma_tilde .flatten ()
4047 )
41- self .direct_effect = self .total_effect - self .mediated_effect
48+ self .direct_effect = self .total_effect - np . sum ( self .mediated_effect )
4249 else :
4350 beta_tilde = np .linalg .lstsq (X , y , rcond = 1 )[0 ]
4451 delta_tilde = np .linalg .lstsq (X , W , rcond = 1 )[0 ]
4552 gamma_tilde = np .linalg .lstsq (W , y , rcond = 1 )[0 ]
46- total_effect , mediated_effect = beta_tilde , delta_tilde @ gamma_tilde
47- direct_effect = total_effect - mediated_effect
48- return np .c_ [total_effect , mediated_effect , direct_effect ].flatten ()
53+ total_effect = beta_tilde .flatten ()
54+ mediated_effect = (
55+ (delta_tilde @ gamma_tilde ).flatten ()
56+ if self .agg
57+ else delta_tilde .flatten () * gamma_tilde .flatten ()
58+ )
59+ direct_effect = total_effect - np .sum (mediated_effect )
60+ return np .concatenate (
61+ [total_effect , mediated_effect , direct_effect ]
62+ ).flatten ()
4963
5064 def bootstrap (self , rng , B = 1_000 , alpha = 0.05 ):
5165 "Bootstrap Confidence Intervals for Total, Mediated and Direct Effects."
@@ -62,11 +76,15 @@ def bootstrap(self, rng, B=1_000, alpha=0.05):
6276
6377 def summary (self ):
6478 "Summary Table for Total, Mediated and Direct Effects."
65- effects = np .c_ [self .total_effect , self .mediated_effect , self .direct_effect ]
66- summary_arr = np .concatenate ([effects , self .ci ], axis = 0 )
79+ effects = np .concatenate (
80+ [self .total_effect , self .mediated_effect , self .direct_effect ], axis = 0
81+ )
82+ summary_arr = np .concatenate ([effects .reshape (1 , - 1 ), self .ci ], axis = 0 )
6783 self .summary_table = pd .DataFrame (
6884 summary_arr ,
69- columns = ["Total Effect" , "Mediated Effect" , "Direct Effect" ],
85+ columns = ["Total Effect:" ]
86+ + [f"Mediated Effect: { var } " for var in self .mediator_names ]
87+ + [f"Direct Effect: { self .param } " ],
7088 index = [
7189 "Estimate" ,
7290 f"CI Lower ({ self .alpha / 2 } )" ,
0 commit comments