Skip to content

Commit 6a5a6bc

Browse files
Proor of concept on how to fix issues PyWavelets#531, PyWavelets#535 and PyWavelets#570
1 parent 196b5d3 commit 6a5a6bc

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

pywt/_cwt.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from math import floor, ceil
2+
from scipy import interpolate
23

34
from ._extensions._pywt import (DiscreteContinuousWavelet, ContinuousWavelet,
45
Wavelet, _check_dtype)
5-
from ._functions import integrate_wavelet, scale2frequency
6+
from ._functions import evaluate_wavelet, scale2frequency
67

78

89
__all__ = ["cwt"]
@@ -123,13 +124,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
123124
dt_out = dt_cplx if wavelet.complex_cwt else dt
124125
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
125126
precision = 10
126-
int_psi, x = integrate_wavelet(wavelet, precision=precision)
127-
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
127+
psi, x = evaluate_wavelet(wavelet, precision=precision)
128+
psi = np.conj(psi) if wavelet.complex_cwt else psi
128129

129-
# convert int_psi, x to the same precision as the data
130-
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
131-
int_psi = np.asarray(int_psi, dtype=dt_psi)
130+
# convert psi, x to the same precision as the data
131+
dt_psi = dt_cplx if psi.dtype.kind == 'c' else dt
132+
psi = np.asarray(psi, dtype=dt_psi)
132133
x = np.asarray(x, dtype=data.real.dtype)
134+
# FIXME: The original wavelet function could be used here, but
135+
# interpolation is computationally more efficient.
136+
wavefun = interpolate.interp1d(x, psi, kind='cubic', assume_sorted=True)
133137

134138
if method == 'fft':
135139
size_scale0 = -1
@@ -146,41 +150,44 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
146150
data = data.reshape((-1, data.shape[-1]))
147151

148152
for i, scale in enumerate(scales):
149-
step = x[1] - x[0]
150-
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
151-
j = j.astype(int) # floor
152-
if j[-1] >= int_psi.size:
153-
j = np.extract(j < int_psi.size, j)
154-
int_psi_scale = int_psi[j][::-1]
153+
# FIXME: Boundary points might be discarded erroneously
154+
if np.sign(x[0])*np.sign(x[-1])<0:
155+
# Wavelet is sampled at 0.0 if the range includes it
156+
xsl = np.arange(0.0, x[0], -1.0/scale)
157+
xsr = np.arange(0.0, x[-1], 1.0/scale)
158+
xs = np.concatenate((xsl[:0:-1], xsr))
159+
else:
160+
xs = np.arange(x[0], x[-1], 1.0/scale)
161+
psi_scale = wavefun(xs)[::-1]
155162

156163
if method == 'conv':
157164
if data.ndim == 1:
158-
conv = np.convolve(data, int_psi_scale)
165+
conv = np.convolve(data, psi_scale)
159166
else:
160167
# batch convolution via loop
161168
conv_shape = list(data.shape)
162-
conv_shape[-1] += int_psi_scale.size - 1
169+
conv_shape[-1] += psi_scale.size - 1
163170
conv_shape = tuple(conv_shape)
164171
conv = np.empty(conv_shape, dtype=dt_out)
165172
for n in range(data.shape[0]):
166-
conv[n, :] = np.convolve(data[n], int_psi_scale)
173+
conv[n, :] = np.convolve(data[n], psi_scale)
167174
else:
168175
# The padding is selected for:
169176
# - optimal FFT complexity
170177
# - to be larger than the two signals length to avoid circular
171178
# convolution
172179
size_scale = next_fast_len(
173-
data.shape[-1] + int_psi_scale.size - 1
180+
data.shape[-1] + psi_scale.size - 1
174181
)
175182
if size_scale != size_scale0:
176183
# Must recompute fft_data when the padding size changes.
177184
fft_data = fftmodule.fft(data, size_scale, axis=-1)
178185
size_scale0 = size_scale
179-
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
186+
fft_wav = fftmodule.fft(psi_scale, size_scale, axis=-1)
180187
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
181-
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
188+
conv = conv[..., :data.shape[-1] + psi_scale.size - 1]
182189

183-
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
190+
coef = conv / np.sqrt(scale)
184191
if out.dtype.kind != 'c':
185192
coef = coef.real
186193
# transform axis is always -1 due to the data reshape above

pywt/_functions.py

+53-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from ._extensions._pywt import DiscreteContinuousWavelet, Wavelet, ContinuousWavelet
1818

1919

20-
__all__ = ["integrate_wavelet", "central_frequency", "scale2frequency", "qmf",
21-
"orthogonal_filter_bank",
20+
__all__ = ["integrate_wavelet", "evaluate_wavelet", "central_frequency",
21+
"scale2frequency", "qmf", "orthogonal_filter_bank",
2222
"intwave", "centrfrq", "scal2frq", "orthfilt"]
2323

2424

@@ -119,6 +119,57 @@ def integrate_wavelet(wavelet, precision=8):
119119
return _integrate(psi_d, step), _integrate(psi_r, step), x
120120

121121

122+
def evaluate_wavelet(wavelet, precision=8):
123+
"""
124+
Evaluate `psi` wavelet function between lower and upper bound.
125+
126+
Parameters
127+
----------
128+
wavelet : Wavelet instance or str
129+
Wavelet to evaluate. If a string, should be the name of a wavelet.
130+
precision : int, optional
131+
Number of wavelet function points computed with Wavelet's
132+
wavefun(level=precision) method (default: 8).
133+
134+
Returns
135+
-------
136+
[psi, x] :
137+
for orthogonal wavelets
138+
[psi_d, psi_r, x] :
139+
for other wavelets
140+
141+
142+
Examples
143+
--------
144+
>>> from pywt import Wavelet, evaluate_wavelet
145+
>>> wavelet1 = Wavelet('db2')
146+
>>> [psi, x] = evaluate_wavelet(wavelet1, precision=5)
147+
>>> wavelet2 = Wavelet('bior1.3')
148+
>>> [psi_d, psi_r, x] = evaluate_wavelet(wavelet2, precision=5)
149+
150+
"""
151+
152+
if type(wavelet) in (tuple, list):
153+
psi, x = np.asarray(wavelet[0]), np.asarray(wavelet[1])
154+
return psi, x
155+
elif not isinstance(wavelet, (Wavelet, ContinuousWavelet)):
156+
wavelet = DiscreteContinuousWavelet(wavelet)
157+
158+
functions_approximations = wavelet.wavefun(precision)
159+
160+
if len(functions_approximations) == 2: # continuous wavelet
161+
psi, x = functions_approximations
162+
return psi, x
163+
164+
elif len(functions_approximations) == 3: # orthogonal wavelet
165+
phi, psi, x = functions_approximations
166+
return psi, x
167+
168+
else: # biorthogonal wavelet
169+
phi_d, psi_d, phi_r, psi_r, x = functions_approximations
170+
return psi_d, psi_r, x
171+
172+
122173
def central_frequency(wavelet, precision=8):
123174
"""
124175
Computes the central frequency of the `psi` wavelet function.

0 commit comments

Comments
 (0)