Skip to content

Commit d5e96cd

Browse files
authored
fix: use functors for testing wrapped arrays (#1134)
1 parent fdb0170 commit d5e96cd

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

lib/MLDataDevices/Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.6.4"
4+
version = "1.6.5"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -15,6 +15,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1515
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1616
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
1717
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
18+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1819
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1920
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
2021
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -32,8 +33,9 @@ oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
3233
[extensions]
3334
MLDataDevicesAMDGPUExt = "AMDGPU"
3435
MLDataDevicesCUDAExt = "CUDA"
35-
MLDataDevicesChainRulesExt = "ChainRules"
3636
MLDataDevicesChainRulesCoreExt = "ChainRulesCore"
37+
MLDataDevicesChainRulesExt = "ChainRules"
38+
MLDataDevicesComponentArraysExt = "ComponentArrays"
3739
MLDataDevicesFillArraysExt = "FillArrays"
3840
MLDataDevicesGPUArraysExt = "GPUArrays"
3941
MLDataDevicesMLUtilsExt = "MLUtils"
@@ -55,6 +57,7 @@ CUDA = "5.2"
5557
ChainRules = "1.51"
5658
ChainRulesCore = "1.23"
5759
Compat = "4.16"
60+
ComponentArrays = "0.15.18"
5861
FillArrays = "1"
5962
Functors = "0.5"
6063
GPUArrays = "10, 11"
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module MLDataDevicesComponentArraysExt
2+
3+
using ComponentArrays: ComponentArrays
4+
using MLDataDevices: MLDataDevices
5+
6+
MLDataDevices.isleaf(::ComponentArrays.ComponentArray) = true
7+
8+
end

lib/MLDataDevices/src/public.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,5 +399,7 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct
399399
"""
400400
isleaf(x) = Functors.isleaf(x)
401401

402-
isleaf(::AbstractArray{T}) where {T} = isbitstype(T) || T <: Number # BigFloat and such are not bitstype
403-
isleaf(::Adapt.WrappedArray) = false
402+
function isleaf(x::AbstractArray{T}) where {T}
403+
parent(x) !== x && return Functors.isleaf(x)
404+
return isbitstype(T) || T <: Number # BigFloat and such are not bitstype
405+
end

0 commit comments

Comments
 (0)