@@ -82,37 +82,24 @@ function amdgpu_array_adapt(::Type{T}, x) where {T}
8282 return Internal. array_adapt(AMDGPU. roc, ROCArray, T, x)
8383end
8484
85- function Adapt. adapt_storage(:: AMDGPUDevice{D,Missing} , x:: AbstractArray ) where {D}
86- MLDataDevices. get_device_type(x) <: AMDGPUDevice && return x
87- return amdgpu_array_adapt(Missing, x)
88- end
89-
90- function Adapt. adapt_storage(:: AMDGPUDevice{D,Nothing} , x:: AbstractArray ) where {D}
91- MLDataDevices. get_device_type(x) <: AMDGPUDevice && return x
92- return amdgpu_array_adapt(Nothing, x)
93- end
94-
95- function Adapt. adapt_storage(
96- :: AMDGPUDevice{D,T} , x:: AbstractArray{ET}
97- ) where {D,T<: AbstractFloat ,ET<: Number }
98- MLDataDevices. get_device_type(x) <: AMDGPUDevice && ET == T && return x
99- return amdgpu_array_adapt(T, x)
100- end
101-
10285function Adapt. adapt_storage(to:: AMDGPUDevice{D,E} , x:: AbstractArray ) where {D,E}
10386 old_dev = AMDGPU. device() # remember the current device
10487 dev = MLDataDevices. get_device(x)
10588 if ! (dev isa AMDGPUDevice)
106- AMDGPU. device!(to. device)
89+ to . device != = nothing && AMDGPU. device!(to. device)
10790 x_new = amdgpu_array_adapt(to, x)
108- AMDGPU. device!(old_dev)
91+ to . device != = nothing && AMDGPU. device!(old_dev)
10992 return x_new
110- elseif AMDGPU. device_id(dev. device) == AMDGPU. device_id(to. device)
93+ elseif (
94+ dev. device === nothing ||
95+ to. device === nothing ||
96+ AMDGPU. device_id(dev. device) == AMDGPU. device_id(to. device)
97+ )
11198 return x
11299 else
113- AMDGPU. device!(to. device)
100+ to . device != = nothing && AMDGPU. device!(to. device)
114101 x_new = copy(x)
115- AMDGPU. device!(old_dev)
102+ to . device != = nothing && AMDGPU. device!(old_dev)
116103 return x_new
117104 end
118105end
0 commit comments