Skip to content

Commit a2b479e

Browse files
committed
Merge PR #72: Add Tikhonov regularization for MVAR matrix inversion
Replaces direct matrix inverse with Tikhonov-regularized solve to prevent LinAlgError exceptions when computing Granger causality with near-singular matrices. Changes: - Replaced xp.linalg.inv() with regularized solve in MVAR computations - Added TIKHONOV_REGULARIZATION_FACTOR constant (1e-12) - Scale-aware regularization: λ = factor × mean(||H||²) - Added stress test for near-singular matrices Impact: - Granger causality measures handle highly correlated signals gracefully - No more LinAlgError crashes with near-singular transfer functions - Maintains numerical accuracy for well-conditioned cases Resolves: #72 # Conflicts: # CHANGELOG.md # tests/test_connectivity.py
2 parents 752aecd + 44666ed commit a2b479e

File tree

3 files changed

+82
-4
lines changed

3 files changed

+82
-4
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
8383
- New (correct): 513 bins, includes 500 Hz (Nyquist)
8484
- See PR #71 for detailed analysis and discussion
8585

86+
- **Tikhonov regularization for MVAR matrix inversion stability**:
87+
- Replaced direct matrix inverse (`xp.linalg.inv()`) with Tikhonov-regularized solve in MVAR computations
88+
- Prevents `LinAlgError` exceptions when computing Granger causality with near-singular matrices
89+
- Affected functions: `_MVAR_Fourier_coefficients` property and `_estimate_transfer_function` function
90+
- Uses scale-aware regularization: λ = `TIKHONOV_REGULARIZATION_FACTOR` × mean(||H||²)
91+
- Added module-level constant `TIKHONOV_REGULARIZATION_FACTOR = 1e-12` for consistency
92+
- Solves `(H + λI)x = I` instead of computing `inv(H)` for better numerical stability
93+
- Added stress test `test_mvar_regularized_inverse_near_singular()` validating near-singular cases
94+
- All Granger causality measures now handle highly correlated signals gracefully
95+
- See PR #72 for detailed numerical analysis
96+
8697
- CHANGELOG.md to track version changes following Keep a Changelog format
8798
- Ruff linter configuration for faster, more comprehensive Python linting
8899
- Enhanced package metadata with additional project URLs (Changelog, Source Code, Issue Tracker)

spectral_connectivity/connectivity.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@
7171
"time_trials_tapers": partial(xp.mean, axis=(0, 1, 2)),
7272
}
7373

74+
# Tikhonov regularization factor for stabilizing matrix inversions
75+
# Used to prevent numerical instability with near-singular matrices
76+
TIKHONOV_REGULARIZATION_FACTOR = 1e-12
77+
7478

7579
def _asnumpy(connectivity_measure: Callable) -> Callable:
7680
"""Transform cupy array to numpy array.
@@ -569,7 +573,13 @@ def _noise_covariance(self) -> NDArray[np.floating]:
569573

570574
@property
571575
def _MVAR_Fourier_coefficients(self) -> NDArray[np.complexfloating]:
572-
return xp.linalg.inv(self._transfer_function)
576+
H = self._transfer_function
577+
# Tikhonov regularization: solve(H + λI, I) instead of inv(H)
578+
# Scale-aware regularization parameter
579+
lam = TIKHONOV_REGULARIZATION_FACTOR * xp.mean(xp.real(xp.conj(H) * H))
580+
identity = xp.eye(H.shape[-1], dtype=H.dtype)
581+
regularized_H = H + lam * identity
582+
return xp.linalg.solve(regularized_H, identity)
573583

574584
@property
575585
def _expectation(self) -> Callable:
@@ -1711,9 +1721,13 @@ def _estimate_transfer_function(
17111721
17121722
"""
17131723
inverse_fourier_coefficients = ifft(minimum_phase, axis=-3).real
1714-
return xp.matmul(
1715-
minimum_phase, xp.linalg.inv(inverse_fourier_coefficients[..., 0:1, :, :])
1716-
)
1724+
H_0 = inverse_fourier_coefficients[..., 0:1, :, :]
1725+
# Tikhonov regularization: solve(H_0 + λI, I) instead of inv(H_0)
1726+
lam = TIKHONOV_REGULARIZATION_FACTOR * xp.mean(H_0 * H_0) # Scale-aware regularization for real matrix
1727+
identity = xp.eye(H_0.shape[-1], dtype=H_0.dtype)
1728+
regularized_H_0 = H_0 + lam * identity
1729+
H_0_inv = xp.linalg.solve(regularized_H_0, identity)
1730+
return xp.matmul(minimum_phase, H_0_inv)
17171731

17181732

17191733
def _estimate_predictive_power(

tests/test_connectivity.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,59 @@ def test_nyquist_bin_odd_n():
652652
), f"Expected {expected_n_frequencies} frequencies, got {coherence.shape[-3]}"
653653

654654

655+
def test_mvar_regularized_inverse_near_singular():
656+
"""Test regularized inverse handles near-singular frequency bins."""
657+
np.random.seed(42)
658+
n_time_samples, n_trials, n_tapers, n_fft_samples, n_signals = (
659+
1, 10, 1, 5, 3
660+
)
661+
662+
# Create nearly singular Fourier coefficients by making signals
663+
# highly correlated
664+
fourier_coefficients = np.zeros(
665+
(n_time_samples, n_trials, n_tapers, n_fft_samples, n_signals),
666+
dtype=complex,
667+
)
668+
669+
# Base signal
670+
base_signal = np.random.randn(
671+
n_time_samples, n_trials, n_tapers, n_fft_samples
672+
) + 1j * np.random.randn(
673+
n_time_samples, n_trials, n_tapers, n_fft_samples
674+
)
675+
676+
# Create near-singular cross-spectral matrix by making signals
677+
# nearly dependent
678+
fourier_coefficients[..., 0] = base_signal
679+
fourier_coefficients[..., 1] = base_signal + 1e-10 * (
680+
np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
681+
+ 1j * np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
682+
)
683+
fourier_coefficients[..., 2] = base_signal + 1e-10 * (
684+
np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
685+
+ 1j * np.random.randn(n_time_samples, n_trials, n_tapers, n_fft_samples)
686+
)
687+
688+
# This should not raise LinAlgError with regularized inverse
689+
conn = Connectivity(fourier_coefficients=fourier_coefficients)
690+
691+
# Test that MVAR coefficients are computed without error
692+
mvar_coeffs = conn._MVAR_Fourier_coefficients
693+
assert mvar_coeffs is not None
694+
assert np.all(np.isfinite(mvar_coeffs))
695+
696+
# Test that transfer function is computed without error
697+
transfer_func = conn._transfer_function
698+
assert transfer_func is not None
699+
assert np.all(np.isfinite(transfer_func))
700+
701+
# Test connectivity measures that depend on MVAR work
702+
dtf = conn.directed_transfer_function()
703+
assert np.all(np.isfinite(dtf))
704+
assert np.all(dtf >= 0) # DTF should be non-negative
705+
assert np.all(dtf <= 1) # DTF should be bounded by 1
706+
707+
655708
def test_connectivity_rejects_wrong_ndim():
656709
"""Test that Connectivity rejects inputs with wrong number of dimensions."""
657710
import pytest

0 commit comments

Comments
 (0)