Skip to content

Commit dc47bc5

Browse files
committed
Fix type-stability of optimizer
1 parent 09c734d commit dc47bc5

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NerfUtils"
22
uuid = "99c1d5ce-7c61-4a25-a107-a5ade2e2a8e4"
33
authors = ["Anton Smirnov <tonysmn97@gmail.com>"]
4-
version = "0.1.3"
4+
version = "0.1.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/nn/adam.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
_eltype(θ::T) where T <: Union{Tuple, NamedTuple} = _eltype(first(θ))
2+
_eltype(θ::T) where T <: AbstractVector = T
3+
_eltype(θ::T) where T = _eltype(reshape(θ, :))
4+
15
"""
26
Adam(kab, θ; kwargs...)
37
@@ -20,7 +24,8 @@ end
2024
KernelAbstractions.get_backend(opt::Adam) = get_backend(first(opt.μ))
2125

2226
function Adam(kab, θ; kwargs...)
23-
μ, ν = [], [] # FIXME unstable
27+
T = _eltype(θ)
28+
μ, ν = T[], T[]
2429
_add_moments!(μ, ν, θ, kab)
2530
Adam(; μ, ν, kwargs...)
2631
end

0 commit comments

Comments
 (0)