-
Notifications
You must be signed in to change notification settings - Fork 43
Implement automatic ridge regression #124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
91aa8b8
7f8a19d
f4ade9b
acad4d5
49c189c
720edbd
5feb309
d85e80e
eccb46a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,13 @@ | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| import scipy | ||
| from scipy.linalg import sqrtm | ||
| from sklearn.linear_model import RidgeCV | ||
| from tqdm import tqdm | ||
| from mne import BaseEpochs | ||
|
|
||
| from mne.utils import logger, verbose | ||
| from mne.utils import logger, verbose, warn | ||
|
|
||
| from ..utils import fill_doc | ||
| from ..base import Connectivity, EpochConnectivity, EpochTemporalConnectivity | ||
|
|
@@ -13,7 +16,7 @@ | |
| @verbose | ||
| @fill_doc | ||
| def vector_auto_regression( | ||
| data, times=None, names=None, lags=1, l2_reg=0.0, | ||
| data, times=None, names=None, lags=1, l2_reg='auto', | ||
| compute_fb_operator=False, model='dynamic', n_jobs=1, verbose=None): | ||
| """Compute vector auto-regresssive (VAR) model. | ||
|
|
||
|
|
@@ -29,8 +32,14 @@ def vector_auto_regression( | |
| %(names)s | ||
| lags : int, optional | ||
| Autoregressive model order, by default 1. | ||
| l2_reg : float, optional | ||
| Ridge penalty (l2-regularization) parameter, by default 0.0. | ||
| l2_reg : str | array-like, shape=(n_alphas,) | float | None, optional | ||
| Ridge penalty (l2-regularization) parameter, by default 'auto'. If | ||
| ``data`` has condition number less than 1e6, then ``data`` will undergo | ||
| automatic regularization using RidgeCV with a pre-defined array of | ||
| alphas: np.logspace(-15,5,11). A user-defined array of alphas (must be | ||
| positive floats) can be inputted or a float value to fix the Ridge | ||
| penalty (l2-regularization) parameter. If ``l2_reg`` is set to 0 or | ||
| None, then no regularization will be performed. | ||
| compute_fb_operator : bool | ||
| Whether to compute the backwards operator and average with | ||
| the forward operator. Addresses bias in the least-square | ||
|
|
@@ -151,9 +160,32 @@ def vector_auto_regression( | |
| # 1. determine shape of the window of data | ||
| n_epochs, n_nodes, _ = data.shape | ||
|
|
||
| cv_alphas = None | ||
| if isinstance(l2_reg, str) and l2_reg == 'auto': | ||
| # reset l2_reg for downstream functions | ||
| l2_reg = 0 | ||
| # determine condition of matrix across all epochs | ||
| conds = np.linalg.cond(data) | ||
| if np.any(conds > 1e6): | ||
| # matrix is ill-conditioned, so regularization must be used with | ||
| # cross-validation alphas values | ||
| cv_alphas = np.logspace(-15, 5, 11) | ||
| warn('Input data matrix exceeds condition threshold of 1e6. ' | ||
| 'Automatic regularization will be performed.') | ||
| elif isinstance(l2_reg, (list, tuple, set, np.ndarray)): | ||
| cv_alphas = l2_reg | ||
| l2_reg = 0 | ||
|
|
||
| # cases where OLS is used | ||
| if (l2_reg in [0, None]) and (cv_alphas is None): | ||
| use_ridge = False | ||
| else: | ||
| use_ridge = True | ||
|
|
||
| model_params = { | ||
| 'lags': lags, | ||
| 'l2_reg': l2_reg, | ||
| 'use_ridge': use_ridge, | ||
| 'cv_alphas': cv_alphas | ||
| } | ||
|
|
||
| if verbose: | ||
|
|
@@ -165,12 +197,20 @@ def vector_auto_regression( | |
| # sample of the multivariate time-series of interest | ||
| # ordinary least squares or regularized least squares | ||
| # (ridge regression) | ||
| X, Y = _construct_var_eqns(data, **model_params) | ||
|
|
||
| b, res, rank, s = scipy.linalg.lstsq(X, Y) | ||
| X, Y = _construct_var_eqns(data, lags=lags, l2_reg=l2_reg) | ||
|
|
||
| # get the coefficients | ||
| coef = b.transpose() | ||
| if cv_alphas is not None: | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings( | ||
| action='ignore', | ||
| message="Ill-conditioned matrix" | ||
| ) | ||
|
Comment on lines
+204
to
+208
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RidgeCV tests out an array of alpha values and some of them do not regularize the matrix enough to avoid an ill-conditioned matrix error. If the user sees many of these messages pop up, they may think that something is going wrong, when in fact the expected behavior of the function is happening. RidgeCV will choose the best alpha value and that will be from an instance when this error was not thrown. |
||
| reg = RidgeCV(alphas=cv_alphas, cv=5).fit(X, Y) | ||
| coef = reg.coef_ | ||
| else: | ||
| b, res, rank, s = scipy.linalg.lstsq(X, Y) | ||
| coef = b.transpose() | ||
|
|
||
| # create connectivity | ||
| coef = coef.flatten() | ||
|
|
@@ -187,8 +227,9 @@ def vector_auto_regression( | |
| # linear system | ||
| A_mats = _system_identification( | ||
| data=data, lags=lags, | ||
| l2_reg=l2_reg, n_jobs=n_jobs, | ||
| compute_fb_operator=compute_fb_operator) | ||
| l2_reg=l2_reg, cv_alphas=cv_alphas, | ||
| n_jobs=n_jobs, compute_fb_operator=compute_fb_operator | ||
| ) | ||
| # create connectivity | ||
| if lags > 1: | ||
| conn = EpochTemporalConnectivity(data=A_mats, | ||
|
|
@@ -261,7 +302,7 @@ def _construct_var_eqns(data, lags, l2_reg=None): | |
| X[:n, i * lags + k - | ||
| 1] = np.reshape(data[:, i, lags - k:-k].T, n) | ||
|
|
||
| if l2_reg is not None: | ||
| if l2_reg: | ||
| np.fill_diagonal(X[n:, :], l2_reg) | ||
|
|
||
| # Construct vectors yi (response variables for each channel i) | ||
|
|
@@ -272,7 +313,7 @@ def _construct_var_eqns(data, lags, l2_reg=None): | |
| return X, Y | ||
|
|
||
|
|
||
| def _system_identification(data, lags, l2_reg=0, | ||
| def _system_identification(data, lags, l2_reg=0, cv_alphas=None, | ||
| n_jobs=-1, compute_fb_operator=False): | ||
| """Solve system identification using least-squares over all epochs. | ||
|
|
||
|
|
@@ -290,6 +331,7 @@ def _system_identification(data, lags, l2_reg=0, | |
| model_params = { | ||
| 'l2_reg': l2_reg, | ||
| 'lags': lags, | ||
| 'cv_alphas': cv_alphas, | ||
| 'compute_fb_operator': compute_fb_operator | ||
| } | ||
|
|
||
|
|
@@ -346,7 +388,7 @@ def _system_identification(data, lags, l2_reg=0, | |
| return A_mats | ||
|
|
||
|
|
||
| def _compute_lds_func(data, lags, l2_reg, compute_fb_operator): | ||
| def _compute_lds_func(data, lags, l2_reg, cv_alphas, compute_fb_operator): | ||
| """Compute linear system using VAR model. | ||
|
|
||
| Allows for parallelization over epochs. | ||
|
|
@@ -372,20 +414,21 @@ def _compute_lds_func(data, lags, l2_reg, compute_fb_operator): | |
| # get time-shifted versions | ||
| X = data[:, :] | ||
| A, resid, omega = _estimate_var(X, lags=lags, offset=0, | ||
| l2_reg=l2_reg) | ||
| l2_reg=l2_reg, cv_alphas=cv_alphas) | ||
|
|
||
| if compute_fb_operator: | ||
| # compute backward linear operator | ||
| # original method | ||
| back_A, back_resid, back_omega = _estimate_var( | ||
| X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg) | ||
| X[::-1, :], lags=lags, offset=0, l2_reg=l2_reg, cv_alphas=cv_alphas | ||
| ) | ||
| A = sqrtm(A.dot(np.linalg.inv(back_A))) | ||
| A = A.real # remove numerical noise | ||
|
|
||
| return A, resid, omega | ||
|
|
||
|
|
||
| def _estimate_var(X, lags, offset=0, l2_reg=0): | ||
| def _estimate_var(X, lags, offset=0, l2_reg=0, cv_alphas=None): | ||
| """Estimate a VAR model. | ||
|
|
||
| Parameters | ||
|
|
@@ -397,8 +440,10 @@ def _estimate_var(X, lags, offset=0, l2_reg=0): | |
| offset : int, optional | ||
| Periods to drop from the beginning of the time-series, by default 0. | ||
| Used for order selection, so it's an apples-to-apples comparison | ||
| l2_reg : int | ||
| l2_reg : int, optional | ||
| The amount of l2-regularization to use. Default of 0. | ||
| cv_alphas : array-like | None, optional | ||
| RidgeCV regularization cross-validation alpha values. Defaults to None. | ||
|
|
||
| Returns | ||
| ------- | ||
|
|
@@ -432,10 +477,25 @@ def _estimate_var(X, lags, offset=0, l2_reg=0): | |
| y_sample = endog[lags:] | ||
| del endog, X | ||
| # Lütkepohl p75, about 5x faster than stated formula | ||
| if l2_reg != 0: | ||
| params = np.linalg.lstsq(z.T @ z + l2_reg * np.eye(n_equations * lags), | ||
| z.T @ y_sample, rcond=1e-15)[0] | ||
|
|
||
| if (l2_reg is not None) and (l2_reg != 0): | ||
| # use pre-specified l2 regularization value | ||
| params = np.linalg.lstsq( | ||
| z.T @ z + l2_reg * np.eye(n_equations * lags), | ||
| z.T @ y_sample, | ||
| rcond=1e-15 | ||
| )[0] | ||
| elif cv_alphas is not None: | ||
| # use ridge regression with built-in cross validation of alpha values | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings( | ||
| action='ignore', | ||
| message="Ill-conditioned matrix" | ||
| ) | ||
| reg = RidgeCV(alphas=cv_alphas, cv=5).fit(z, y_sample) | ||
| params = reg.coef_.T | ||
| else: | ||
| # use OLS regression | ||
| params = np.linalg.lstsq(z, y_sample, rcond=1e-15)[0] | ||
|
|
||
| # (n_samples - lags, n_channels) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.