diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 0b3e8e0..ef95afe 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -357,14 +357,14 @@ Printf.tofloat(x::BFloat16) = Float32(x) # Random import Random: rand, randn, randexp, AbstractRNG, Sampler -"""Sample a BFloat16 from [0,1) by setting random mantissa -bits for one(BFloat16) to obtain [1,2) (where floats are uniformly -distributed) then subtract 1 for [0,1).""" +"""Sample a BFloat16 from [0,1) by creating a random integer +in 0, ... 255 and then scaling it into [0,1). This samples +from every BFloat16 in [1/2, 1) but only from every other +in [1/4, 1/2), and every forth in [1/8, 1/4), etc. for a +uniform distribution.""" function rand(rng::AbstractRNG, ::Sampler{BFloat16}) - u = reinterpret(UInt16, one(BFloat16)) - # shift random bits into BFloat16 mantissa (1 sign + 8 exp bits = 9) - u |= rand(rng, UInt16) >> 9 # u in [1,2) - return reinterpret(BFloat16, u) - one(BFloat16) # -1 for [0,1) + # 0x1p-8 is 2^-8 = eps(BFloat16)/2 + return rand(rng, UInt8) * BFloat16(0x1p-8) end randn(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randn(rng)) diff --git a/test/runtests.jl b/test/runtests.jl index 02e4a1a..58a9ef7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -184,10 +184,8 @@ end # zero should be the lowest BFloat16 sampled @test mi === zero(BFloat16) - # prevfloat(one(BFloat16)) cannot be sampled bc - # prevfloat(BFloat16(2)) - 1 is _two_ before one(BFloat16) - # (a statistical flaw of the [1,2)-1 sampling) - @test ma === prevfloat(one(BFloat16), 2) + # prevfloat(one(BFloat16)) should be maximum + @test ma === prevfloat(one(BFloat16), 1) end include("structure.jl")