1- import Base . Broadcast : Broadcasted, Extruded, BroadcastStyle, ArrayStyle
1+ # broadcasting
22
3- BroadcastStyle( :: Type{<:CuArray} ) = ArrayStyle{CuArray}()
3+ using Base . Broadcast : BroadcastStyle, Broadcasted
44
5- function Base. similar(bc:: Broadcasted{ArrayStyle{CuArray}} , :: Type{T} ) where T
5+ struct CuArrayStyle{N} <: AbstractGPUArrayStyle{N} end
6+ CuArrayStyle(:: Val{N} ) where N = CuArrayStyle{N}()
7+ CuArrayStyle{M}(:: Val{N} ) where {N,M} = CuArrayStyle{N}()
8+
9+ BroadcastStyle(:: Type{<:CuArray{T,N}} ) where {T,N} = CuArrayStyle{N}()
10+
11+ Base. similar(bc:: Broadcasted{CuArrayStyle{N}} , :: Type{T} ) where {N,T} =
612 similar(CuArray{T}, axes(bc))
7- end
813
9- function Base. similar(bc:: Broadcasted{ArrayStyle{CuArray }} , :: Type{T} , dims... ) where {T}
10- similar( CuArray{T}, dims... )
11- end
14+ Base. similar(bc:: Broadcasted{CuArrayStyle{N }} , :: Type{T} , dims... ) where {N,T} =
15+ CuArray{T}(undef , dims... )
16+
1217
13- # replace base functions with libdevice alternatives
14- # TODO : do this with Cassette.jl
18+ # # replace base functions with libdevice alternatives
1519
1620cufunc(f) = f
1721cufunc(:: Type{T} ) where T = (x... ) -> T(x... ) # broadcasting type ctors isn't GPU compatible
1822
19- Broadcast. broadcasted(:: ArrayStyle{CuArray } , f, args... ) =
20- Broadcasted{ArrayStyle{CuArray }}(cufunc(f), args, nothing )
23+ Broadcast. broadcasted(:: CuArrayStyle{N } , f, args... ) where {N} =
24+ Broadcasted{CuArrayStyle{N }}(cufunc(f), args, nothing )
2125
22- libdevice = :[
26+ const libdevice = :[
2327 cos, cospi, sin, sinpi, tan, acos, asin, atan,
2428 cosh, sinh, tanh, acosh, asinh, atanh,
2529 log, log10, log1p, log2, logb, ilogb,
@@ -40,7 +44,8 @@ for f in libdevice
4044 @eval cufunc(:: typeof (Base.$ f)) = CUDAnative.$ f
4145end
4246
43- # broadcast ^
47+ # broadcast ^
48+
4449culiteral_pow(:: typeof (^ ), x:: T , :: Val{0} ) where {T<: Real } = one(x)
4550culiteral_pow(:: typeof (^ ), x:: T , :: Val{1} ) where {T<: Real } = x
4651culiteral_pow(:: typeof (^ ), x:: T , :: Val{2} ) where {T<: Real } = x * x
0 commit comments