Conversation
…`s of offsets (also simplifying `_aux_children`); fix broken test for issue FluxML#62
src/destructure.jl
Outdated
| if p isa ProjectTo # e.g. Array, NamedTuple | ||
| p(y) | ||
| else # p === identity for unknown structs | ||
| y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields |
There was a problem hiding this comment.
Note to self, this I need to think about. Some of this complication was working around things that are now fixed in CRC.jl, if I remember right.
There was a problem hiding this comment.
Yeah, admittedly this line took some trial and error and is a little bit above my pay-grade. I managed to convince myself, but perhaps there's something cleaner.
There was a problem hiding this comment.
Ok, I think I finally understand what's going on. Sorry it took a while.
re constructs another Skip containing the gradient, and backing turns that into a NamedTuple with the same field names, which is what Tangent wants.
The only way I can see this failing is this: If the primal type's constructor is fussy about what types it can accept, then it may not be happy to accept something which is valid as its gradient. E.g. if there is only Skip(::AbstractLayer), and re tries to make one with a Tangent.
There was a problem hiding this comment.
No worries! Yes, I struggled with that edge case too. Unfortunately I think it's quite tricky to work around. For example, suppose you have a user-defined functor(m::MyModel) = (m.w,), w -> .... Then:
- In general there's no way to reconstruct
MyModel(or even aNamedTupleof fields/values) withoutre, as you do not know the corresponding field name given only(m.w,), but - As you say, if the primal constructor isn't sufficiently generic then it won't be able to store
Tangent/Nothing/etc. values in it's fields and will error beforebackingcan unpack it again
Avoiding re would be ideal, but I think that would require functor to always return NamedTuples on custom structs. I noticed that this is the default in @functor, though, so maybe it's not such a painful requirement? In the mean time I can at least add a branch that would avoid re for structs that are functored to NamedTuples.
There was a problem hiding this comment.
In fact there's another problem I didn't spot before, what a mess:
julia> ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]); # from tests: a,c are functor-ed, and only a is trainable
julia> v2, re2 = destructure(ac)
([1.0, 2.0], Restructure(TwoThirds, ..., 2))
julia> gradient(ac) do x # with Tangent{typeof(x), typeof(y)}(y)
w2, _ = destructure(x)
w2[2]^2
end
((a = [0.0, 4.0], b = nothing, c = [4.0, 5.0]),)
# Same, with z = backing(re(y)) :
julia> gradient(ac) do x
w2, _ = destructure(x)
w2[2]^2
end
┌ Info: last case
│ x = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
│ y = (a = [0.0, 4.0], c = [4.0, 5.0])
└ z = NamedTuple{(:a, :b, :c), Tuple{Any, Any, Any}}(([0.0, 4.0], [3.0], [4.0, 5.0]))
((a = [0.0, 4.0], b = [3.0], c = [4.0, 5.0]),)There was a problem hiding this comment.
Oh yikes. That's a good example, hits all the pain points at once. If I'm understanding correctly, the gradient should be ((a = [0.0, 4.0], b = nothing, c = nothing),), right?
I think the problem is the _trainmap above; it populates the nothing values from _trainable (non-trainable fields) with the primal values, when they should be NoT. That's how the b and/or c values get back in there.
There was a problem hiding this comment.
Yes, I think _trainmap needs to do something isnothing(t) ? NoT : f(t, a) here. That's where c = [4.0, 5.0] is coming from.
But b = [3.0] is coming from this PR's trick of calling the reconstructor made by @functor:
julia> ch, re = Functors.functor(ac)
((a = [1.0, 2.0], c = [4.0, 5.0]), var"#1#2"{TwoThirds}(TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])))
julia> re((a = [10, 20], c = nothing))
TwoThirds([10, 20], [3.0], nothing)
There was a problem hiding this comment.
Gotcha. So on top of the modified _trainmap to fix c, one would still have to filter backing(re(y)) to replace repopulated primal values which aren't functor-ed with NoT in order to fix b.
EDIT: But, based on the output of Tangent{typeof(x), typeof(y)}(y), maybe the modified _trainmap alone would be enough and backing(re(y)) isn't needed after all, as Tangent will assign NoT to omitted fields in y automatically.
EDIT 2: Never mind, that would still fail for children which aren't fields, like Skip.
There was a problem hiding this comment.
Alright pushed something that works for both Skip and your TwoThirds example (modified _trainmap + filtering backing(re(y))). But since it uses re it would still fail for fussy constructors.
…h are not `trainable`; filter primal values from `backing(re(y))`
This adds a couple small changes on top of this draft PR in order to fix #62:
Offsetto fix the issue mentioned in Attempt to fix #62 #63 for array of arrays. For example, the offset structure forx = [[1.0, 2.0]]is now something likeo = [Offset(4)]which is not leaflike, compared too = [4]previously. This also opens the door to storing more information in this wrapper struct (original array size? eltype?), but that doesn't seem necessary at this timey = backing(re(y))allows forfunctor(x)to return children which aren't its own fields:yis first restructured to match the structure ofx, and then theNamedTuplebacking forre(y)is extracted and passed toTangent. It has the added benefit of adding some symmetry with_trainable_biwalkwhich naturally restructures the output of_trainmap, whereas_Tangent_biwalkpreviously did notCloses #63 (replaces).