diff --git a/src/cluster.jl b/src/cluster.jl index e3689b9..127c77d 100755 --- a/src/cluster.jl +++ b/src/cluster.jl @@ -26,23 +26,23 @@ function clusterdepth( rng, data::AbstractArray; τ=2.3, - stat_type = :onesample_ttest, - perm_type = :sign, - side_type = :abs, - nperm = 5000, - pval_type = :troendle, - (statfun!) = nothing, - statfun = nothing, - permfun = nothing, + stat_type=:onesample_ttest, + perm_type=:sign, + side_type=:abs, + nperm=5000, + pval_type=:troendle, + (statfun!)=nothing, + statfun=nothing, + permfun=nothing, ) if stat_type == :onesample_ttest && isnothing(statfun!) && isnothing(statfun) statfun! = studentt! statfun = studentt end if perm_type == :sign - if isnothing(permfun) - permfun = sign_permute! - end + if isnothing(permfun) + permfun = sign_permute! + end end if side_type == :abs sidefun = abs @@ -51,11 +51,15 @@ function clusterdepth( elseif side_type == :negative sidefun = x -> -x elseif side_type == :positive - sidefun = nothing # the default :) + sidefun = x -> x # the default :) else @assert isnothing(side_type) "unknown side_type ($side_type) specified. Check your spelling and ?clusterdepth" end + data_obs = sidefun.(statfun(data)) + if any(data_obs[:, 1] .> τ) || any(data_obs[:, end] .> τ) + @warn "Your data shows a cluster that starts before the first sample, or ends after the last sample. There exists a fundamental limit in the ClusterDepth method, that the clusterdepth for such a cluster cannot be determined. Maybe you can extend the epoch to include more samples?" + end cdmTuple = perm_clusterdepths_both( rng, data, @@ -67,7 +71,7 @@ function clusterdepth( sidefun=sidefun, ) - return pvals(statfun(data), cdmTuple, τ; type=pval_type) + return pvals(data_obs, cdmTuple, τ; type=pval_type) end @@ -109,9 +113,9 @@ function perm_clusterdepths_both( # inplace! statfun!(d0, d_perm) end - if !isnothing(sidefun) - d0 .= sidefun.(d0) - end + + d0 .= sidefun.(d0) + # get clusterdepth (fromTo, head, tail) = calc_clusterdepth(d0, τ) diff --git a/src/utils.jl b/src/utils.jl index 16afa0a..27410b4 100755 --- a/src/utils.jl +++ b/src/utils.jl @@ -90,7 +90,7 @@ my_statfun = x->studentt_unpaired(x,grp) ``` """ -function studentt_unpaired(x::AbstractArray, group) +function studentt_unpaired(x, group) x_reshaped = reshape(x, :, size(x, ndims(x))) x₁ = x_reshaped[:, group] x₂ = x_reshaped[:, .!group] diff --git a/test/cluster.jl b/test/cluster.jl index 3dae67d..ad54fbe 100755 --- a/test/cluster.jl +++ b/test/cluster.jl @@ -16,14 +16,52 @@ ) @test s == [3, 5] @test l == [0, 1] + + s, l = ClusterDepth.cluster( + [4.0, 0.0, 10.0, 0.0, 3.0, 4.0, 0, 4.0, 4.0] .> 0.9, + ) end @testset "Tests for 2D data" begin data = randn(StableRNG(1), 4, 5) - @show ClusterDepth.calc_clusterdepth(data, 0) + res = ClusterDepth.clusterdepth(data; τ=0.4) + @test size(res) == (4,) end @testset "Tests for 3D data" begin data = randn(StableRNG(1), 3, 20, 5) - @show ClusterDepth.clusterdepth(data; τ=0.4, nperm=5) + res = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5) + @test size(res) == (3, 20) end +@testset "Test sidefun" begin + data = randn(StableRNG(1), 23, 20) + data[3:8, :] .+= 3 + data[12:17, :] .-= 3 + res = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5) + res_negated = ClusterDepth.clusterdepth(.-data; τ=0.4, nperm=5) + @test res ≈ res_negated # should be same if side_type=:abs + + + # testing the default is abs + res_abs = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5, side_type=:abs) + @test res ≈ res_abs + + + res_pos = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5, side_type=:positive) + @test all(res_pos[3:8] .< 0.8) + res_neg = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5, side_type=:negative) + @test all(res_neg[12:17] .< 0.8) + + +end +@testset "Test warning clusterbegin/end" begin + + # test the warning that clusters must not begin/end with a potentially significant cluster + data = randn(StableRNG(1), 23, 20) + data[1:5, :] .+= 3 + @test_warn x -> occursin("Your data shows a cluster", x) Warning ClusterDepth.clusterdepth(data; τ=0.4, nperm=5) + data = randn(StableRNG(1), 23, 20) + data[23, :] .-= 3 + @test_warn x -> occursin("Your data shows a cluster", x) Warning ClusterDepth.clusterdepth(data; τ=0.4, nperm=5) + +end \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index 8351f86..8d75d31 100755 --- a/test/utils.jl +++ b/test/utils.jl @@ -67,9 +67,9 @@ end x2 = _x[:, :, 2] x = hcat(x1, x2) group = repeat([false, true], inner=size(x1, 2)) - t = ClusterDepth.studentt_unpaired(x, group .== 1) + @benchmark t = ClusterDepth.studentt_unpaired(x, group) - t_true = [HypothesisTests.UnequalVarianceTTest(r[group], r[.!group]).t for r in eachrow(x)] + @benchmark t_true = [HypothesisTests.UnequalVarianceTTest(r[group], r[.!group]).t for r in eachrow(x)] @test all(t .≈ t_true) @test length(t) == 10000 @@ -80,5 +80,5 @@ end x = cat(x1, x2, dims=3) group = repeat([false, true], inner=size(x1, 3)) t = ClusterDepth.studentt_unpaired(x, group .== 1) - @test size(t) == (1000, 50) + @test size(t) == (10000, 50) end \ No newline at end of file