@@ -21,20 +21,28 @@ MetalDevice() = MetalDevice{Missing}()
2121struct oneAPIDevice{T<: EltypeAdaptorType } <: AbstractGPUDevice end
2222oneAPIDevice() = oneAPIDevice{Missing}()
2323
24- struct ReactantDevice{C,D,S,T<: EltypeAdaptorType } <: AbstractAcceleratorDevice
24+ struct ReactantDevice{C,D,S,T<: EltypeAdaptorType ,TN } <: AbstractAcceleratorDevice
2525 client:: C
2626 device:: D
2727 sharding:: S
2828end
2929function ReactantDevice()
30- return ReactantDevice{Missing,Missing,Missing,Missing}(missing , missing , missing )
30+ return ReactantDevice{Missing,Missing,Missing,Missing,Union{}}(
31+ missing , missing , missing
32+ )
3133end
32- function ReactantDevice(client, device, sharding)
33- return ReactantDevice{typeof(client),typeof(device),typeof(sharding),Missing}(
34+ function ReactantDevice(client, device, sharding, _ :: Type{TN} = Union{}) where {TN}
35+ return ReactantDevice{typeof(client),typeof(device),typeof(sharding),Missing,TN }(
3436 client, device, sharding
3537 )
3638end
3739
40+ function with_track_numbers(
41+ dev:: ReactantDevice{C,D,S,T,Union{}} , _:: Type{TN}
42+ ) where {C,D,S,T,TN}
43+ return ReactantDevice{C,D,S,T,TN}(dev. client, dev. device, dev. sharding)
44+ end
45+
3846# Helper functions to get the eltype from device types
3947Base. eltype(:: CPUDevice{T} ) where {T} = T
4048Base. eltype(:: CUDADevice{D,T} ) where {D,T} = T
@@ -70,30 +78,38 @@ function with_eltype(::oneAPIDevice, ::Type{T}) where {T<:AbstractFloat}
7078 return oneAPIDevice{T}()
7179end
7280
73- function with_eltype(dev:: ReactantDevice{C,D,S} , :: Missing ) where {C,D,S}
74- return ReactantDevice{C,D,S,Missing}(dev. client, dev. device, dev. sharding)
81+ function with_eltype(dev:: ReactantDevice{C,D,S,<:Any,TN } , :: Missing ) where {C,D,S,TN }
82+ return ReactantDevice{C,D,S,Missing,TN }(dev. client, dev. device, dev. sharding)
7583end
76- function with_eltype(dev:: ReactantDevice{C,D,S} , :: Nothing ) where {C,D,S}
77- return ReactantDevice{C,D,S,Nothing}(dev. client, dev. device, dev. sharding)
84+ function with_eltype(dev:: ReactantDevice{C,D,S,<:Any,TN } , :: Nothing ) where {C,D,S,TN }
85+ return ReactantDevice{C,D,S,Nothing,TN }(dev. client, dev. device, dev. sharding)
7886end
79- function with_eltype(dev:: ReactantDevice{C,D,S} , :: Type{T} ) where {C,D,S,T<: AbstractFloat }
80- return ReactantDevice{C,D,S,T}(dev. client, dev. device, dev. sharding)
87+ function with_eltype(
88+ dev:: ReactantDevice{C,D,S,<:Any,TN} , :: Type{T}
89+ ) where {C,D,S,TN,T<: AbstractFloat }
90+ return ReactantDevice{C,D,S,T,TN}(dev. client, dev. device, dev. sharding)
8191end
8292
8393function Base.:(== )(
84- x:: ReactantDevice{<:Any,<:Any,<:Any,T1} , y:: ReactantDevice{<:Any,<:Any,<:Any,T2}
85- ) where {T1,T2}
94+ x:: ReactantDevice{<:Any,<:Any,<:Any,T1,TN1 } , y:: ReactantDevice{<:Any,<:Any,<:Any,T2,TN2 }
95+ ) where {T1,T2,TN1,TN2 }
8696 if x. client != = missing && y. client != = missing && x. client. client != y. client. client
8797 return false
8898 end
8999
90- if x. device != = missing && y. device != = missing && x. device. device != y. device. device
100+ if (
101+ x. device != = missing &&
102+ x. device != = nothing && # can be nothing if objects are sharded
103+ y. device != = missing &&
104+ y. device != = nothing && # can be nothing if objects are sharded
105+ x. device. device != y. device. device
106+ )
91107 return false
92108 end
93109
94110 T1 === Missing && return T2 === Missing || T2 === Nothing
95111 T2 === Missing && return T1 === Missing || T1 === Nothing
96- return T1 === T2
112+ return T1 === T2 && TN1 === TN2
97113end
98114
99115# XXX : Deprecate in v2
@@ -311,7 +327,8 @@ cpu_device(eltype::T=missing) where {T} = with_eltype(CPUDevice(), eltype)
311327
312328"""
313329 reactant_device(;
314- force::Bool=false, client=missing, device=missing, sharding=missing, eltype=missing
330+ force::Bool=false, client=missing, device=missing, sharding=missing, eltype=missing,
331+ track_numbers::Type{TN}=Union{}
315332 ) -> Union{ReactantDevice, CPUDevice}
316333
317334Return a `ReactantDevice` object if functional. Otherwise, throw an error if `force` is
@@ -324,19 +341,28 @@ specified, then the default client and index are used.
324341`Reactant.Sharding.AbstractSharding` is specified, then we use it to shard all abstract
325342arrays. Alternatively, pass in a `IdDict` to specify the sharding for specific leaves.
326343
344+ `track_numbers` can be specified to convert numbers of specified subtypes to be traced.
345+
327346The `eltype` parameter controls element type conversion:
328347
329348 - `missing/nothing` (default): Preserves the original element type
330349 - `Type{<:AbstractFloat}`: Converts floating-point arrays to the specified type
331350"""
332351function reactant_device(
333- eltype:: T = missing ; force:: Bool = false , client= missing , device= missing , sharding= missing
334- ) where {T}
352+ eltype:: T = missing ;
353+ force:: Bool = false ,
354+ client= missing ,
355+ device= missing ,
356+ sharding= missing ,
357+ track_numbers:: Type{TN} = Union{},
358+ ) where {T,TN}
335359 msg = " `ReactantDevice` is not loaded or not functional. Load `Reactant.jl` before \
336360 calling this function. Defaulting to CPU."
337361 if loaded(ReactantDevice)
338362 if functional(ReactantDevice)
339- return with_eltype(ReactantDevice(client, device, sharding), eltype)
363+ return with_track_numbers(
364+ with_eltype(ReactantDevice(client, device, sharding), eltype), track_numbers
365+ )
340366 end
341367 msg = " `ReactantDevice` is loaded but not functional. Defaulting to CPU."
342368 end
0 commit comments