1
1
from math import ceil , floor
2
2
3
- from ._extensions ._pywt import (
4
- ContinuousWavelet ,
5
- DiscreteContinuousWavelet ,
6
- Wavelet ,
7
- _check_dtype ,
8
- )
3
+ from ._extensions ._pywt import (ContinuousWavelet , DiscreteContinuousWavelet ,
4
+ Wavelet , _check_dtype )
9
5
from ._functions import integrate_wavelet , scale2frequency
10
6
from ._utils import AxisError
11
7
16
12
17
13
try :
18
14
import scipy
15
+
19
16
fftmodule = scipy .fft
20
17
next_fast_len = fftmodule .next_fast_len
21
18
except ImportError :
@@ -31,10 +28,19 @@ def next_fast_len(n):
31
28
following this number to take advantage of FFT speedup.
32
29
This fallback is less efficient than `scipy.fftpack.next_fast_len`
33
30
"""
34
- return 2 ** ceil (np .log2 (n ))
35
-
36
-
37
- def cwt (data , scales , wavelet , sampling_period = 1. , method = 'conv' , axis = - 1 ):
31
+ return 2 ** ceil (np .log2 (n ))
32
+
33
+
34
+ def cwt (
35
+ data ,
36
+ scales ,
37
+ wavelet ,
38
+ sampling_period = 1.0 ,
39
+ method = "conv" ,
40
+ axis = - 1 ,
41
+ * ,
42
+ precision = 12 ,
43
+ ):
38
44
"""
39
45
cwt(data, scales, wavelet)
40
46
@@ -70,6 +76,11 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
70
76
axis: int, optional
71
77
Axis over which to compute the CWT. If not given, the last axis is
72
78
used.
79
+ precision: int, optional
80
+ Length of wavelet (2 ** precision) used to compute the CWT. Greater
81
+ will increase resolution, especially for lower and higher scales,
82
+ but compute a bit slower. Too low will distort coefficients
83
+ and their norms, with a zipper-like effect; recommended >= 12.
73
84
74
85
Returns
75
86
-------
@@ -125,16 +136,15 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
125
136
126
137
dt_out = dt_cplx if wavelet .complex_cwt else dt
127
138
out = np .empty ((np .size (scales ),) + data .shape , dtype = dt_out )
128
- precision = 10
129
139
int_psi , x = integrate_wavelet (wavelet , precision = precision )
130
140
int_psi = np .conj (int_psi ) if wavelet .complex_cwt else int_psi
131
141
132
142
# convert int_psi, x to the same precision as the data
133
- dt_psi = dt_cplx if int_psi .dtype .kind == 'c' else dt
143
+ dt_psi = dt_cplx if int_psi .dtype .kind == "c" else dt
134
144
int_psi = np .asarray (int_psi , dtype = dt_psi )
135
145
x = np .asarray (x , dtype = data .real .dtype )
136
146
137
- if method == ' fft' :
147
+ if method == " fft" :
138
148
size_scale0 = - 1
139
149
fft_data = None
140
150
elif method != "conv" :
@@ -156,7 +166,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
156
166
j = np .extract (j < int_psi .size , j )
157
167
int_psi_scale = int_psi [j ][::- 1 ]
158
168
159
- if method == ' conv' :
169
+ if method == " conv" :
160
170
if data .ndim == 1 :
161
171
conv = np .convolve (data , int_psi_scale )
162
172
else :
@@ -172,27 +182,24 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
172
182
# - optimal FFT complexity
173
183
# - to be larger than the two signals length to avoid circular
174
184
# convolution
175
- size_scale = next_fast_len (
176
- data .shape [- 1 ] + int_psi_scale .size - 1
177
- )
185
+ size_scale = next_fast_len (data .shape [- 1 ] + int_psi_scale .size - 1 )
178
186
if size_scale != size_scale0 :
179
187
# Must recompute fft_data when the padding size changes.
180
188
fft_data = fftmodule .fft (data , size_scale , axis = - 1 )
181
189
size_scale0 = size_scale
182
190
fft_wav = fftmodule .fft (int_psi_scale , size_scale , axis = - 1 )
183
191
conv = fftmodule .ifft (fft_wav * fft_data , axis = - 1 )
184
- conv = conv [..., :data .shape [- 1 ] + int_psi_scale .size - 1 ]
192
+ conv = conv [..., : data .shape [- 1 ] + int_psi_scale .size - 1 ]
185
193
186
- coef = - np .sqrt (scale ) * np .diff (conv , axis = - 1 )
187
- if out .dtype .kind != 'c' :
194
+ coef = - np .sqrt (scale ) * np .diff (conv , axis = - 1 )
195
+ if out .dtype .kind != "c" :
188
196
coef = coef .real
189
197
# transform axis is always -1 due to the data reshape above
190
- d = (coef .shape [- 1 ] - data .shape [- 1 ]) / 2.
198
+ d = (coef .shape [- 1 ] - data .shape [- 1 ]) / 2.0
191
199
if d > 0 :
192
- coef = coef [..., floor (d ): - ceil (d )]
200
+ coef = coef [..., floor (d ) : - ceil (d )]
193
201
elif d < 0 :
194
- raise ValueError (
195
- f"Selected scale of { scale } too small." )
202
+ raise ValueError (f"Selected scale of { scale } too small." )
196
203
if data .ndim > 1 :
197
204
# restore original data shape and axis position
198
205
coef = coef .reshape (data_shape_pre )
0 commit comments