Skip to content

Commit be9c1c8

Browse files
authored
Fix NaN gradient in spectrogram (#617)
1 parent 81e6cd1 commit be9c1c8

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

src/audio/spectrogram.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ See [`stft`](@ref) for other arguments.
2525
Spectrogram in the shape `(T, F, B)`, where
2626
`T` is the number of window hops and `F = n_fft ÷ 2 + 1`.
2727
"""
28-
function spectrogram(waveform;
28+
function spectrogram(waveform::AbstractArray{T};
2929
pad::Int = 0, n_fft::Int, hop_length::Int, window,
3030
center::Bool = true, power::Real = 2.0,
3131
normalized::Bool = false, window_normalized::Bool = false,
32-
)
32+
) where T
3333
pad > 0 && (waveform = pad_zeros(waveform, pad; dims=1);)
3434

3535
# Pack batch dimensions.
@@ -41,8 +41,8 @@ function spectrogram(waveform;
4141
window_normalized && (spec = spec .* inv(norm(window));)
4242

4343
if power > 0
44-
p = eltype(waveform)(power)
45-
spec = abs.(spec).^p
44+
p = T(power)
45+
spec = abs.(spec .+ eps(T)).^p
4646
end
4747
return spec
4848
end

test/testsuite/spectral.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,19 @@ function spectral_testsuite(Backend)
110110
spec = spectrogram(x;
111111
n_fft=1024, hop_length=128, window,
112112
center=true, normalized=false)
113-
114113
@test abs.(y).^2 spec
115114

115+
# Gradient with `0`s in spectrogram.
116+
# We add small ϵ to spectrogram before computing power
117+
# to prevent `NaN` in gradient due to `abs(0)`.
118+
x = device(ones(Float32, 1024))
119+
g = Zygote.gradient(x) do x
120+
sum(spectrogram(x;
121+
n_fft=1024, hop_length=128, window,
122+
center=true, normalized=false))
123+
end
124+
@test !any(isnan.(g[1]))
125+
116126
# Batched.
117127
x = device(rand(Float32, 1024, 3))
118128
spec = spectrogram(x;

0 commit comments

Comments
 (0)