Skip to content

Commit f3fc36c

Browse files
authored
feat: support track numbers via reactant device API (#1533)
1 parent f9fd7bf commit f3fc36c

File tree

4 files changed

+74
-22
lines changed

4 files changed

+74
-22
lines changed

lib/MLDataDevices/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.14.0"
4+
version = "1.15.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

lib/MLDataDevices/ext/MLDataDevicesReactantExt.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,19 @@ function device_to_kwargs(dev::ReactantDevice, x)
3434
dev.device === missing || (kwargs = (; kwargs..., device=dev.device))
3535
if dev.sharding !== missing
3636
if dev.sharding isa IdDict
37-
sharding = (
38-
haskey(dev.sharding, x) ? dev.sharding[x] : Reactant.Sharding.NoSharding()
39-
)
37+
if haskey(dev.sharding, x)
38+
sharding = dev.sharding[x]
39+
else
40+
if all(x -> x isa Reactant.Sharding.NoSharding, values(dev.sharding))
41+
sharding = Reactant.Sharding.NoSharding()
42+
else
43+
meshes = unique([
44+
getfield(sharding, :mesh) for sharding in values(dev.sharding)
45+
])
46+
@assert length(meshes) == 1 "Multiple meshes are not supported."
47+
sharding = Reactant.Sharding.Replicated(only(meshes))
48+
end
49+
end
4050
@assert sharding isa Reactant.Sharding.AbstractSharding
4151
kwargs = (; kwargs..., sharding)
4252
elseif dev.sharding isa Reactant.Sharding.AbstractSharding
@@ -140,6 +150,12 @@ Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage(
140150
)
141151
return rng
142152
end
153+
Profiler.@annotate "Device Transfer (Reactant)" function Adapt.adapt_storage(
154+
dev::ReactantDevice{C,D,S,T,TN}, x::Number
155+
) where {C,D,S,T,TN}
156+
typeof(x) <: TN && return ConcreteRNumber(x; device_to_kwargs(dev, x)...)
157+
return x
158+
end
143159

144160
function Adapt.adapt_storage(
145161
::CPUDevice,

lib/MLDataDevices/src/public.jl

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,28 @@ MetalDevice() = MetalDevice{Missing}()
2121
struct oneAPIDevice{T<:EltypeAdaptorType} <: AbstractGPUDevice end
2222
oneAPIDevice() = oneAPIDevice{Missing}()
2323

24-
struct ReactantDevice{C,D,S,T<:EltypeAdaptorType} <: AbstractAcceleratorDevice
24+
struct ReactantDevice{C,D,S,T<:EltypeAdaptorType,TN} <: AbstractAcceleratorDevice
2525
client::C
2626
device::D
2727
sharding::S
2828
end
2929
function ReactantDevice()
30-
return ReactantDevice{Missing,Missing,Missing,Missing}(missing, missing, missing)
30+
return ReactantDevice{Missing,Missing,Missing,Missing,Union{}}(
31+
missing, missing, missing
32+
)
3133
end
32-
function ReactantDevice(client, device, sharding)
33-
return ReactantDevice{typeof(client),typeof(device),typeof(sharding),Missing}(
34+
function ReactantDevice(client, device, sharding, _::Type{TN}=Union{}) where {TN}
35+
return ReactantDevice{typeof(client),typeof(device),typeof(sharding),Missing,TN}(
3436
client, device, sharding
3537
)
3638
end
3739

40+
function with_track_numbers(
41+
dev::ReactantDevice{C,D,S,T,Union{}}, _::Type{TN}
42+
) where {C,D,S,T,TN}
43+
return ReactantDevice{C,D,S,T,TN}(dev.client, dev.device, dev.sharding)
44+
end
45+
3846
# Helper functions to get the eltype from device types
3947
Base.eltype(::CPUDevice{T}) where {T} = T
4048
Base.eltype(::CUDADevice{D,T}) where {D,T} = T
@@ -70,30 +78,38 @@ function with_eltype(::oneAPIDevice, ::Type{T}) where {T<:AbstractFloat}
7078
return oneAPIDevice{T}()
7179
end
7280

73-
function with_eltype(dev::ReactantDevice{C,D,S}, ::Missing) where {C,D,S}
74-
return ReactantDevice{C,D,S,Missing}(dev.client, dev.device, dev.sharding)
81+
function with_eltype(dev::ReactantDevice{C,D,S,<:Any,TN}, ::Missing) where {C,D,S,TN}
82+
return ReactantDevice{C,D,S,Missing,TN}(dev.client, dev.device, dev.sharding)
7583
end
76-
function with_eltype(dev::ReactantDevice{C,D,S}, ::Nothing) where {C,D,S}
77-
return ReactantDevice{C,D,S,Nothing}(dev.client, dev.device, dev.sharding)
84+
function with_eltype(dev::ReactantDevice{C,D,S,<:Any,TN}, ::Nothing) where {C,D,S,TN}
85+
return ReactantDevice{C,D,S,Nothing,TN}(dev.client, dev.device, dev.sharding)
7886
end
79-
function with_eltype(dev::ReactantDevice{C,D,S}, ::Type{T}) where {C,D,S,T<:AbstractFloat}
80-
return ReactantDevice{C,D,S,T}(dev.client, dev.device, dev.sharding)
87+
function with_eltype(
88+
dev::ReactantDevice{C,D,S,<:Any,TN}, ::Type{T}
89+
) where {C,D,S,TN,T<:AbstractFloat}
90+
return ReactantDevice{C,D,S,T,TN}(dev.client, dev.device, dev.sharding)
8191
end
8292

8393
function Base.:(==)(
84-
x::ReactantDevice{<:Any,<:Any,<:Any,T1}, y::ReactantDevice{<:Any,<:Any,<:Any,T2}
85-
) where {T1,T2}
94+
x::ReactantDevice{<:Any,<:Any,<:Any,T1,TN1}, y::ReactantDevice{<:Any,<:Any,<:Any,T2,TN2}
95+
) where {T1,T2,TN1,TN2}
8696
if x.client !== missing && y.client !== missing && x.client.client != y.client.client
8797
return false
8898
end
8999

90-
if x.device !== missing && y.device !== missing && x.device.device != y.device.device
100+
if (
101+
x.device !== missing &&
102+
x.device !== nothing && # can be nothing if objects are sharded
103+
y.device !== missing &&
104+
y.device !== nothing && # can be nothing if objects are sharded
105+
x.device.device != y.device.device
106+
)
91107
return false
92108
end
93109

94110
T1 === Missing && return T2 === Missing || T2 === Nothing
95111
T2 === Missing && return T1 === Missing || T1 === Nothing
96-
return T1 === T2
112+
return T1 === T2 && TN1 === TN2
97113
end
98114

99115
# XXX: Deprecate in v2
@@ -311,7 +327,8 @@ cpu_device(eltype::T=missing) where {T} = with_eltype(CPUDevice(), eltype)
311327

312328
"""
313329
reactant_device(;
314-
force::Bool=false, client=missing, device=missing, sharding=missing, eltype=missing
330+
force::Bool=false, client=missing, device=missing, sharding=missing, eltype=missing,
331+
track_numbers::Type{TN}=Union{}
315332
) -> Union{ReactantDevice, CPUDevice}
316333
317334
Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is
@@ -324,19 +341,28 @@ specified, then the default client and index are used.
324341
`Reactant.Sharding.AbstractSharding` is specified, then we use it to shard all abstract
325342
arrays. Alternatively, pass in a `IdDict` to specify the sharding for specific leaves.
326343
344+
`track_numbers` can be specified to convert numbers of specified subtypes to be traced.
345+
327346
The `eltype` parameter controls element type conversion:
328347
329348
- `missing/nothing` (default): Preserves the original element type
330349
- `Type{<:AbstractFloat}`: Converts floating-point arrays to the specified type
331350
"""
332351
function reactant_device(
333-
eltype::T=missing; force::Bool=false, client=missing, device=missing, sharding=missing
334-
) where {T}
352+
eltype::T=missing;
353+
force::Bool=false,
354+
client=missing,
355+
device=missing,
356+
sharding=missing,
357+
track_numbers::Type{TN}=Union{},
358+
) where {T,TN}
335359
msg = "`ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \
336360
calling this function. Defaulting to CPU."
337361
if loaded(ReactantDevice)
338362
if functional(ReactantDevice)
339-
return with_eltype(ReactantDevice(client, device, sharding), eltype)
363+
return with_track_numbers(
364+
with_eltype(ReactantDevice(client, device, sharding), eltype), track_numbers
365+
)
340366
end
341367
msg = "`ReactantDevice` is loaded but not functional. Defaulting to CPU."
342368
end

lib/MLDataDevices/test/reactant_tests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,13 @@ end
184184
@test dev(rng) isa Reactant.ReactantRNG
185185
end
186186
end
187+
188+
@testset "Track Numbers" begin
189+
if MLDataDevices.functional(ReactantDevice)
190+
dev = reactant_device(; track_numbers=Float32)
191+
x = dev(2.0f0)
192+
@test x isa ConcreteRNumber{Float32}
193+
x = dev(2.0)
194+
@test x isa Float64
195+
end
196+
end

0 commit comments

Comments
 (0)