diff --git a/doc/source/ref/thresholding-functions.rst b/doc/source/ref/thresholding-functions.rst index b0424cad3..b09ecb92a 100644 --- a/doc/source/ref/thresholding-functions.rst +++ b/doc/source/ref/thresholding-functions.rst @@ -12,6 +12,7 @@ Thresholding .. autofunction:: threshold .. autofunction:: threshold_firm +.. autofunction:: estimate_noise The left panel of the figure below illustrates that non-negative Garotte thresholding is intermediate between soft and hard thresholding. Firm diff --git a/pywt/_thresholding.py b/pywt/_thresholding.py index 286f61179..989971905 100644 --- a/pywt/_thresholding.py +++ b/pywt/_thresholding.py @@ -11,7 +11,10 @@ from __future__ import division, print_function, absolute_import import numpy as np -__all__ = ['threshold', 'threshold_firm'] +from ._multidim import dwtn + + +__all__ = ['threshold', 'threshold_firm', 'estimate_noise'] def soft(data, value, substitute=0): @@ -248,3 +251,90 @@ def threshold_firm(data, value_low, value_high): if np.any(large_vals[0]): thresholded[large_vals] = data[large_vals] return thresholded + + +def estimate_noise(data, distribution='Gaussian', **kwargs): + """ + Robust wavelet-based estimator of the (Gaussian) noise standard deviation. + + Parameters + ---------- + data : ndarray + The data used to estimate sigma. + distribution : str or object with ppf method + The underlying noise distribution. + \\**kwargs : \\**kwargs + Keyword arguments to pass into distribution ppf method. + + Returns + ------- + sigma : float + Estimated noise standard deviation. + + Notes + ----- + This function assumes the noise follows a Gaussian distribution. The + estimation algorithm is based on the median absolute deviation of the + wavelet detail coefficients as described in section 4.2 of [1]_. + + References + ---------- + .. [1] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation + by wavelet shrinkage." Biometrika 81.3 (1994): 425-455. + DOI:10.1093/biomet/81.3.425 + + Examples + -------- + >>> import numpy as np + >>> import pywt + >>> data = np.sin(np.linspace(0,10,100)) + >>> np.random.seed(42) + >>> noise = 0.5 * np.random.normal(0,1,100) + >>> pywt.estimate_noise(data + noise) + 0.45634925413327504 + """ + + coeffs = dwtn(data, wavelet='db2') + detail_coeffs = coeffs['d' * data.ndim] + return _sigma_est_dwt(detail_coeffs, distribution=distribution, **kwargs) + + +def _sigma_est_dwt(detail_coeffs, distribution='Gaussian', **kwargs): + """Calculate the robust median estimator of the noise standard deviation. + Parameters + ---------- + detail_coeffs : ndarray + The detail coefficients corresponding to the discrete wavelet + transform of an image. + distribution : str or object with ppf method + The underlying noise distribution. + \\**kwargs : \\**kwargs + Keyword arguments to pass into distribution ppf method. + + Returns + ------- + sigma : float + The estimated noise standard deviation (see section 4.2 of [1]_). + References + ---------- + .. [1] D. L. Donoho and I. M. Johnstone. "Ideal spatial adaptation + by wavelet shrinkage." Biometrika 81.3 (1994): 425-455. + DOI:10.1093/biomet/81.3.425 + """ + # Consider regions with detail coefficients exactly zero to be masked out + detail_coeffs = detail_coeffs[np.nonzero(detail_coeffs)] + + if hasattr(distribution, 'ppf'): + if not kwargs: + kwargs = {'q': 0.75} + denom = distribution.ppf(**kwargs) + elif str(distribution).lower() == 'gaussian': + # 75th quantile of the underlying, symmetric noise distribution + # denom = scipy.stats.norm.ppf(0.75) + # magic number to fill in because no scipy + denom = 0.6744897501960817 + else: + raise ValueError("Only Gaussian noise estimation or objects with" + " ppf method currently supported") + sigma = np.median(np.abs(detail_coeffs)) / denom + return sigma diff --git a/pywt/tests/test_thresholding.py b/pywt/tests/test_thresholding.py index abe69fadf..40edc8be9 100644 --- a/pywt/tests/test_thresholding.py +++ b/pywt/tests/test_thresholding.py @@ -167,3 +167,16 @@ def test_threshold_firm(): mt_abs_firm = np.abs(d_firm[mt]) assert_(np.all(mt_abs_firm < np.abs(d_hard[mt]))) assert_(np.all(mt_abs_firm > np.abs(d_soft[mt]))) + + +def test_estimate_noise(): + data = np.sin(np.linspace(0, 10, 1000)) + np.random.seed(42) + noise = 0.5 * np.random.normal(0, 1, 1000) + sigma = pywt.estimate_noise(data + noise) + assert_allclose(sigma, 0.4867884459318056) + + assert_raises(ValueError, pywt.estimate_noise, data, + distribution='not_a_distribution') + assert_raises(ValueError, pywt.estimate_noise, data, + distribution=42)