diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index a1d589721d..742058006f 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -2,7 +2,6 @@ using Base: FastMath - ## helpers within(lower, upper) = (val) -> lower <= val <= upper @@ -83,9 +82,13 @@ end @device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x) -# TODO: enable once PTX > 7.0 is supported -# @device_override Base.tanh(x::Float16) = @asmcall("tanh.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) - +@device_override function Base.tanh(x::Float16) + if compute_capability() >= sv"7.5" + @asmcall("tanh.approx.f16 \$0, \$1;", "=r,r", Float16, Tuple{Float16}, x) + else + Float16(tanh(Float32(x))) + end +end ## inverse hyperbolic @@ -103,10 +106,39 @@ end @device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.log(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath log(f) + r = Float16(f) + + # handle degenrate cases + r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r) + r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r) + r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r) + r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r) + + return r +end + @device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.log10(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath log10(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r) + r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r) + r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r) + r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r) + + return r +end @device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x) @@ -114,6 +146,18 @@ end @device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.log2(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath log2(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r) + r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r) + + return r +end @device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x) @device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x) @@ -127,16 +171,65 @@ end @device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.exp(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = fma(f, log2(Float32(ℯ)), -0.0f0) + f = @fastmath exp2(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r) + r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r) + r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r) + r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r) + + return r +end @device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) -@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) -# TODO: enable once PTX > 7.0 is supported -# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) +@device_override function Base.exp2(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath exp2(f) + + # one ULP adjustement + f = muladd(f, reinterpret(Float32, 0x33800000), f) + r = Float16(f) + + return r +end +@device_override FastMath.exp2_fast(x::Float64) = exp2(x) +@device_override FastMath.exp2_fast(x::Float32) = + @asmcall("ex2.approx.f32 \$0, \$1;", "=r,r", Float32, Tuple{Float32}, x) +@device_override function FastMath.exp2_fast(x::Float16) + if compute_capability() >= sv"7.5" + ccall("llvm.nvvm.ex2.approx.f16", llvmcall, Float16, (Float16,), x) + else + exp2(x) + end +end @device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.exp10(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = fma(f, log2(10.f0), -0.0f0) + f = @fastmath exp2(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r) + r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r) + r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r) + r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r) + r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r) + + return r +end @device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x) @@ -204,6 +297,7 @@ end @device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0 @device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 +@device_override Base.isnan(x::Float16) = isnan(Float32(x)) @device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x) @device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x) @@ -223,14 +317,14 @@ end @device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x) @device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f) @device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f) -# TODO: enable once PTX > 7.0 is supported -# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) +@device_override Base.abs(f::Float16) = Float16(abs(Float32(f))) @device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x) ## roots and powers @device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_override Base.sqrt(x::Float16) = Float16(sqrt(Float32(x))) @device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x) @device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index d9f868a132..6fa5768df9 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -2,7 +2,9 @@ using SpecialFunctions @testset "math" begin @testset "log10" begin - @test testf(a->log10.(a), Float32[100]) + for T in (Float32, Float64) + @test testf(a->log10.(a), T[100]) + end end @testset "pow" begin @@ -12,28 +14,34 @@ using SpecialFunctions @test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1)) end end + + @testset "min/max" begin + for T in (Float32, Float64) + @test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1)) + @test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1)) + end + end @testset "isinf" begin - for x in (Inf32, Inf, NaN32, NaN) + for x in (Inf32, Inf, NaN16, NaN32, NaN) @test testf(x->isinf.(x), [x]) end end @testset "isnan" begin - for x in (Inf32, Inf, NaN32, NaN) + for x in (Inf32, Inf, NaN16, NaN32, NaN) @test testf(x->isnan.(x), [x]) end end for op in (exp, angle, exp2, exp10,) @testset "$op" begin - for T in (Float16, Float32, Float64) + for T in (Float32, Float64) @test testf(x->op.(x), rand(T, 1)) @test testf(x->op.(x), -rand(T, 1)) end end end - for op in (expm1,) @testset "$op" begin # FIXME: add expm1(::Float16) to Base @@ -50,7 +58,6 @@ using SpecialFunctions @test testf(x->op.(x), rand(T, 1)) @test testf(x->op.(x), -rand(T, 1)) end - end end @testset "mod and rem" begin @@ -97,6 +104,21 @@ using SpecialFunctions # JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load @test testf(x->exp.(x), [1e7im]) end + + @testset "Real - $op" for op in (abs, abs2, exp, exp10, log, log10) + @testset "$T" for T in (Float16, Float32, Float64) + @test testf(x->op.(x), rand(T, 1)) + end + end + + @testset "Float16 - $op" for op in (exp,exp2,exp10,log,log2,log10) + all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16)) + all_float_16 = filter(!isnan, all_float_16) + if op in (log, log2, log10) + all_float_16 = filter(>=(0), all_float_16) + end + @test testf(x->map(op, x), all_float_16) + end @testset "fastmath" begin # libdevice provides some fast math functions