-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathadapt.jl
More file actions
56 lines (41 loc) · 1.53 KB
/
adapt.jl
File metadata and controls
56 lines (41 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import Adapt
"""
Adapt.adapt_structure(::Type{<:AbstractArray}, context::AbstractCommsContext)
Adapt a given context to a context with a device associated with the given array type.
# Example
```julia
Adapt.adapt_structure(Array, ClimaComms.context(ClimaComms.CUDADevice())) -> ClimaComms.CPUSingleThreaded()
```
!!! note
By default, adapting to `Array` creates a `CPUSingleThreaded` device, and
there is currently no way to convert to a CPUMultiThreaded device.
"""
Adapt.adapt_structure(to::Type{<:AbstractArray}, ctx::AbstractCommsContext) =
context(Adapt.adapt(to, device(ctx)))
"""
Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice)
Adapt a given device to a device associated with the given array type.
# Example
```julia
Adapt.adapt_structure(Array, ClimaComms.CUDADevice()) -> ClimaComms.CPUSingleThreaded()
```
!!! note
By default, adapting to `Array` creates a `CPUSingleThreaded` device, and
there is currently no way to convert to a CPUMultiThreaded device.
"""
Adapt.adapt_structure(::Type{<:AbstractArray}, device::AbstractDevice) =
CPUSingleThreaded()
"""
adapt(device::AbstractDevice, x)
Adapt an object `x` to be compatible with the specified `device`.
"""
function adapt(device::AbstractDevice, x)
return Adapt.adapt(array_type(device), x)
end
"""
adapt(device::AbstractCommsContext, x)
Adapt an object `x` to be compatible with the specified `context`.
"""
function adapt(context::AbstractCommsContext, x)
return Adapt.adapt(array_type(device(context)), x)
end