Skip to content

Commit 77fceb1

Browse files
authored
Support Mooncake 0.5 (#2653)
* Update friendly_tangents parameter in withgradient * Update Mooncake version to 0.5.1 * Fix argument naming in prepare_gradient_cache call
1 parent d15c7dc commit 77fceb1

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ MLDataDevices = "1.4.2"
6363
MLUtils = "0.4"
6464
MPI = "0.20.19"
6565
MacroTools = "0.5"
66-
Mooncake = "0.4"
66+
Mooncake = "0.5.1"
6767
NCCL = "0.1.1"
6868
NNlib = "0.9.22"
6969
OneHotArrays = "0.2.4"

ext/FluxMooncakeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function Flux.gradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F
99
end
1010

1111
function Flux.withgradient(f::F, adtype::AutoMooncake, args::Vararg{Any,N}) where {F,N}
12-
cache = Mooncake.prepare_gradient_cache(f, args...; friendly_tangents=true)
12+
cache = Mooncake.prepare_gradient_cache(f, args...; config=Mooncake.Config(friendly_tangents=true))
1313
val, grads = Mooncake.value_and_gradient!!(cache, f, args...)
1414
return (val=val, grad=grads[2:end])
1515
end

0 commit comments

Comments
 (0)