Skip to content

Commit a2afe7a

Browse files
authored
Merge pull request #27 from chiuhoward/add-covariate-regressor
[ENH] Add CovariateRegressor class
2 parents 0ea86cb + b4f8978 commit a2afe7a

File tree

3 files changed

+494
-0
lines changed

3 files changed

+494
-0
lines changed

LICENSE

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,27 @@ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
2929
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
3030
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
3131
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32+
33+
Code used in the covariate_regressor module is distributed with the following license:
34+
35+
MIT License
36+
37+
Copyright (c) 2017 Lukas Snoek
38+
39+
Permission is hereby granted, free of charge, to any person obtaining a copy
40+
of this software and associated documentation files (the "Software"), to deal
41+
in the Software without restriction, including without limitation the rights
42+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
43+
copies of the Software, and to permit persons to whom the Software is
44+
furnished to do so, subject to the following conditions:
45+
46+
The above copyright notice and this permission notice shall be included in all
47+
copies or substantial portions of the Software.
48+
49+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
50+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
51+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
52+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
53+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
54+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
55+
SOFTWARE.

afqinsight/covariate_regressor.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import numpy as np
2+
from scipy.linalg import lstsq
3+
from sklearn.base import BaseEstimator, TransformerMixin
4+
from sklearn.impute import SimpleImputer
5+
6+
7+
def find_subset_indices(X_full, X_subset, method="hash", allow_missing=False):
8+
"""
9+
Find row indices in X_full that correspond to rows in X_subset.
10+
Supports 'hash' (fast) and 'precise' (element-wise) matching.
11+
Allow_missing appends empty array for non-matching rows if True.
12+
"""
13+
if X_full.shape[1] != X_subset.shape[1]:
14+
raise ValueError(
15+
f"Feature dimensions don't match: {X_full.shape[1]} vs {X_subset.shape[1]}"
16+
)
17+
indices = []
18+
if method == "precise":
19+
for i, subset_row in enumerate(X_subset):
20+
matches = [
21+
j
22+
for j, full_row in enumerate(X_full)
23+
if np.array_equal(full_row, subset_row, equal_nan=True)
24+
]
25+
if not matches and not allow_missing:
26+
raise ValueError(f"No matching row found for subset row {i}")
27+
indices.append(matches[0] if matches else [])
28+
elif method == "hash":
29+
full_hashes = [hash(row.tobytes()) for row in X_full]
30+
for i, subset_row in enumerate(X_subset):
31+
subset_hash = hash(subset_row.tobytes())
32+
try:
33+
indices.append(full_hashes.index(subset_hash))
34+
except ValueError as e:
35+
if allow_missing:
36+
indices.append([])
37+
else:
38+
raise ValueError(f"No matching row found for subset row {i}") from e
39+
else:
40+
raise ValueError(f"Unknown method '{method}'. Use 'hash' or 'precise'.")
41+
return np.array(indices)
42+
43+
44+
class CovariateRegressor(BaseEstimator, TransformerMixin):
45+
"""
46+
Fits covariate(s) onto each feature in X and returns their residuals.
47+
"""
48+
49+
def __init__(
50+
self,
51+
covariate,
52+
X_full,
53+
pipeline=None,
54+
cross_validate=True,
55+
precise=False,
56+
unique_id_col_index=None,
57+
stack_intercept=True,
58+
):
59+
"""Regresses out a variable (covariate) from each feature in X.
60+
61+
Parameters
62+
----------
63+
covariate : numpy array
64+
Array of length (n_samples, n_covariates) to regress out of each
65+
feature; May have multiple columns for multiple covariates.
66+
X_full : numpy array
67+
Array of length (n_samples, n_features), from which the covariate
68+
will be regressed. This is used to determine how the
69+
covariate-models should be cross-validated (which is necessary
70+
to use in in scikit-learn Pipelines).
71+
pipeline : sklearn.pipeline.Pipeline or None, default=None
72+
Optional scikit-learn pipeline to apply to the covariate before fitting
73+
the regression model. If provided, the pipeline will be fitted on the
74+
covariate data during the fit phase and applied to transform the covariate
75+
in both fit and transform phases. This allows for preprocessing steps
76+
such as imputation, scaling, normalization, or feature engineering to be
77+
applied to the covariate consistently across train and test sets. If None,
78+
the covariate is used as-is without any preprocessing.
79+
cross_validate : bool
80+
Whether to cross-validate the covariate-parameters (y~covariate)
81+
estimated from the train-set to the test set (cross_validate=True)
82+
or whether to fit the covariate regressor separately on the test-set
83+
(cross_validate=False).
84+
precise: bool
85+
When setting precise to True, the arrays are compared feature-wise,
86+
which is accurate, but relatively slow. When setting precise to False,
87+
it will infer the index of the covariates by looking at the hash of all
88+
the features, which is much faster. Also, to aid the accuracy, we remove
89+
the features which are constant (0) across samples.
90+
stack_intercept : bool
91+
Whether to stack an intercept to the covariate (default is True)
92+
93+
Attributes
94+
----------
95+
weights_ : numpy array
96+
Array with weights for the covariate(s).
97+
98+
Notes
99+
-----
100+
This is a modified version of the ConfoundRegressor from [1]_. Setting
101+
cross_validate to True is equivalent to "foldwise covariate regression" (FwCR)
102+
as described in Snoek et al. (2019). Setting this parameter to False, however,
103+
is NOT equivalent to "whole dataset covariate regression" (WDCR) as it does not
104+
apply covariate regression to the *full* dataset, but simply refits the
105+
covariate model on the test-set. We recommend setting this parameter to True.
106+
Transformer-objects in scikit-learn only allow to pass the data (X) and
107+
optionally the target (y) to the fit and transform methods. However, we need
108+
to index the covariate accordingly as well. To do so, we compare the X during
109+
initialization (self.X_full) with the X passed to fit/transform. As such, we can
110+
infer which samples are passed to the methods and index the covariate
111+
accordingly. The precise flag controls the precision of the index matching.
112+
113+
References
114+
----------
115+
.. [1] Lukas Snoek, Steven Miletić, H. Steven Scholte,
116+
"How to control for confounds in decoding analyses of neuroimaging data",
117+
NeuroImage, Volume 184, 2019, Pages 741-760, ISSN 1053-8119,
118+
https://doi.org/10.1016/j.neuroimage.2018.09.074.
119+
"""
120+
self.covariate = covariate.astype(np.float64)
121+
self.cross_validate = cross_validate
122+
self.X_full = X_full
123+
self.precise = precise
124+
self.stack_intercept = stack_intercept
125+
self.weights_ = None
126+
self.pipeline = pipeline
127+
self.imputer = SimpleImputer(strategy="median")
128+
self.X_imputer = SimpleImputer(strategy="median")
129+
self.unique_id_col_index = unique_id_col_index
130+
131+
def _prepare_covariate(self, covariate):
132+
"""Prepare covariate matrix (adds intercept if needed)"""
133+
if self.stack_intercept:
134+
return np.c_[np.ones((covariate.shape[0], 1)), covariate]
135+
return covariate
136+
137+
def fit(self, X, y=None):
138+
"""Fits the covariate-regressor to X.
139+
140+
Parameters
141+
----------
142+
X : numpy array
143+
An array of shape (n_samples, n_features), which should correspond
144+
to your train-set only!
145+
y : None
146+
Included for compatibility; does nothing.
147+
"""
148+
149+
# Prepare covariate matrix (adds intercept if needed)
150+
covariate = self._prepare_covariate(self.covariate)
151+
152+
# Find indices of X subset in the original X
153+
method = "precise" if self.precise else "hash"
154+
fit_idx = find_subset_indices(self.X_full, X, method=method)
155+
156+
# Remove unique ID column if specified
157+
if self.unique_id_col_index is not None:
158+
X = np.delete(X, self.unique_id_col_index, axis=1)
159+
160+
# Extract covariate data for the fitting subset
161+
covariate_fit = covariate[fit_idx, :]
162+
163+
# Conditional imputation for covariate data
164+
if np.isnan(covariate_fit).any():
165+
covariate_fit = self.imputer.fit_transform(covariate_fit)
166+
else:
167+
# Still fit the imputer for consistency in transform
168+
self.imputer.fit(covariate_fit)
169+
170+
# Apply pipeline transformation if specified
171+
if self.pipeline is not None:
172+
X = self.pipeline.fit_transform(X)
173+
174+
# Conditional imputation for X
175+
if np.isnan(X).any():
176+
X = self.X_imputer.fit_transform(X)
177+
else:
178+
# Still fit the imputer for consistency in transform
179+
self.X_imputer.fit(X)
180+
181+
# Fit linear regression: X = covariate * weights + residuals
182+
# Using scipy's lstsq for numerical stability
183+
self.weights_ = lstsq(covariate_fit, X)[0]
184+
185+
return self
186+
187+
def transform(self, X):
188+
"""Regresses out covariate from X.
189+
190+
Parameters
191+
----------
192+
X : numpy array
193+
An array of shape (n_samples, n_features), which should correspond
194+
to your train-set only!
195+
196+
Returns
197+
-------
198+
X_new : ndarray
199+
ndarray with covariate-regressed features
200+
"""
201+
202+
if not self.cross_validate:
203+
self.fit(X)
204+
205+
# Prepare covariate matrix (adds intercept if needed)
206+
covariate = self._prepare_covariate(self.covariate)
207+
208+
# Find indices of X subset in the original X
209+
method = "precise" if self.precise else "hash"
210+
transform_idx = find_subset_indices(self.X_full, X, method=method)
211+
212+
# Remove unique ID column if specified
213+
if self.unique_id_col_index is not None:
214+
X = np.delete(X, self.unique_id_col_index, axis=1)
215+
216+
# Extract covariate data for the transform subset
217+
covariate_transform = covariate[transform_idx]
218+
219+
# Conditional imputation for covariate data (use fitted imputer)
220+
if np.isnan(covariate_transform).any():
221+
covariate_transform = self.imputer.transform(covariate_transform)
222+
223+
# Apply pipeline transformation if specified
224+
if self.pipeline is not None:
225+
X = self.pipeline.transform(X)
226+
227+
# Conditional imputation for X (use fitted imputer)
228+
if np.isnan(X).any():
229+
X = self.X_imputer.transform(X)
230+
231+
# Compute residuals
232+
X_new = X - covariate_transform.dot(self.weights_)
233+
234+
# Ensure no NaNs in output
235+
X_new = np.nan_to_num(X_new)
236+
237+
return X_new

0 commit comments

Comments
 (0)