Skip to content

Commit 2eb1019

Browse files
committed
fix: preserve object when device is same
1 parent 59c0c69 commit 2eb1019

File tree

6 files changed

+23
-6
lines changed

6 files changed

+23
-6
lines changed

lib/MLDataDevices/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.6.3"
4+
version = "1.6.4"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

lib/MLDataDevices/ext/MLDataDevicesAMDGPUExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,11 @@ function Internal.unsafe_free_internal!(::Type{AMDGPUDevice}, x::AbstractArray)
7777
end
7878

7979
# Device Transfer
80-
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
80+
function Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray)
81+
MLDataDevices.get_device_type(x) <: AMDGPUDevice && return x
82+
return AMDGPU.roc(x)
83+
end
84+
8185
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
8286
old_dev = AMDGPU.device() # remember the current device
8387
dev = MLDataDevices.get_device(x)

lib/MLDataDevices/ext/MLDataDevicesCUDAExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ function Internal.unsafe_free_internal!(::Type{CUDADevice}, x::AbstractArray)
5757
end
5858

5959
# Device Transfer
60-
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
60+
function Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray)
61+
MLDataDevices.get_device_type(x) <: CUDADevice && return x
62+
return CUDA.cu(x)
63+
end
6164

6265
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
6366
old_dev = CUDA.device() # remember the current device

lib/MLDataDevices/ext/MLDataDevicesMetalExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ function Internal.unsafe_free_internal!(::Type{MetalDevice}, x::AbstractArray)
2929
end
3030

3131
# Device Transfer
32-
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)
32+
function Adapt.adapt_storage(::MetalDevice, x::AbstractArray)
33+
MLDataDevices.get_device_type(x) <: MetalDevice && return x
34+
return Metal.mtl(x)
35+
end
3336

3437
end

lib/MLDataDevices/ext/MLDataDevicesoneAPIExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ end
4242
# Device Transfer
4343
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
4444
@eval function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray{$(T1)})
45+
MLDataDevices.get_device_type(x) <: oneAPIDevice && return x
4546
if !SUPPORTS_FP64[oneAPI.device()]
4647
@warn LazyString(
4748
"Double type is not supported on this device. Using `", $(T2), "` instead.")
@@ -50,6 +51,9 @@ for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
5051
return oneArray(x)
5152
end
5253
end
53-
Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray) = oneArray(x)
54+
function Adapt.adapt_storage(::oneAPIDevice, x::AbstractArray)
55+
MLDataDevices.get_device_type(x) <: oneAPIDevice && return x
56+
return oneArray(x)
57+
end
5458

5559
end

lib/MLDataDevices/src/public.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,10 @@ for op in (:get_device, :get_device_type)
375375
end
376376

377377
# Adapt Interface
378-
Adapt.adapt_storage(::CPUDevice, x::AbstractArray) = Array(x)
378+
function Adapt.adapt_storage(::CPUDevice, x::AbstractArray)
379+
get_device_type(x) <: CPUDevice && return x
380+
return Array(x)
381+
end
379382
Adapt.adapt_storage(to::AbstractDevice, ::Random.TaskLocalRNG) = default_device_rng(to)
380383
Adapt.adapt_storage(::AbstractDevice, rng::AbstractRNG) = rng
381384

0 commit comments

Comments
 (0)