use ForwardDiff.jacobian in place of Zygote.forward_jacobian#1468
use ForwardDiff.jacobian in place of Zygote.forward_jacobian#1468vpuri3 wants to merge 1 commit intoFluxML:masterfrom
Conversation
|
There are some enzyme related errors in NNlib integration tests but they seem unrelated to this PR. |
|
@ToucheSir, LMK I need to add more tests. here's a working MWE with Lux. This also resolves #1348 with the change in this PR, this code is working: using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff
CUDA.allowscalar(false)
#==========================#
function testhessian(
NN::Lux.AbstractExplicitLayer,
data::Tuple;
device = cpu_device(),
)
p, st = Lux.setup(Random.default_rng(), NN)
st = Lux.testmode(st)
p = ComponentArray(p)
xdata, ydata = data |> device
p, st = (p, st) |> device
function loss(optx)
ypred, _ = NN(xdata, optx, st)
sum(abs2, ydata - ypred)
end
g(p) = Zygote.gradient(loss, p)[1]
H(p) = ForwardDiff.jacobian(g, p)
Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))
data = ntuple(_ -> rand(1, 10), 2)
device = Lux.gpu_device()
H = testhessian(NN, data; device)julia> include("hess.jl")
10×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
0.236781 -0.075257 -1.20583 0.31846 -0.101217 -1.62179 -0.713834 0.503548 -1.14138 1.98508
-0.075257 0.0239192 0.383253 -0.101217 0.0321702 0.515458 0.0296168 -0.780695 0.362769 -0.630924
-1.20583 0.383253 6.1408 -1.62179 0.515458 8.2591 0.474545 -2.56436 5.19194 -10.1092
0.318461 -0.101217 -1.62179 0.514738 -0.163601 -2.62135 -2.09317 0.677249 -1.53511 3.20854
-0.101217 0.0321702 0.515458 -0.163601 0.0519977 0.833151 0.0398333 -2.18309 0.487909 -1.01978
-1.62179 0.515458 8.2591 -2.62135 0.833151 13.3494 0.638242 -3.44895 5.84984 -16.3398
-0.713834 0.0296168 0.474545 -2.09317 0.0398333 0.638242 0.0366717 -0.198167 0.449183 -0.781213
0.503548 -0.780695 -2.56436 0.677249 -2.18309 -3.44895 -0.198167 1.07086 -2.4273 4.22154
-1.14138 0.362769 5.19194 -1.53511 0.487909 5.84984 0.449183 -2.4273 5.50193 -9.56889
1.98508 -0.630924 -10.1092 3.20854 -1.01978 -16.3398 -0.781213 4.22154 -9.56889 20.0
(hess) pkg> st
Status `~/.julia/dev/GeometryLearning.jl/hess/Project.toml`
[052768ef] CUDA v5.0.0
[b0b7db55] ComponentArrays v0.15.4
[f6369f11] ForwardDiff v0.10.36
[b2108857] Lux v0.5.8
[d0bbae9a] LuxCUDA v0.3.1
[e88e6eb3] Zygote v0.6.67 `~/.julia/dev/Zygote` |
ToucheSir
left a comment
There was a problem hiding this comment.
I'm a little confused. This looks like the same change as #1270, just with no tests? My comment at #1270 (comment) and @mcabbott's at #1270 (comment) still very much apply, so those need to be addressed.
|
Did this ever reach a conclusion? I'm in need of the ability to take the jacobian with respect to the inputs of a (Lux) model output and then optimize that object using gradient descent updates on the (Lux) model parameters. Something like the following Or should I be looking towards JAX for this sort of thing? The use case is thermodynamics. |
|
That's a better question for the SciML/Lux help channels, not this issue tracker. |
|
This PR changes the implementation used internally for FwdDiff-over-Zygote. It didn't get much attention as it was a little unclear what this solves -- see requests above for tests which fail before the change. Your example wants to do Zygote-over-ForwardDiff, which won't work, and would not be changed by this PR. (Zygote has a rule for |
Pursuant to #1270