Skip to content

Commit 87e2f12

Browse files
Merge pull request #992 from DhairyaLGandhi/dg/941
Differentiate `push!` with implicit Params
2 parents 18a6f2a + ce8eb91 commit 87e2f12

File tree

2 files changed

+60
-4
lines changed

2 files changed

+60
-4
lines changed

src/compiler/interface.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,15 @@ Base.adjoint(f::Function) = x -> gradient(f, x)[1]
6565

6666
# TODO store ids only
6767
struct Params
68-
order::Buffer{Any, Vector{Any}}
68+
order::Buffer # {Any, Vector{Any}}
6969
params::IdSet{Any}
70-
Params() = new(Buffer([], false), IdSet())
7170
end
7271

72+
Params() = Params(Buffer([], false), IdSet())
73+
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
74+
Params(ps::Params) = ps
75+
Params(xs::Tuple) = Params(collect(xs))
76+
7377
@forward Params.order Base.iterate, Base.length, Base.getindex
7478
@forward Params.params Base.in
7579

@@ -103,6 +107,20 @@ function Base.push!(ps::Params, x)
103107
return ps
104108
end
105109

110+
@adjoint! function Base.push!(xs::IdSet, x...)
111+
l = length(x)
112+
push!(xs, x...), Δ -> begin
113+
(Δ, ntuple(_ -> nothing, l)...)
114+
end
115+
end
116+
117+
@adjoint! function Base.push!(xs::Params, x::AbstractArray{T}...) where T
118+
sz_x = size.(x)
119+
push!(xs, x...), Δ -> begin
120+
(Δ, map(x -> Ones{T}(x...), sz_x)...)
121+
end
122+
end
123+
106124
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
107125

108126
function Base.delete!(ps::Params, x)
@@ -114,8 +132,6 @@ function Base.delete!(ps::Params, x)
114132
return ps
115133
end
116134

117-
Params(xs) = push!(Params(), xs...)
118-
119135
Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)
120136

121137
Base.:(==)(x::Params, y::Params) = x.order.data == y.order.data

test/interface.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,44 @@ end
163163
@test all(abs.(gs[w]) .<= 1e-5)
164164
@test all(abs.(gs[b]) .<= 1e-5)
165165
end
166+
167+
@testset "Params nesting" begin
168+
struct Dense{F,T,S}
169+
W::T
170+
b::S
171+
σ::F
172+
end
173+
174+
(d::Dense)(x) = d.σ.(d.W * x .+ d.b)
175+
d = Dense(ones(Float32, 3,3), zeros(Float32, 3), identity)
176+
ps = Zygote.Params([d.W, d.b])
177+
r = ones(Float32, 3,3)
178+
179+
gs = gradient(ps) do
180+
p, pb = pullback(ps) do
181+
sum(d(r))
182+
end
183+
g = pb(p)
184+
sum(g[d.W]) # + sum(g[d.b])
185+
end
186+
187+
@test gs[d.W] fill(81f0, (3,3))
188+
189+
# Test L2
190+
l2g = gradient(ps) do
191+
sum(sum(x .^ 2) for x in ps)
192+
end
193+
@test l2g[d.W] fill(2.f0, size(d.W))
194+
@test l2g[d.b] fill(0.f0, size(d.b))
195+
196+
# Can be safely removed - creating Params within
197+
# gradient calls may break between releases.
198+
sgs = gradient(ps) do
199+
sum(sum(x) for x in Zygote.Params([d.W, d.b]))
200+
end
201+
@test sgs[d.W] fill(1.f0, size(d.W))
202+
@test sgs[d.b] fill(1.f0, size(d.b))
203+
end
204+
205+
166206
end

0 commit comments

Comments
 (0)