diff --git a/Project.toml b/Project.toml index ec8ed4fd84..cc771163c6 100644 --- a/Project.toml +++ b/Project.toml @@ -63,7 +63,7 @@ MLDataDevices = "1.4.2" MLUtils = "0.4" MPI = "0.20.19" MacroTools = "0.5" -Mooncake = "0.4" +Mooncake = "0.5.1" NCCL = "0.1.1" NNlib = "0.9.22" OneHotArrays = "0.2.4" diff --git a/ext/FluxMooncakeExt.jl b/ext/FluxMooncakeExt.jl index d371bfb952..a30d601968 100644 --- a/ext/FluxMooncakeExt.jl +++ b/ext/FluxMooncakeExt.jl @@ -9,7 +9,7 @@ function Flux.gradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F end function Flux.withgradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F,N} - cache = Mooncake.prepare_gradient_cache(f, args...; friendly_tangents=true) + cache = Mooncake.prepare_gradient_cache(f, args...; config=Mooncake.Config(friendly_tangents=true)) val, grads = Mooncake.value_and_gradient!!(cache, f, args...) return (val=val, grad=grads[2:end]) end