Skip to content

Commit 1b451e5

Browse files
committed
Cwt: Since PyWavelets#570 has not been merged yet, this is to quickly implement the precision option
1 parent d95a01f commit 1b451e5

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

pywt/_cwt.py

+31-24
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
from math import ceil, floor
22

3-
from ._extensions._pywt import (
4-
ContinuousWavelet,
5-
DiscreteContinuousWavelet,
6-
Wavelet,
7-
_check_dtype,
8-
)
3+
from ._extensions._pywt import (ContinuousWavelet, DiscreteContinuousWavelet,
4+
Wavelet, _check_dtype)
95
from ._functions import integrate_wavelet, scale2frequency
106
from ._utils import AxisError
117

@@ -16,6 +12,7 @@
1612

1713
try:
1814
import scipy
15+
1916
fftmodule = scipy.fft
2017
next_fast_len = fftmodule.next_fast_len
2118
except ImportError:
@@ -31,10 +28,19 @@ def next_fast_len(n):
3128
following this number to take advantage of FFT speedup.
3229
This fallback is less efficient than `scipy.fftpack.next_fast_len`
3330
"""
34-
return 2**ceil(np.log2(n))
35-
36-
37-
def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
31+
return 2 ** ceil(np.log2(n))
32+
33+
34+
def cwt(
35+
data,
36+
scales,
37+
wavelet,
38+
sampling_period=1.0,
39+
method="conv",
40+
axis=-1,
41+
*,
42+
precision=12,
43+
):
3844
"""
3945
cwt(data, scales, wavelet)
4046
@@ -70,6 +76,11 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
7076
axis: int, optional
7177
Axis over which to compute the CWT. If not given, the last axis is
7278
used.
79+
precision: int, optional
80+
Length of wavelet (2 ** precision) used to compute the CWT. Greater
81+
will increase resolution, especially for lower and higher scales,
82+
but compute a bit slower. Too low will distort coefficients
83+
and their norms, with a zipper-like effect; recommended >= 12.
7384
7485
Returns
7586
-------
@@ -125,16 +136,15 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
125136

126137
dt_out = dt_cplx if wavelet.complex_cwt else dt
127138
out = np.empty((np.size(scales),) + data.shape, dtype=dt_out)
128-
precision = 10
129139
int_psi, x = integrate_wavelet(wavelet, precision=precision)
130140
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
131141

132142
# convert int_psi, x to the same precision as the data
133-
dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt
143+
dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt
134144
int_psi = np.asarray(int_psi, dtype=dt_psi)
135145
x = np.asarray(x, dtype=data.real.dtype)
136146

137-
if method == 'fft':
147+
if method == "fft":
138148
size_scale0 = -1
139149
fft_data = None
140150
elif method != "conv":
@@ -156,7 +166,7 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
156166
j = np.extract(j < int_psi.size, j)
157167
int_psi_scale = int_psi[j][::-1]
158168

159-
if method == 'conv':
169+
if method == "conv":
160170
if data.ndim == 1:
161171
conv = np.convolve(data, int_psi_scale)
162172
else:
@@ -172,27 +182,24 @@ def cwt(data, scales, wavelet, sampling_period=1., method='conv', axis=-1):
172182
# - optimal FFT complexity
173183
# - to be larger than the two signals length to avoid circular
174184
# convolution
175-
size_scale = next_fast_len(
176-
data.shape[-1] + int_psi_scale.size - 1
177-
)
185+
size_scale = next_fast_len(data.shape[-1] + int_psi_scale.size - 1)
178186
if size_scale != size_scale0:
179187
# Must recompute fft_data when the padding size changes.
180188
fft_data = fftmodule.fft(data, size_scale, axis=-1)
181189
size_scale0 = size_scale
182190
fft_wav = fftmodule.fft(int_psi_scale, size_scale, axis=-1)
183191
conv = fftmodule.ifft(fft_wav * fft_data, axis=-1)
184-
conv = conv[..., :data.shape[-1] + int_psi_scale.size - 1]
192+
conv = conv[..., : data.shape[-1] + int_psi_scale.size - 1]
185193

186-
coef = - np.sqrt(scale) * np.diff(conv, axis=-1)
187-
if out.dtype.kind != 'c':
194+
coef = -np.sqrt(scale) * np.diff(conv, axis=-1)
195+
if out.dtype.kind != "c":
188196
coef = coef.real
189197
# transform axis is always -1 due to the data reshape above
190-
d = (coef.shape[-1] - data.shape[-1]) / 2.
198+
d = (coef.shape[-1] - data.shape[-1]) / 2.0
191199
if d > 0:
192-
coef = coef[..., floor(d):-ceil(d)]
200+
coef = coef[..., floor(d) : -ceil(d)]
193201
elif d < 0:
194-
raise ValueError(
195-
f"Selected scale of {scale} too small.")
202+
raise ValueError(f"Selected scale of {scale} too small.")
196203
if data.ndim > 1:
197204
# restore original data shape and axis position
198205
coef = coef.reshape(data_shape_pre)

0 commit comments

Comments
 (0)