Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests #157

Open
wants to merge 1 commit into
base: cs-dev
Choose a base branch
from
Open

tests #157

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/combinators/conditional.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# test/combinators/conditional.jl
using Test
using MeasureBase
using Random: MersenneTwister

@testset "Conditional" begin
# Create a simple conditional measure
base_measure = StdNormal()
condition(x) = abs(x) <= 2 # Only accept values in [-2, 2]

cond_measure = @inferred Conditional(base_measure, condition)

# Test basic properties
@test basemeasure(cond_measure) === base_measure

# Test sampling with rejection sampling
rng = MersenneTwister(123)
samples = [rand(rng, cond_measure) for _ in 1:100]
@test all(condition, samples)

# Test density
x = 1.0
@test logdensityof(cond_measure, x) ≈ logdensityof(base_measure, x)
@test logdensityof(cond_measure, 3.0) == -Inf # Outside condition

# Test support
@test insupport(cond_measure, 1.0)
@test !insupport(cond_measure, 3.0)
end
74 changes: 74 additions & 0 deletions test/combinators/half.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using Test
using MeasureBase
using Random: MersenneTwister
using LogExpFunctions: loghalf
using IrrationalConstants: log2π

@testset "Half" begin
rng = MersenneTwister(42)

@testset "Basic properties" begin
μ = Half(StdNormal())

# Test show method
@test sprint(show, μ) == "Half(StdNormal())"

# Test unhalf
@test unhalf(μ) === StdNormal()

# Test basemeasure
@test basemeasure(μ) isa WeightedMeasure
@test _logweight(basemeasure(μ)) ≈ log(2)
end

@testset "Sampling and density" begin
μ = Half(StdNormal())
n_samples = 1000
samples = [rand(rng, μ) for _ in 1:n_samples]

# All samples should be non-negative
@test all(x -> x ≥ 0, samples)

# Test density at specific points
x = 1.0
expected_log_density = -0.5 * (x^2 + log2π) - loghalf
@test logdensityof(μ, x) ≈ expected_log_density

# Test density at negative points
@test logdensityof(μ, -1.0) == -Inf

# Test density at zero
@test isfinite(logdensityof(μ, 0.0))
end

@testset "Transport" begin
μ = Half(StdNormal())

# Test transport to/from uniform
u = 0.7 # arbitrary point in (0,1)
x = transport_to(μ, StdUniform(), u)
@test x ≥ 0
@test transport_to(StdUniform(), μ, x) ≈ u

# Test edge cases
@test transport_to(μ, StdUniform(), 0.0) == 0.0
@test transport_to(μ, StdUniform(), 1.0) > 0
end

@testset "SMF (Standardized Measure Function)" begin
μ = Half(StdUniform())

# Test SMF properties
@test smf(μ, -1.0) == -1.0 # Below support
@test smf(μ, 0.0) == -1.0 # At lower bound
@test smf(μ, 0.5) == 0.0 # Midpoint
@test smf(μ, 1.0) == 1.0 # At upper bound

# Test inverse SMF
for p in [0.0, 0.25, 0.5, 0.75, 1.0]
x = invsmf(μ, p)
@test smf(μ, x) ≈ p
@test 0 ≤ x ≤ 1
end
end
end
54 changes: 54 additions & 0 deletions test/combinators/implicitlymapped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,57 @@ using AffineMaps, PropertyFunctions
obs,
)
end

using Test
using MeasureBase
using Static: static
using Random: MersenneTwister

@testset "TakeAny" begin
rng = MersenneTwister(42)

@testset "Basic properties" begin
take2 = TakeAny(2)
take_static2 = TakeAny(static(2))

# Test with various collection types
arr = [1,2,3,4,5]
@test length(take2(arr)) == 2
@test length(take_static2(arr)) == 2

# Test consistency
@test take2(arr) == take2(arr) # Same elements when called multiple times

# Test with different sized inputs
@test length(take2(1:10)) == 2
@test length(take2(1:1)) == 1 # Should handle cases where input is smaller than n
end

@testset "Implicit mapping with TakeAny" begin
# Create a kernel that produces a product measure
kernel = par -> StdNormal()^3

# Create mapped version that only looks at first two components
mapped_kernel = ImplicitlyMapped(kernel, TakeAny(2))

# Test with some parameter value
par = 1.0
full_measure = kernel(par)
mapped_measure = explicit_kernel(mapped_kernel, rand(rng, full_measure))(par)

# Check dimensions
@test getdof(mapped_measure) == 2
@test getdof(full_measure) == 3

# Test consistency of mapping
obs1 = rand(rng, full_measure)
obs2 = rand(rng, full_measure)

mapped1 = mapped_kernel.mapfunc(obs1)
mapped2 = mapped_kernel.mapfunc(obs2)

# Same elements should be selected consistently
@test length(mapped1) == 2
@test mapped_kernel.mapfunc(obs1) == mapped1 # Consistent mapping
end
end
48 changes: 48 additions & 0 deletions test/domains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# test/domains.jl
using Test
using MeasureBase
using Static: static
using Random: MersenneTwister

@testset "Domains" begin
@testset "BoundedInts" begin
bounded = ℤ[1:5]
@test 3 ∈ bounded
@test -1 ∉ bounded
@test 6 ∉ bounded
@test 1.5 ∉ bounded
@test minimum(bounded) == 1
@test maximum(bounded) == 5
@test testvalue(bounded) == 0

# Test show method
@test sprint(show, bounded) == "ℤ[1:5]"
end

@testset "ZeroSet" begin
# Simple quadratic function and its gradient
f(x) = sum(x.^2)
∇f(x) = 2x
zs = ZeroSet(f, ∇f)

# Test points
@test zeros(3) ∈ zs
@test [1e-8, -1e-8, 1e-8] ∈ zs
@test [0.1, 0.1, 0.1] ∉ zs

# Test with different floating point types
@test zeros(Float32, 2) ∈ zs
@test zeros(Float64, 2) ∈ zs
end

@testset "IntegerNumbers" begin
@test minimum(ℤ) == static(-Inf)
@test maximum(ℤ) == static(Inf)

# Test membership
@test 42 ∈ ℤ
@test -42 ∈ ℤ
@test 3.14 ∉ ℤ
@test 2.0 ∈ ℤ # Integer-valued floats should be in ℤ
end
end
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using JET
# include("test_aqua.jl")

include("static.jl")
include("domains.jl")

include("test_primitive.jl")
include("test_standard.jl")
Expand All @@ -33,5 +34,6 @@ include("smf.jl")
include("combinators/weighted.jl")
include("combinators/transformedmeasure.jl")
include("combinators/implicitlymapped.jl")

include("combinators/conditional.jl")
include("combinators/half.jl")
include("test_docs.jl")
18 changes: 18 additions & 0 deletions test/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test
import MeasureBase

import Static
using StaticArrays
using Static: static
import FillArrays

Expand Down Expand Up @@ -32,3 +33,20 @@ import FillArrays
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) isa Static.StaticInt
@test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) == static(7)
end

@testset "maybestatic_size" begin
# Test regular array
arr = rand(MersenneTwister(123), 3, 4)
@test MeasureBase.maybestatic_size(arr) == (3, 4)

# Test static array
static_arr = SMatrix{2,2}([1 2; 3 4])
@test MeasureBase.maybestatic_size(static_arr) == (static(2), static(2))

# Test mixed static/dynamic array
mixed = zeros(static(2), 3) # Create a matrix with static first dimension
size_result = MeasureBase.maybestatic_size(mixed)
@test size_result[1] isa Static.StaticInt
@test size_result[2] isa Int
@test size_result == (static(2), 3)
end
Loading