Skip to content

Commit

Permalink
Merge pull request #2054 from christiangnrd/0.13.5_caching_fix
Browse files Browse the repository at this point in the history
Make params non-differentiable (Closes #2040 & #2048)
  • Loading branch information
ToucheSir authored Aug 30, 2022
2 parents c04210c + f5793b5 commit 31e4dd0
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ function params(m...)
return ps
end

# Allows caching of the parameters when params is called within gradient() to fix #2040.
@non_differentiable params(m...)

struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
Expand Down

0 comments on commit 31e4dd0

Please sign in to comment.