Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,13 @@ end
@static if Base.thismajor(LLVM.version()) <= v"18"
# LLVM 18 and below generate non-existing instructions for Julia's default methods of
# fast min/max on fp64: https://github.com/JuliaGPU/CUDA.jl/issues/2886
@device_override @inline Base.FastMath.max_fast(x::Float64, y::Float64) = ifelse(y > x, y, x)
@device_override @inline Base.FastMath.min_fast(x::Float64, y::Float64) = ifelse(y > x, x, y)
@device_override @inline Base.FastMath.minmax_fast(x::Float64, y::Float64) = ifelse(y > x, (x, y), (y, x))
for T in (Float16, Float32, Float64)
@eval begin
@device_override @inline Base.FastMath.max_fast(x::$T, y::$T) = ifelse(y > x, y, x)
@device_override @inline Base.FastMath.min_fast(x::$T, y::$T) = ifelse(y > x, x, y)
@device_override @inline Base.FastMath.minmax_fast(x::$T, y::$T) = ifelse(y > x, (x, y), (y, x))
end
Comment on lines +369 to +373
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do something like

@device_override @inline Base.FastMath.max_fast(x::$T, y::$T) where {T<:Union{Float16, Float32, Float64}} = ifelse(y > x, y, x)

just to avoid the loop
?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm always wary for doing so, because the Base method may then end up being more specific (and we really want these to apply). In this case, Base doesn't use metaprogramming so I guess it could work..

end
end

@device_function saturate(x::Float32) = ccall("extern __nv_saturatef", llvmcall, Cfloat, (Cfloat,), x)
Expand Down