From a61571b3ceb303395208c22fd03ef0347be02881 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 6 Feb 2025 16:39:21 -0500 Subject: [PATCH 01/14] Wrap and test some more Float16 intrinsics --- src/device/intrinsics/math.jl | 182 +++++++++++++++++++++++++++- test/core/device/intrinsics/math.jl | 31 ++++- 2 files changed, 202 insertions(+), 11 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index a1d589721d..19da168c7e 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -2,7 +2,6 @@ using Base: FastMath - ## helpers within(lower, upper) = (val) -> lower <= val <= upper @@ -103,10 +102,71 @@ end @device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.log(x::Float16) + log_x = @asmcall("""{.reg.b32 f, C; + .reg.b16 r,h; + mov.b16 h,\$1; + cvt.f32.f16 f,h; + lg2.approx.ftz.f32 f,f; + mov.b32 C, 0x3f317218U; + mul.f32 f,f,C; + cvt.rn.f16.f32 r,f; + .reg.b16 spc, ulp, p; + mov.b16 spc, 0X160DU; + mov.b16 ulp, 0x9C00U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0X3BFEU; + mov.b16 ulp, 0x8010U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0X3C0BU; + mov.b16 ulp, 0x8080U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0X6051U; + mov.b16 ulp, 0x1C00U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 \$0,r; + }""", "=h,h", Float16, Tuple{Float16}, x) + return log_x +end + @device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.log10(x::Float16) + log_x = @asmcall("""{.reg.b16 h, r; + .reg.b32 f, C; + mov.b16 h, \$1; + cvt.f32.f16 f, h; + lg2.approx.ftz.f32 f, f; + mov.b32 C, 0x3E9A209BU; + mul.f32 f,f,C; + cvt.rn.f16.f32 r, f; + .reg.b16 spc, ulp, p; + mov.b16 spc, 0x338FU; + mov.b16 ulp, 0x1000U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0x33F8U; + mov.b16 ulp, 0x9000U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0x57E1U; + mov.b16 ulp, 0x9800U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0x719DU; + mov.b16 ulp, 0x9C00U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 \$0, r; + }""", "=h,h", Float16, Tuple{Float16}, x) + return log_x +end @device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x) @@ -114,6 +174,26 @@ end @device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.log2(x::Float16) + log_x = @asmcall("""{.reg.b16 h, r; + .reg.b32 f; + mov.b16 h, \$1; + cvt.f32.f16 f, h; + lg2.approx.ftz.f32 f, f; + cvt.rn.f16.f32 r, f; + .reg.b16 spc, ulp, p; + mov.b16 spc, 0xA2E2U; + mov.b16 ulp, 0x8080U; + set.eq.f16.f16 p, r, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc, 0xBF46U; + mov.b16 ulp, 0x9400U; + set.eq.f16.f16 p, r, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 \$0, r; + }""", "=h,h", Float16, Tuple{Float16}, x) + return log_x +end @device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x) @device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x) @@ -127,16 +207,95 @@ end @device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.exp(x::Float16) + exp_x = @asmcall("""{ + .reg.b32 f, C, nZ; + .reg.b16 h,r; + mov.b16 h,\$1; + cvt.f32.f16 f,h; + mov.b32 C, 0x3fb8aa3bU; + mov.b32 nZ, 0x80000000U; + fma.rn.f32 f,f,C,nZ; + ex2.approx.ftz.f32 f,f; + cvt.rn.f16.f32 r,f; + .reg.b16 spc, ulp, p; + mov.b16 spc,0X1F79U; + mov.b16 ulp,0x9400U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0X25CFU; + mov.b16 ulp,0x9400U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0XC13BU; + mov.b16 ulp,0x0400U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0XC1EFU; + mov.b16 ulp,0x0200U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 \$0,r; + }""", "=h,h", Float16, Tuple{Float16}, x) + return exp_x +end @device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.exp2(x::Float16) + exp_x = @asmcall("""{.reg.b32 f, ULP; + .reg.b16 r; + mov.b16 r,\$1; + cvt.f32.f16 f,r; + ex2.approx.ftz.f32 f,f; + mov.b32 ULP, 0x33800000U; + fma.rn.f32 f,f,ULP,f; + cvt.rn.f16.f32 r,f; + mov.b16 \$0,r; + }""", "=h,h", Float16, Tuple{Float16}, x) + return exp_x +end @device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) -# TODO: enable once PTX > 7.0 is supported -# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) @device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.exp10(x::Float16) + + exp_x = @asmcall("""{.reg.b16 h,r; + .reg.b32 f, C, nZ; + mov.b16 h, \$1; + cvt.f32.f16 f, h; + mov.b32 C, 0x40549A78U; + mov.b32 nZ, 0x80000000U; + fma.rn.f32 f,f,C,nZ; + ex2.approx.ftz.f32 f, f; + cvt.rn.f16.f32 r, f; + .reg.b16 spc, ulp, p; + mov.b16 spc,0x34DEU; + mov.b16 ulp,0x9800U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0x9766U; + mov.b16 ulp,0x9000U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0x9972U; + mov.b16 ulp,0x1000U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0xA5C4U; + mov.b16 ulp,0x1000U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 spc,0xBF0AU; + mov.b16 ulp,0x8100U; + set.eq.f16.f16 p, h, spc; + fma.rn.f16 r,p,ulp,r; + mov.b16 \$0, r; + }""", "=h,h", Float16, Tuple{Float16}, x) + return exp_x +end @device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x) @@ -204,6 +363,13 @@ end @device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0 @device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 +@device_override function Base.isnan(x::Float16) + if compute_capability() >= sv"8.0" + return (ccall("extern __nv_hisnan", llvmcall, Int32, (Float16,), x)) != 0 + else + return isnan(Float32(x)) + end +end @device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x) @device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x) @@ -223,14 +389,20 @@ end @device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x) @device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f) @device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f) -# TODO: enable once PTX > 7.0 is supported -# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) +@device_override Base.abs(f::Float16) = Float16(abs(Float32(f))) @device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x) ## roots and powers @device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) +@device_override function Base.sqrt(x::Float16) + if compute_capability() >= sv"8.0" + ccall("extern __nv_hsqrt", llvmcall, Float16, (Float16,), x) + else + return Float16(sqrt(Float32(x))) + end +end @device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x) @device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index d9f868a132..474570a4f8 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -2,7 +2,9 @@ using SpecialFunctions @testset "math" begin @testset "log10" begin - @test testf(a->log10.(a), Float32[100]) + for T in (Float32, Float64) + @test testf(a->log10.(a), T[100]) + end end @testset "pow" begin @@ -12,28 +14,34 @@ using SpecialFunctions @test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1)) end end + + @testset "min/max" begin + for T in (Float32, Float64) + @test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1)) + @test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1)) + end + end @testset "isinf" begin - for x in (Inf32, Inf, NaN32, NaN) + for x in (Inf32, Inf, NaN16, NaN32, NaN) @test testf(x->isinf.(x), [x]) end end @testset "isnan" begin - for x in (Inf32, Inf, NaN32, NaN) + for x in (Inf32, Inf, NaN16, NaN32, NaN) @test testf(x->isnan.(x), [x]) end end for op in (exp, angle, exp2, exp10,) @testset "$op" begin - for T in (Float16, Float32, Float64) + for T in (Float32, Float64) @test testf(x->op.(x), rand(T, 1)) @test testf(x->op.(x), -rand(T, 1)) end end end - for op in (expm1,) @testset "$op" begin # FIXME: add expm1(::Float16) to Base @@ -50,7 +58,6 @@ using SpecialFunctions @test testf(x->op.(x), rand(T, 1)) @test testf(x->op.(x), -rand(T, 1)) end - end end @testset "mod and rem" begin @@ -97,6 +104,18 @@ using SpecialFunctions # JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load @test testf(x->exp.(x), [1e7im]) end + + @testset "Real - $op" for op in (exp, abs, abs2, exp10, log10) + @testset "$T" for T in (Float16, Float32, Float64) + @test testf(x->op.(x), rand(T, 1)) + end + end + + @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) + @testset "$T" for T in (Float16, ) + @test testf(x->op.(x), rand(T, 1)) + end + end @testset "fastmath" begin # libdevice provides some fast math functions From f50e87bbc2599af5e40d612b501fa24326e630e8 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 10 Feb 2025 14:15:46 -0500 Subject: [PATCH 02/14] Fix bad sqrt override --- src/device/intrinsics/math.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 19da168c7e..484e8affbe 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -396,13 +396,7 @@ end @device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.sqrt(x::Float16) - if compute_capability() >= sv"8.0" - ccall("extern __nv_hsqrt", llvmcall, Float16, (Float16,), x) - else - return Float16(sqrt(Float32(x))) - end -end +@device_override function Base.sqrt(x::Float16) = Float16(sqrt(Float32(x))) @device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x) @device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) From 4608a2c3bec9bfd7b62ee36dd383483d845aeb49 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 10 Feb 2025 15:18:20 -0500 Subject: [PATCH 03/14] Actually fix --- src/device/intrinsics/math.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 484e8affbe..bc0fa7c2d2 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -363,13 +363,7 @@ end @device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0 @device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 -@device_override function Base.isnan(x::Float16) - if compute_capability() >= sv"8.0" - return (ccall("extern __nv_hisnan", llvmcall, Int32, (Float16,), x)) != 0 - else - return isnan(Float32(x)) - end -end +@device_override Base.isnan(x::Float16) = isnan(Float32(x)) @device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x) @device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x) @@ -396,7 +390,7 @@ end @device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.sqrt(x::Float16) = Float16(sqrt(Float32(x))) +@device_override Base.sqrt(x::Float16) = Float16(sqrt(Float32(x))) @device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x) @device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) From c6393f450eaa1cb2d0197f4184063c27ee7b32d6 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 11 Feb 2025 12:26:14 +0100 Subject: [PATCH 04/14] Replace inline assembly with native code. --- src/device/intrinsics/math.jl | 231 +++++++++++----------------------- 1 file changed, 76 insertions(+), 155 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index bc0fa7c2d2..15e8239777 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -102,70 +102,38 @@ end @device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.log(x::Float16) - log_x = @asmcall("""{.reg.b32 f, C; - .reg.b16 r,h; - mov.b16 h,\$1; - cvt.f32.f16 f,h; - lg2.approx.ftz.f32 f,f; - mov.b32 C, 0x3f317218U; - mul.f32 f,f,C; - cvt.rn.f16.f32 r,f; - .reg.b16 spc, ulp, p; - mov.b16 spc, 0X160DU; - mov.b16 ulp, 0x9C00U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0X3BFEU; - mov.b16 ulp, 0x8010U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0X3C0BU; - mov.b16 ulp, 0x8080U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0X6051U; - mov.b16 ulp, 0x1C00U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 \$0,r; - }""", "=h,h", Float16, Tuple{Float16}, x) - return log_x +@device_override function Base.log(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath log(f) + r = Float16(f) + + # handle degenrate cases + r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r) + r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r) + r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r) + r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r) + + return r end @device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.log10(x::Float16) - log_x = @asmcall("""{.reg.b16 h, r; - .reg.b32 f, C; - mov.b16 h, \$1; - cvt.f32.f16 f, h; - lg2.approx.ftz.f32 f, f; - mov.b32 C, 0x3E9A209BU; - mul.f32 f,f,C; - cvt.rn.f16.f32 r, f; - .reg.b16 spc, ulp, p; - mov.b16 spc, 0x338FU; - mov.b16 ulp, 0x1000U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0x33F8U; - mov.b16 ulp, 0x9000U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0x57E1U; - mov.b16 ulp, 0x9800U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0x719DU; - mov.b16 ulp, 0x9C00U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 \$0, r; - }""", "=h,h", Float16, Tuple{Float16}, x) - return log_x +@device_override function Base.log10(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath log10(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r) + r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r) + r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r) + r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r) + + return r end @device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x) @@ -174,25 +142,17 @@ end @device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.log2(x::Float16) - log_x = @asmcall("""{.reg.b16 h, r; - .reg.b32 f; - mov.b16 h, \$1; - cvt.f32.f16 f, h; - lg2.approx.ftz.f32 f, f; - cvt.rn.f16.f32 r, f; - .reg.b16 spc, ulp, p; - mov.b16 spc, 0xA2E2U; - mov.b16 ulp, 0x8080U; - set.eq.f16.f16 p, r, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc, 0xBF46U; - mov.b16 ulp, 0x9400U; - set.eq.f16.f16 p, r, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 \$0, r; - }""", "=h,h", Float16, Tuple{Float16}, x) - return log_x +@device_override function Base.log2(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath log2(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r) + r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r) + + return r end @device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x) @@ -207,94 +167,55 @@ end @device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.exp(x::Float16) - exp_x = @asmcall("""{ - .reg.b32 f, C, nZ; - .reg.b16 h,r; - mov.b16 h,\$1; - cvt.f32.f16 f,h; - mov.b32 C, 0x3fb8aa3bU; - mov.b32 nZ, 0x80000000U; - fma.rn.f32 f,f,C,nZ; - ex2.approx.ftz.f32 f,f; - cvt.rn.f16.f32 r,f; - .reg.b16 spc, ulp, p; - mov.b16 spc,0X1F79U; - mov.b16 ulp,0x9400U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0X25CFU; - mov.b16 ulp,0x9400U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0XC13BU; - mov.b16 ulp,0x0400U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0XC1EFU; - mov.b16 ulp,0x0200U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 \$0,r; - }""", "=h,h", Float16, Tuple{Float16}, x) - return exp_x +@device_override function Base.exp(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32))) + f = @fastmath exp2(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r) + r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r) + r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r) + r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r) + + return r end @device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.exp2(x::Float16) - exp_x = @asmcall("""{.reg.b32 f, ULP; - .reg.b16 r; - mov.b16 r,\$1; - cvt.f32.f16 f,r; - ex2.approx.ftz.f32 f,f; - mov.b32 ULP, 0x33800000U; - fma.rn.f32 f,f,ULP,f; - cvt.rn.f16.f32 r,f; - mov.b16 \$0,r; - }""", "=h,h", Float16, Tuple{Float16}, x) - return exp_x +@device_override function Base.exp2(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = @fastmath exp2(f) + + # one ULP adjustement + f = muladd(f, reinterpret(Float32, 0x33800000), f) + r = Float16(f) + + return r end @device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) @device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) -@device_override function Base.exp10(x::Float16) - - exp_x = @asmcall("""{.reg.b16 h,r; - .reg.b32 f, C, nZ; - mov.b16 h, \$1; - cvt.f32.f16 f, h; - mov.b32 C, 0x40549A78U; - mov.b32 nZ, 0x80000000U; - fma.rn.f32 f,f,C,nZ; - ex2.approx.ftz.f32 f, f; - cvt.rn.f16.f32 r, f; - .reg.b16 spc, ulp, p; - mov.b16 spc,0x34DEU; - mov.b16 ulp,0x9800U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0x9766U; - mov.b16 ulp,0x9000U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0x9972U; - mov.b16 ulp,0x1000U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0xA5C4U; - mov.b16 ulp,0x1000U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 spc,0xBF0AU; - mov.b16 ulp,0x8100U; - set.eq.f16.f16 p, h, spc; - fma.rn.f16 r,p,ulp,r; - mov.b16 \$0, r; - }""", "=h,h", Float16, Tuple{Float16}, x) - return exp_x +@device_override function Base.exp10(h::Float16) + # perform computation in Float32 domain + f = Float32(h) + f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32))) + f = @fastmath exp2(f) + r = Float16(f) + + # handle degenerate cases + r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r) + r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r) + r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r) + r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r) + r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r) + + return r end @device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x) From 3b68fc5ea33e5ae9e57a592fdbcf9e61eea658ab Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 11 Feb 2025 15:08:35 +0100 Subject: [PATCH 05/14] Try enabling some more intrinsics. --- src/device/intrinsics/math.jl | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 15e8239777..607e03d462 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -82,9 +82,13 @@ end @device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x) -# TODO: enable once PTX > 7.0 is supported -# @device_override Base.tanh(x::Float16) = @asmcall("tanh.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) - +@device_override function Base.tanh(x::Float16) + if compute_capability() >= sv"7.5" + @asmcall("tanh.approx.f16 \$0, \$1;", "=r,r", Float16, Tuple{Float16}, x) + else + Float16(tanh(Float32(x))) + end +end ## inverse hyperbolic @@ -197,7 +201,16 @@ end return r end -@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) +@device_override FastMath.exp2_fast(x::Float64) = exp2(x) +@device_override FastMath.exp2_fast(x::Float32) = + @asmcall("ex2.approx.f32 \$0, \$1;", "=r,r", Float32, Tuple{Float32}, x) +@device_override function FastMath.exp2_fast(x::Float16) + if compute_capability() >= sv"7.5" + ccall("llvm.nvvm.ex2.approx.f16", llvmcall, Float16, (Float16,), x) + else + exp2(x) + end +end @device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) @device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) From 78d2b09fc68a504d11378d7e0d9e0a55cc6f05ee Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 26 Feb 2025 15:22:34 -0500 Subject: [PATCH 06/14] Replace magic constants and test all floats --- src/device/intrinsics/math.jl | 4 ++-- test/core/device/intrinsics/math.jl | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 607e03d462..360c8335de 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -174,7 +174,7 @@ end @device_override function Base.exp(h::Float16) # perform computation in Float32 domain f = Float32(h) - f = fma(f, reinterpret(Float32, 0x3fb8aa3b), reinterpret(Float32, Base.sign_mask(Float32))) + f = fma(f, reinterpret(Float32, reinterpret(UInt32, log2(Float32(ℯ)))), -0.0f0) f = @fastmath exp2(f) r = Float16(f) @@ -217,7 +217,7 @@ end @device_override function Base.exp10(h::Float16) # perform computation in Float32 domain f = Float32(h) - f = fma(f, reinterpret(Float32, 0x40549A78), reinterpret(Float32, Base.sign_mask(Float32))) + f = fma(f, reinterpret(Float32, reinterpret(UInt32, log2(10.f0))), reinterpret(Float32, Base.sign_mask(Float32))) f = @fastmath exp2(f) r = Float16(f) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index 474570a4f8..32b6626007 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -112,8 +112,9 @@ using SpecialFunctions end @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) - @testset "$T" for T in (Float16, ) - @test testf(x->op.(x), rand(T, 1)) + all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16)) + for each_float in all_float_16 + @test testf(x->op.(x), Float16[each_float]) end end From c724c334a8dd56ebaae57a38cbe171fd7bcf842e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 27 Feb 2025 11:36:31 -0500 Subject: [PATCH 07/14] Update test/core/device/intrinsics/math.jl Co-authored-by: Valentin Churavy --- test/core/device/intrinsics/math.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index 32b6626007..d0bbd42ab5 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -113,6 +113,7 @@ using SpecialFunctions @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16)) + all_float_16 = filter(!isnan, all_float_16) for each_float in all_float_16 @test testf(x->op.(x), Float16[each_float]) end From 1b5296996e43f95965b40008e1421e288f91b4b6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 27 Feb 2025 11:40:57 -0500 Subject: [PATCH 08/14] Update src/device/intrinsics/math.jl Co-authored-by: Valentin Churavy --- src/device/intrinsics/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 360c8335de..8614ecdf11 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -217,7 +217,7 @@ end @device_override function Base.exp10(h::Float16) # perform computation in Float32 domain f = Float32(h) - f = fma(f, reinterpret(Float32, reinterpret(UInt32, log2(10.f0))), reinterpret(Float32, Base.sign_mask(Float32))) + f = fma(f, log2(10.f0), reinterpret(Float32, Base.sign_mask(Float32))) f = @fastmath exp2(f) r = Float16(f) From 552f947242ec93c831f2851b99c86e706501a07e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 27 Feb 2025 11:41:58 -0500 Subject: [PATCH 09/14] Use map to avoid spawing lots of kernels --- test/core/device/intrinsics/math.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index d0bbd42ab5..cb5008d246 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -114,9 +114,7 @@ using SpecialFunctions @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16)) all_float_16 = filter(!isnan, all_float_16) - for each_float in all_float_16 - @test testf(x->op.(x), Float16[each_float]) - end + @test testf(x->map(op, x), all_float_16) end @testset "fastmath" begin From 15c9deefc435fa481616fae6ce817778a3989a5a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 27 Feb 2025 14:43:57 -0500 Subject: [PATCH 10/14] Update src/device/intrinsics/math.jl Co-authored-by: Valentin Churavy --- src/device/intrinsics/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 8614ecdf11..4d68e320cc 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -174,7 +174,7 @@ end @device_override function Base.exp(h::Float16) # perform computation in Float32 domain f = Float32(h) - f = fma(f, reinterpret(Float32, reinterpret(UInt32, log2(Float32(ℯ)))), -0.0f0) + f = fma(f, log2(Float32(ℯ)), -0.0f0) f = @fastmath exp2(f) r = Float16(f) From 1a1bdfb5113575efacd9fa70fb36a2eb6863b93e Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 1 Mar 2025 08:01:22 +0100 Subject: [PATCH 11/14] Update test/core/device/intrinsics/math.jl --- test/core/device/intrinsics/math.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index cb5008d246..ae1d6c376d 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -114,6 +114,9 @@ using SpecialFunctions @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16)) all_float_16 = filter(!isnan, all_float_16) + if op in (log, log2, log10) + all_float_16 = filter(>=(0), all_float_16) + end @test testf(x->map(op, x), all_float_16) end From 541cbf2218d25e409e1dc9e7ce8001d1d921d77c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 3 Mar 2025 09:05:31 -0500 Subject: [PATCH 12/14] Update test/core/device/intrinsics/math.jl Co-authored-by: Valentin Churavy --- test/core/device/intrinsics/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index ae1d6c376d..e5f943cede 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -105,7 +105,7 @@ using SpecialFunctions @test testf(x->exp.(x), [1e7im]) end - @testset "Real - $op" for op in (exp, abs, abs2, exp10, log10) + @testset "Real - $op" for op in (abs, abs2, exp, exp10, log, log10) @testset "$T" for T in (Float16, Float32, Float64) @test testf(x->op.(x), rand(T, 1)) end From 6669b83602243dfe2887d64bd029a71466e7499f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 3 Mar 2025 09:05:42 -0500 Subject: [PATCH 13/14] Update test/core/device/intrinsics/math.jl Co-authored-by: Valentin Churavy --- test/core/device/intrinsics/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/core/device/intrinsics/math.jl b/test/core/device/intrinsics/math.jl index e5f943cede..6fa5768df9 100644 --- a/test/core/device/intrinsics/math.jl +++ b/test/core/device/intrinsics/math.jl @@ -111,7 +111,7 @@ using SpecialFunctions end end - @testset "Float16 - $op" for op in (log,exp,exp2,exp10,log2,log10) + @testset "Float16 - $op" for op in (exp,exp2,exp10,log,log2,log10) all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16)) all_float_16 = filter(!isnan, all_float_16) if op in (log, log2, log10) From e0354d0427c804c20cdc466e27e2585d280e5c73 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 3 Mar 2025 09:05:55 -0500 Subject: [PATCH 14/14] Update src/device/intrinsics/math.jl Co-authored-by: Valentin Churavy --- src/device/intrinsics/math.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index 4d68e320cc..742058006f 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -217,7 +217,7 @@ end @device_override function Base.exp10(h::Float16) # perform computation in Float32 domain f = Float32(h) - f = fma(f, log2(10.f0), reinterpret(Float32, Base.sign_mask(Float32))) + f = fma(f, log2(10.f0), -0.0f0) f = @fastmath exp2(f) r = Float16(f)