|
| 1 | +import numpy as np |
| 2 | +from scipy.linalg import lstsq |
| 3 | +from sklearn.base import BaseEstimator, TransformerMixin |
| 4 | +from sklearn.impute import SimpleImputer |
| 5 | + |
| 6 | + |
| 7 | +def find_subset_indices(X_full, X_subset, method="hash", allow_missing=False): |
| 8 | + """ |
| 9 | + Find row indices in X_full that correspond to rows in X_subset. |
| 10 | + Supports 'hash' (fast) and 'precise' (element-wise) matching. |
| 11 | + Allow_missing appends empty array for non-matching rows if True. |
| 12 | + """ |
| 13 | + if X_full.shape[1] != X_subset.shape[1]: |
| 14 | + raise ValueError( |
| 15 | + f"Feature dimensions don't match: {X_full.shape[1]} vs {X_subset.shape[1]}" |
| 16 | + ) |
| 17 | + indices = [] |
| 18 | + if method == "precise": |
| 19 | + for i, subset_row in enumerate(X_subset): |
| 20 | + matches = [ |
| 21 | + j |
| 22 | + for j, full_row in enumerate(X_full) |
| 23 | + if np.array_equal(full_row, subset_row, equal_nan=True) |
| 24 | + ] |
| 25 | + if not matches and not allow_missing: |
| 26 | + raise ValueError(f"No matching row found for subset row {i}") |
| 27 | + indices.append(matches[0] if matches else []) |
| 28 | + elif method == "hash": |
| 29 | + full_hashes = [hash(row.tobytes()) for row in X_full] |
| 30 | + for i, subset_row in enumerate(X_subset): |
| 31 | + subset_hash = hash(subset_row.tobytes()) |
| 32 | + try: |
| 33 | + indices.append(full_hashes.index(subset_hash)) |
| 34 | + except ValueError as e: |
| 35 | + if allow_missing: |
| 36 | + indices.append([]) |
| 37 | + else: |
| 38 | + raise ValueError(f"No matching row found for subset row {i}") from e |
| 39 | + else: |
| 40 | + raise ValueError(f"Unknown method '{method}'. Use 'hash' or 'precise'.") |
| 41 | + return np.array(indices) |
| 42 | + |
| 43 | + |
| 44 | +class CovariateRegressor(BaseEstimator, TransformerMixin): |
| 45 | + """ |
| 46 | + Fits covariate(s) onto each feature in X and returns their residuals. |
| 47 | + """ |
| 48 | + |
| 49 | + def __init__( |
| 50 | + self, |
| 51 | + covariate, |
| 52 | + X_full, |
| 53 | + pipeline=None, |
| 54 | + cross_validate=True, |
| 55 | + precise=False, |
| 56 | + unique_id_col_index=None, |
| 57 | + stack_intercept=True, |
| 58 | + ): |
| 59 | + """Regresses out a variable (covariate) from each feature in X. |
| 60 | +
|
| 61 | + Parameters |
| 62 | + ---------- |
| 63 | + covariate : numpy array |
| 64 | + Array of length (n_samples, n_covariates) to regress out of each |
| 65 | + feature; May have multiple columns for multiple covariates. |
| 66 | + X_full : numpy array |
| 67 | + Array of length (n_samples, n_features), from which the covariate |
| 68 | + will be regressed. This is used to determine how the |
| 69 | + covariate-models should be cross-validated (which is necessary |
| 70 | + to use in in scikit-learn Pipelines). |
| 71 | + pipeline : sklearn.pipeline.Pipeline or None, default=None |
| 72 | + Optional scikit-learn pipeline to apply to the covariate before fitting |
| 73 | + the regression model. If provided, the pipeline will be fitted on the |
| 74 | + covariate data during the fit phase and applied to transform the covariate |
| 75 | + in both fit and transform phases. This allows for preprocessing steps |
| 76 | + such as imputation, scaling, normalization, or feature engineering to be |
| 77 | + applied to the covariate consistently across train and test sets. If None, |
| 78 | + the covariate is used as-is without any preprocessing. |
| 79 | + cross_validate : bool |
| 80 | + Whether to cross-validate the covariate-parameters (y~covariate) |
| 81 | + estimated from the train-set to the test set (cross_validate=True) |
| 82 | + or whether to fit the covariate regressor separately on the test-set |
| 83 | + (cross_validate=False). |
| 84 | + precise: bool |
| 85 | + When setting precise to True, the arrays are compared feature-wise, |
| 86 | + which is accurate, but relatively slow. When setting precise to False, |
| 87 | + it will infer the index of the covariates by looking at the hash of all |
| 88 | + the features, which is much faster. Also, to aid the accuracy, we remove |
| 89 | + the features which are constant (0) across samples. |
| 90 | + stack_intercept : bool |
| 91 | + Whether to stack an intercept to the covariate (default is True) |
| 92 | +
|
| 93 | + Attributes |
| 94 | + ---------- |
| 95 | + weights_ : numpy array |
| 96 | + Array with weights for the covariate(s). |
| 97 | +
|
| 98 | + Notes |
| 99 | + ----- |
| 100 | + This is a modified version of the ConfoundRegressor from [1]_. Setting |
| 101 | + cross_validate to True is equivalent to "foldwise covariate regression" (FwCR) |
| 102 | + as described in Snoek et al. (2019). Setting this parameter to False, however, |
| 103 | + is NOT equivalent to "whole dataset covariate regression" (WDCR) as it does not |
| 104 | + apply covariate regression to the *full* dataset, but simply refits the |
| 105 | + covariate model on the test-set. We recommend setting this parameter to True. |
| 106 | + Transformer-objects in scikit-learn only allow to pass the data (X) and |
| 107 | + optionally the target (y) to the fit and transform methods. However, we need |
| 108 | + to index the covariate accordingly as well. To do so, we compare the X during |
| 109 | + initialization (self.X_full) with the X passed to fit/transform. As such, we can |
| 110 | + infer which samples are passed to the methods and index the covariate |
| 111 | + accordingly. The precise flag controls the precision of the index matching. |
| 112 | +
|
| 113 | + References |
| 114 | + ---------- |
| 115 | + .. [1] Lukas Snoek, Steven Miletić, H. Steven Scholte, |
| 116 | + "How to control for confounds in decoding analyses of neuroimaging data", |
| 117 | + NeuroImage, Volume 184, 2019, Pages 741-760, ISSN 1053-8119, |
| 118 | + https://doi.org/10.1016/j.neuroimage.2018.09.074. |
| 119 | + """ |
| 120 | + self.covariate = covariate.astype(np.float64) |
| 121 | + self.cross_validate = cross_validate |
| 122 | + self.X_full = X_full |
| 123 | + self.precise = precise |
| 124 | + self.stack_intercept = stack_intercept |
| 125 | + self.weights_ = None |
| 126 | + self.pipeline = pipeline |
| 127 | + self.imputer = SimpleImputer(strategy="median") |
| 128 | + self.X_imputer = SimpleImputer(strategy="median") |
| 129 | + self.unique_id_col_index = unique_id_col_index |
| 130 | + |
| 131 | + def _prepare_covariate(self, covariate): |
| 132 | + """Prepare covariate matrix (adds intercept if needed)""" |
| 133 | + if self.stack_intercept: |
| 134 | + return np.c_[np.ones((covariate.shape[0], 1)), covariate] |
| 135 | + return covariate |
| 136 | + |
| 137 | + def fit(self, X, y=None): |
| 138 | + """Fits the covariate-regressor to X. |
| 139 | +
|
| 140 | + Parameters |
| 141 | + ---------- |
| 142 | + X : numpy array |
| 143 | + An array of shape (n_samples, n_features), which should correspond |
| 144 | + to your train-set only! |
| 145 | + y : None |
| 146 | + Included for compatibility; does nothing. |
| 147 | + """ |
| 148 | + |
| 149 | + # Prepare covariate matrix (adds intercept if needed) |
| 150 | + covariate = self._prepare_covariate(self.covariate) |
| 151 | + |
| 152 | + # Find indices of X subset in the original X |
| 153 | + method = "precise" if self.precise else "hash" |
| 154 | + fit_idx = find_subset_indices(self.X_full, X, method=method) |
| 155 | + |
| 156 | + # Remove unique ID column if specified |
| 157 | + if self.unique_id_col_index is not None: |
| 158 | + X = np.delete(X, self.unique_id_col_index, axis=1) |
| 159 | + |
| 160 | + # Extract covariate data for the fitting subset |
| 161 | + covariate_fit = covariate[fit_idx, :] |
| 162 | + |
| 163 | + # Conditional imputation for covariate data |
| 164 | + if np.isnan(covariate_fit).any(): |
| 165 | + covariate_fit = self.imputer.fit_transform(covariate_fit) |
| 166 | + else: |
| 167 | + # Still fit the imputer for consistency in transform |
| 168 | + self.imputer.fit(covariate_fit) |
| 169 | + |
| 170 | + # Apply pipeline transformation if specified |
| 171 | + if self.pipeline is not None: |
| 172 | + X = self.pipeline.fit_transform(X) |
| 173 | + |
| 174 | + # Conditional imputation for X |
| 175 | + if np.isnan(X).any(): |
| 176 | + X = self.X_imputer.fit_transform(X) |
| 177 | + else: |
| 178 | + # Still fit the imputer for consistency in transform |
| 179 | + self.X_imputer.fit(X) |
| 180 | + |
| 181 | + # Fit linear regression: X = covariate * weights + residuals |
| 182 | + # Using scipy's lstsq for numerical stability |
| 183 | + self.weights_ = lstsq(covariate_fit, X)[0] |
| 184 | + |
| 185 | + return self |
| 186 | + |
| 187 | + def transform(self, X): |
| 188 | + """Regresses out covariate from X. |
| 189 | +
|
| 190 | + Parameters |
| 191 | + ---------- |
| 192 | + X : numpy array |
| 193 | + An array of shape (n_samples, n_features), which should correspond |
| 194 | + to your train-set only! |
| 195 | +
|
| 196 | + Returns |
| 197 | + ------- |
| 198 | + X_new : ndarray |
| 199 | + ndarray with covariate-regressed features |
| 200 | + """ |
| 201 | + |
| 202 | + if not self.cross_validate: |
| 203 | + self.fit(X) |
| 204 | + |
| 205 | + # Prepare covariate matrix (adds intercept if needed) |
| 206 | + covariate = self._prepare_covariate(self.covariate) |
| 207 | + |
| 208 | + # Find indices of X subset in the original X |
| 209 | + method = "precise" if self.precise else "hash" |
| 210 | + transform_idx = find_subset_indices(self.X_full, X, method=method) |
| 211 | + |
| 212 | + # Remove unique ID column if specified |
| 213 | + if self.unique_id_col_index is not None: |
| 214 | + X = np.delete(X, self.unique_id_col_index, axis=1) |
| 215 | + |
| 216 | + # Extract covariate data for the transform subset |
| 217 | + covariate_transform = covariate[transform_idx] |
| 218 | + |
| 219 | + # Conditional imputation for covariate data (use fitted imputer) |
| 220 | + if np.isnan(covariate_transform).any(): |
| 221 | + covariate_transform = self.imputer.transform(covariate_transform) |
| 222 | + |
| 223 | + # Apply pipeline transformation if specified |
| 224 | + if self.pipeline is not None: |
| 225 | + X = self.pipeline.transform(X) |
| 226 | + |
| 227 | + # Conditional imputation for X (use fitted imputer) |
| 228 | + if np.isnan(X).any(): |
| 229 | + X = self.X_imputer.transform(X) |
| 230 | + |
| 231 | + # Compute residuals |
| 232 | + X_new = X - covariate_transform.dot(self.weights_) |
| 233 | + |
| 234 | + # Ensure no NaNs in output |
| 235 | + X_new = np.nan_to_num(X_new) |
| 236 | + |
| 237 | + return X_new |
0 commit comments