Skip to content

Commit d116e6c

Browse files
committed
fix: dispatches
1 parent a0248e6 commit d116e6c

File tree

12 files changed

+74
-125
lines changed

12 files changed

+74
-125
lines changed

lib/MLDataDevices/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.17.3"
4+
version = "1.17.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

lib/MLDataDevices/ext/AMDGPUExt.jl

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,37 +82,24 @@ function amdgpu_array_adapt(::Type{T}, x) where {T}
8282
return Internal.array_adapt(AMDGPU.roc, ROCArray, T, x)
8383
end
8484

85-
function Adapt.adapt_storage(::AMDGPUDevice{D,Missing}, x::AbstractArray) where {D}
86-
MLDataDevices.get_device_type(x) <: AMDGPUDevice && return x
87-
return amdgpu_array_adapt(Missing, x)
88-
end
89-
90-
function Adapt.adapt_storage(::AMDGPUDevice{D,Nothing}, x::AbstractArray) where {D}
91-
MLDataDevices.get_device_type(x) <: AMDGPUDevice && return x
92-
return amdgpu_array_adapt(Nothing, x)
93-
end
94-
95-
function Adapt.adapt_storage(
96-
::AMDGPUDevice{D,T}, x::AbstractArray{ET}
97-
) where {D,T<:AbstractFloat,ET<:Number}
98-
MLDataDevices.get_device_type(x) <: AMDGPUDevice && ET == T && return x
99-
return amdgpu_array_adapt(T, x)
100-
end
101-
10285
function Adapt.adapt_storage(to::AMDGPUDevice{D,E}, x::AbstractArray) where {D,E}
10386
old_dev = AMDGPU.device() # remember the current device
10487
dev = MLDataDevices.get_device(x)
10588
if !(dev isa AMDGPUDevice)
106-
AMDGPU.device!(to.device)
89+
to.device !== nothing && AMDGPU.device!(to.device)
10790
x_new = amdgpu_array_adapt(to, x)
108-
AMDGPU.device!(old_dev)
91+
to.device !== nothing && AMDGPU.device!(old_dev)
10992
return x_new
110-
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device)
93+
elseif (
94+
dev.device === nothing ||
95+
to.device === nothing ||
96+
AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device)
97+
)
11198
return x
11299
else
113-
AMDGPU.device!(to.device)
100+
to.device !== nothing && AMDGPU.device!(to.device)
114101
x_new = copy(x)
115-
AMDGPU.device!(old_dev)
102+
to.device !== nothing && AMDGPU.device!(old_dev)
116103
return x_new
117104
end
118105
end

lib/MLDataDevices/ext/CUDAExt.jl

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,37 +56,24 @@ function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray)
5656
end
5757

5858
# Device Transfer
59-
cuda_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(CUDA.cu, CuArray, T, x)
60-
61-
function Adapt.adapt_storage(::CUDADevice{D,Missing}, x::AbstractArray) where {D}
62-
MLDataDevices.get_device_type(x) <: CUDADevice && return x
63-
return cuda_array_adapt(Missing, x)
64-
end
65-
66-
function Adapt.adapt_storage(::CUDADevice{D,Nothing}, x::AbstractArray) where {D}
67-
MLDataDevices.get_device_type(x) <: CUDADevice && return x
68-
return cuda_array_adapt(Nothing, x)
69-
end
70-
71-
function Adapt.adapt_storage(::CUDADevice{D,T}, x::AbstractArray) where {D,T<:AbstractFloat}
72-
MLDataDevices.get_device_type(x) <: CUDADevice && eltype(x) == T && return x
73-
return cuda_array_adapt(T, x)
59+
function cuda_array_adapt(::CUDADevice{D,E}, x) where {D,E}
60+
return Internal.array_adapt(CUDA.cu, CuArray, E, x)
7461
end
7562

7663
function Adapt.adapt_storage(to::CUDADevice{D,E}, x::AbstractArray) where {D,E}
7764
old_dev = CUDA.device() # remember the current device
7865
dev = MLDataDevices.get_device(x)
7966
if !(dev isa CUDADevice)
80-
CUDA.device!(to.device)
67+
to.device !== nothing && CUDA.device!(to.device)
8168
x_new = cuda_array_adapt(to, x)
82-
CUDA.device!(old_dev)
69+
to.device !== nothing && CUDA.device!(old_dev)
8370
return x_new
84-
elseif dev.device == to.device
71+
elseif dev.device === nothing || to.device === nothing || dev.device == to.device
8572
return x
8673
else
87-
CUDA.device!(to.device)
74+
to.device !== nothing && CUDA.device!(to.device)
8875
x_new = copy(x)
89-
CUDA.device!(old_dev)
76+
to.device !== nothing && CUDA.device!(old_dev)
9077
return x_new
9178
end
9279
end

lib/MLDataDevices/ext/GPUArraysSparseArraysExt.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@ Internal.get_device_type(rng::GPUArrays.RNG) = Internal.get_device_type(rng.stat
1414
for (T1, T2) in
1515
((AbstractGPUSparseMatrixCSC, SparseMatrixCSC), (AbstractGPUSparseVector, SparseVector))
1616
@eval begin
17-
Adapt.adapt_storage(::CPUDevice{Missing}, x::$(T1)) = $(T2)(x)
18-
Adapt.adapt_storage(::CPUDevice{Nothing}, x::$(T1)) = $(T2)(x)
19-
Adapt.adapt_storage(::CPUDevice{T}, x::$(T1)) where {T<:AbstractFloat} = $(T2){T}(x)
17+
function Adapt.adapt_storage(::CPUDevice{T}, x::$(T1)) where {T}
18+
if T <: AbstractFloat
19+
eltype(x) <: Complex && return $(T2){Complex{T}}(x)
20+
return $(T2){T}(x)
21+
end
22+
return $(T2)(x)
23+
end
2024
end
2125
end
2226

lib/MLDataDevices/ext/MetalExt.jl

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module MetalExt
22

33
using Adapt: Adapt
44
using GPUArrays: GPUArrays
5-
using MLDataDevices: MLDataDevices, Internal, MetalDevice, reset_gpu_device!
5+
using MLDataDevices:
6+
MLDataDevices, Internal, MetalDevice, reset_gpu_device!, get_device_type
67
using Metal: Metal, MtlArray
78

89
__init__() = reset_gpu_device!()
@@ -29,21 +30,12 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray)
2930
end
3031

3132
# Device Transfer
32-
metal_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(Metal.mtl, MtlArray, T, x)
33-
34-
function Adapt.adapt_storage(::MetalDevice{Missing}, x::AbstractArray)
35-
MLDataDevices.get_device_type(x) <: MetalDevice && return x
36-
return metal_array_adapt(Missing, x)
37-
end
38-
39-
function Adapt.adapt_storage(::MetalDevice{Nothing}, x::AbstractArray)
40-
MLDataDevices.get_device_type(x) <: MetalDevice && return x
41-
return metal_array_adapt(Nothing, x)
42-
end
43-
44-
function Adapt.adapt_storage(::MetalDevice{T}, x::AbstractArray) where {T<:AbstractFloat}
45-
MLDataDevices.get_device_type(x) <: MetalDevice && eltype(x) == T && return x
46-
return metal_array_adapt(T, x)
33+
function Adapt.adapt_storage(::MetalDevice{T}, x::AbstractArray) where {T}
34+
# Metal is single-device, so we only need to check the device type
35+
if get_device_type(x) <: MetalDevice
36+
Internal.return_without_conversion(T, x) && return x
37+
end
38+
return Internal.array_adapt(Metal.mtl, MtlArray, T, x)
4739
end
4840

4941
end

lib/MLDataDevices/ext/OpenCLExt.jl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,22 +76,12 @@ end
7676

7777
opencl_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(CLArray, CLArray, T, x)
7878

79-
function Adapt.adapt_storage(::OpenCLDevice{Missing}, x::AbstractArray)
80-
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
81-
return opencl_array_adapt(Missing, x)
82-
end
83-
84-
function Adapt.adapt_storage(::OpenCLDevice{Nothing}, x::AbstractArray)
85-
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
86-
return opencl_array_adapt(Nothing, x)
87-
end
88-
89-
function Adapt.adapt_storage(::OpenCLDevice{T}, x::AbstractArray) where {T<:AbstractFloat}
90-
MLDataDevices.get_device_type(x) <: OpenCLDevice && eltype(x) == T && return x
91-
if T === Float64 && !SUPPORTS_FP64[cl.device()]
92-
throw(ArgumentError("FP64 is not supported on this device"))
79+
function Adapt.adapt_storage(::OpenCLDevice{T}, x::AbstractArray) where {T}
80+
if MLDataDevices.get_device_type(x) <: OpenCLDevice
81+
Internal.return_without_conversion(T, x) && return x
9382
end
94-
return opencl_array_adapt(T, x)
83+
84+
return Internal.array_adapt(CLArray, CLArray, T, x)
9585
end
9686

9787
end

lib/MLDataDevices/ext/ReactantExt.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,14 @@ Internal.unsafe_free_internal!(::Type{ReactantDevice}, x::AbstractArray) = nothi
103103

104104
# Device Transfer
105105
Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage(
106-
dev::ReactantDevice{C,D,S,Missing}, x::AbstractArray
107-
) where {C,D,S}
108-
return ConcreteRArray(x; device_to_kwargs(dev, x)...) # Preserves eltype
109-
end
106+
dev::ReactantDevice{C,D,S,T}, x::AbstractArray{ET}
107+
) where {C,D,S,T,ET}
108+
if T === Nothing || T === Missing
109+
return ConcreteRArray(x; device_to_kwargs(dev, x)...) # Preserves eltype
110+
end
110111

111-
Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage(
112-
dev::ReactantDevice{C,D,S,Nothing}, x::AbstractArray
113-
) where {C,D,S}
114-
return ConcreteRArray(x; device_to_kwargs(dev, x)...) # Preserves eltype
115-
end
112+
@assert T <: AbstractFloat
116113

117-
Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage(
118-
dev::ReactantDevice{C,D,S,T}, x::AbstractArray{ET}
119-
) where {C,D,S,T<:AbstractFloat,ET}
120114
# Convert eltype first, then move to device
121115
if ET <: AbstractFloat
122116
x_converted = convert(AbstractArray{T}, x)

lib/MLDataDevices/ext/SparseArraysExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ using SparseArrays: AbstractSparseArray, nonzeros
77
Internal.get_device(x::AbstractSparseArray) = Internal.get_device(nonzeros(x))
88
Internal.get_device_type(x::AbstractSparseArray) = Internal.get_device_type(nonzeros(x))
99

10-
Adapt.adapt_storage(::CPUDevice{Missing}, x::AbstractSparseArray) = x
11-
Adapt.adapt_storage(::CPUDevice{Nothing}, x::AbstractSparseArray) = x
12-
function Adapt.adapt_storage(
13-
::CPUDevice{T}, x::AbstractSparseArray
14-
) where {T<:AbstractFloat}
15-
return convert(AbstractSparseArray{T}, x)
10+
function Adapt.adapt_storage(::CPUDevice{T}, x::AbstractSparseArray) where {T}
11+
if T <: AbstractFloat
12+
eltype(x) <: Complex && return convert(AbstractSparseArray{Complex{T}}, x)
13+
return convert(AbstractSparseArray{T}, x)
14+
end
15+
return x
1616
end
1717

1818
end

lib/MLDataDevices/ext/oneAPIExt.jl

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,24 +76,11 @@ for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
7676
end
7777
end
7878

79-
oneapi_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(oneArray, oneArray, T, x)
80-
81-
function Adapt.adapt_storage(::oneAPIDevice{Missing}, x::AbstractArray)
82-
MLDataDevices.get_device_type(x) <: oneAPIDevice && return x
83-
return oneapi_array_adapt(Missing, x)
84-
end
85-
86-
function Adapt.adapt_storage(::oneAPIDevice{Nothing}, x::AbstractArray)
87-
MLDataDevices.get_device_type(x) <: oneAPIDevice && return x
88-
return oneapi_array_adapt(Nothing, x)
89-
end
90-
91-
function Adapt.adapt_storage(::oneAPIDevice{T}, x::AbstractArray) where {T<:AbstractFloat}
92-
MLDataDevices.get_device_type(x) <: oneAPIDevice && eltype(x) == T && return x
93-
if T === Float64 && !SUPPORTS_FP64[oneAPI.device()]
94-
throw(ArgumentError("FP64 is not supported on this device"))
79+
function Adapt.adapt_storage(::oneAPIDevice{T}, x::AbstractArray) where {T}
80+
if MLDataDevices.get_device_type(x) <: oneAPIDevice
81+
Internal.return_without_conversion(T, x) && return x
9582
end
96-
return oneapi_array_adapt(T, x)
83+
return Internal.array_adapt(oneArray, oneArray, T, x)
9784
end
9885

9986
end

lib/MLDataDevices/src/internal.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,9 @@ end
289289
function to_rarray_internal end
290290

291291
# Utility function to facilitate data transfer
292-
# For AbstractFloat and Complex{<:AbstractFloat} arrays, we provide specialized methods to avoid
293-
# ambiguity with the general fallback and to enable efficient type conversion when needed.
292+
# For AbstractFloat and Complex{<:AbstractFloat} arrays, we provide specialized methods to
293+
# avoid ambiguity with the general fallback and to enable efficient type conversion when
294+
# needed.
294295
function array_adapt(
295296
f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{<:AbstractFloat}
296297
) where {F,aType}
@@ -357,4 +358,14 @@ function array_adapt(::F, ::Type{aType}, ::Type{E}, x::AbstractArray{T}) where {
357358
return aType(x)
358359
end
359360

361+
return_without_conversion(::Type{Nothing}, ::AbstractArray) = true
362+
return_without_conversion(::Type{Missing}, ::AbstractArray) = true
363+
return_without_conversion(::Type{T}, ::AbstractArray{T}) where {T<:AbstractFloat} = true
364+
function return_without_conversion(
365+
::Type{T}, ::AbstractArray{Complex{T}}
366+
) where {T<:AbstractFloat}
367+
return true
368+
end
369+
return_without_conversion(::Type{T}, ::AbstractArray) where {T} = false
370+
360371
end

0 commit comments

Comments
 (0)