Skip to content

Commit 32a72e8

Browse files
Merge pull request #341 from AstitvaAggarwal/dev
ComponentVector handling
2 parents 37f685e + 546b509 commit 32a72e8

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1616
[weakdeps]
1717
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1818
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
19+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1920
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2021
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2122
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
@@ -27,6 +28,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2728
[extensions]
2829
ComponentArraysGPUArraysExt = "GPUArrays"
2930
ComponentArraysKernelAbstractionsExt = "KernelAbstractions"
31+
ComponentArraysMooncakeExt = "Mooncake"
3032
ComponentArraysOptimisersExt = "Optimisers"
3133
ComponentArraysReactantExt = "Reactant"
3234
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
@@ -45,6 +47,7 @@ GPUArrays = "10.3.1, 11"
4547
KernelAbstractions = "0.9.29"
4648
LinearAlgebra = "1.10"
4749
Optimisers = "0.3, 0.4"
50+
Mooncake = "0.5"
4851
Reactant = "0.2.15"
4952
RecursiveArrayTools = "3.8"
5053
ReverseDiff = "1.15"

ext/ComponentArraysMooncakeExt.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module ComponentArraysMooncakeExt
2+
3+
using ComponentArrays, Mooncake
4+
5+
# ComponentVector handling in @from_rrule
6+
function Mooncake.increment_and_get_rdata!(
7+
f::Mooncake.FData{@NamedTuple{data::A, axes::Mooncake.NoFData}},
8+
r::Mooncake.NoRData,
9+
t::A,
10+
) where {P <: Union{Base.IEEEFloat, Complex{<:Base.IEEEFloat}}, A <: Array{P}}
11+
return Mooncake.increment_and_get_rdata!(f.data[:data], r, t)
12+
end
13+
14+
end

0 commit comments

Comments
 (0)