1
1
from math import floor , ceil
2
+ from scipy import interpolate
2
3
3
4
from ._extensions ._pywt import (DiscreteContinuousWavelet , ContinuousWavelet ,
4
5
Wavelet , _check_dtype )
5
- from ._functions import integrate_wavelet , scale2frequency
6
+ from ._functions import evaluate_wavelet , scale2frequency
6
7
7
8
8
9
__all__ = ["cwt" ]
@@ -123,13 +124,16 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
123
124
dt_out = dt_cplx if wavelet .complex_cwt else dt
124
125
out = np .empty ((np .size (scales ),) + data .shape , dtype = dt_out )
125
126
precision = 10
126
- int_psi , x = integrate_wavelet (wavelet , precision = precision )
127
- int_psi = np .conj (int_psi ) if wavelet .complex_cwt else int_psi
127
+ psi , x = evaluate_wavelet (wavelet , precision = precision )
128
+ psi = np .conj (psi ) if wavelet .complex_cwt else psi
128
129
129
- # convert int_psi , x to the same precision as the data
130
- dt_psi = dt_cplx if int_psi .dtype .kind == 'c' else dt
131
- int_psi = np .asarray (int_psi , dtype = dt_psi )
130
+ # convert psi , x to the same precision as the data
131
+ dt_psi = dt_cplx if psi .dtype .kind == 'c' else dt
132
+ psi = np .asarray (psi , dtype = dt_psi )
132
133
x = np .asarray (x , dtype = data .real .dtype )
134
+ # FIXME: The original wavelet function could be used here, but
135
+ # interpolation is computationally more efficient.
136
+ wavefun = interpolate .interp1d (x , psi , kind = 'cubic' , assume_sorted = True )
133
137
134
138
if method == 'fft' :
135
139
size_scale0 = - 1
@@ -146,41 +150,44 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
146
150
data = data .reshape ((- 1 , data .shape [- 1 ]))
147
151
148
152
for i , scale in enumerate (scales ):
149
- step = x [1 ] - x [0 ]
150
- j = np .arange (scale * (x [- 1 ] - x [0 ]) + 1 ) / (scale * step )
151
- j = j .astype (int ) # floor
152
- if j [- 1 ] >= int_psi .size :
153
- j = np .extract (j < int_psi .size , j )
154
- int_psi_scale = int_psi [j ][::- 1 ]
153
+ # FIXME: Boundary points might be discarded erroneously
154
+ if np .sign (x [0 ])* np .sign (x [- 1 ])< 0 :
155
+ # Wavelet is sampled at 0.0 if the range includes it
156
+ xsl = np .arange (0.0 , x [0 ], - 1.0 / scale )
157
+ xsr = np .arange (0.0 , x [- 1 ], 1.0 / scale )
158
+ xs = np .concatenate ((xsl [:0 :- 1 ], xsr ))
159
+ else :
160
+ xs = np .arange (x [0 ], x [- 1 ], 1.0 / scale )
161
+ psi_scale = wavefun (xs )[::- 1 ]
155
162
156
163
if method == 'conv' :
157
164
if data .ndim == 1 :
158
- conv = np .convolve (data , int_psi_scale )
165
+ conv = np .convolve (data , psi_scale )
159
166
else :
160
167
# batch convolution via loop
161
168
conv_shape = list (data .shape )
162
- conv_shape [- 1 ] += int_psi_scale .size - 1
169
+ conv_shape [- 1 ] += psi_scale .size - 1
163
170
conv_shape = tuple (conv_shape )
164
171
conv = np .empty (conv_shape , dtype = dt_out )
165
172
for n in range (data .shape [0 ]):
166
- conv [n , :] = np .convolve (data [n ], int_psi_scale )
173
+ conv [n , :] = np .convolve (data [n ], psi_scale )
167
174
else :
168
175
# The padding is selected for:
169
176
# - optimal FFT complexity
170
177
# - to be larger than the two signals length to avoid circular
171
178
# convolution
172
179
size_scale = next_fast_len (
173
- data .shape [- 1 ] + int_psi_scale .size - 1
180
+ data .shape [- 1 ] + psi_scale .size - 1
174
181
)
175
182
if size_scale != size_scale0 :
176
183
# Must recompute fft_data when the padding size changes.
177
184
fft_data = fftmodule .fft (data , size_scale , axis = - 1 )
178
185
size_scale0 = size_scale
179
- fft_wav = fftmodule .fft (int_psi_scale , size_scale , axis = - 1 )
186
+ fft_wav = fftmodule .fft (psi_scale , size_scale , axis = - 1 )
180
187
conv = fftmodule .ifft (fft_wav * fft_data , axis = - 1 )
181
- conv = conv [..., :data .shape [- 1 ] + int_psi_scale .size - 1 ]
188
+ conv = conv [..., :data .shape [- 1 ] + psi_scale .size - 1 ]
182
189
183
- coef = - np .sqrt (scale ) * np . diff ( conv , axis = - 1 )
190
+ coef = conv / np .sqrt (scale )
184
191
if out .dtype .kind != 'c' :
185
192
coef = coef .real
186
193
# transform axis is always -1 due to the data reshape above
0 commit comments