diff --git a/src/dist/misc.jl b/src/dist/misc.jl index 938a0da9..33e0b5ee 100644 --- a/src/dist/misc.jl +++ b/src/dist/misc.jl @@ -15,6 +15,12 @@ tobits(x::Tuple) = frombits(x::Tuple, world) = map(v -> frombits(v, world), x) +tobits(x::Matrix) = + mapreduce(tobits, vcat, x) + +frombits(x::Matrix, world) = + map(v -> frombits(v, world), x) + Base.ifelse(cond::Dist{Bool}, then::Tuple, elze::Tuple) = Tuple(ifelse(cond, x, y) for (x, y) in zip(then,elze)) diff --git a/test/tuple_test.jl b/test/tuple_test.jl index 6cc07cb8..f95f0f0d 100644 --- a/test/tuple_test.jl +++ b/test/tuple_test.jl @@ -21,3 +21,18 @@ using Distributions end @test pr(cg)[(false, false, 3)] ≈ 0.5 + 0.5 * 0.8/2^3 end + +@testset "Probabilistic Matrix" begin + + x = [DistUInt{3}([false,false,flip(1.0/(i+j))]) for i=1:2, j=1:2] + @test getindex.(pr.(x), 1) ≈ [0.5 0.3333333333333333; 0.3333333333333333 0.25] + + # TODO next test is too slow, speed up dynamo + # y = @dice begin + # x*x + # end + + # pr(y)[[0 0; 0 0]] ≈ 0.333333 + +end +