Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ function clusterdepth(
rng,
data::AbstractArray;
τ=2.3,
stat_type = :onesample_ttest,
perm_type = :sign,
side_type = :abs,
nperm = 5000,
pval_type = :troendle,
(statfun!) = nothing,
statfun = nothing,
permfun = nothing,
stat_type=:onesample_ttest,
perm_type=:sign,
side_type=:abs,
nperm=5000,
pval_type=:troendle,
(statfun!)=nothing,
statfun=nothing,
permfun=nothing,
)
if stat_type == :onesample_ttest && isnothing(statfun!) && isnothing(statfun)
statfun! = studentt!
statfun = studentt
end
if perm_type == :sign
if isnothing(permfun)
permfun = sign_permute!
end
if isnothing(permfun)
permfun = sign_permute!
end
end
if side_type == :abs
sidefun = abs
Expand All @@ -51,11 +51,15 @@ function clusterdepth(
elseif side_type == :negative
sidefun = x -> -x
elseif side_type == :positive
sidefun = nothing # the default :)
sidefun = x -> x # the default :)
else
@assert isnothing(side_type) "unknown side_type ($side_type) specified. Check your spelling and ?clusterdepth"
end
data_obs = sidefun.(statfun(data))

if any(data_obs[:, 1] .> τ) || any(data_obs[:, end] .> τ)
@warn "Your data shows a cluster that starts before the first sample, or ends after the last sample. There exists a fundamental limit in the ClusterDepth method, that the clusterdepth for such a cluster cannot be determined. Maybe you can extend the epoch to include more samples?"
end
cdmTuple = perm_clusterdepths_both(
rng,
data,
Expand All @@ -67,7 +71,7 @@ function clusterdepth(
sidefun=sidefun,
)

return pvals(statfun(data), cdmTuple, τ; type=pval_type)
return pvals(data_obs, cdmTuple, τ; type=pval_type)
end


Expand Down Expand Up @@ -109,9 +113,9 @@ function perm_clusterdepths_both(
# inplace!
statfun!(d0, d_perm)
end
if !isnothing(sidefun)
d0 .= sidefun.(d0)
end

d0 .= sidefun.(d0)

# get clusterdepth
(fromTo, head, tail) = calc_clusterdepth(d0, τ)

Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ my_statfun = x->studentt_unpaired(x,grp)
```
"""
function studentt_unpaired(x::AbstractArray, group)
function studentt_unpaired(x, group)
x_reshaped = reshape(x, :, size(x, ndims(x)))
x₁ = x_reshaped[:, group]
x₂ = x_reshaped[:, .!group]
Expand Down
42 changes: 40 additions & 2 deletions test/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,52 @@
)
@test s == [3, 5]
@test l == [0, 1]

s, l = ClusterDepth.cluster(
[4.0, 0.0, 10.0, 0.0, 3.0, 4.0, 0, 4.0, 4.0] .> 0.9,
)
end

@testset "Tests for 2D data" begin
data = randn(StableRNG(1), 4, 5)
@show ClusterDepth.calc_clusterdepth(data, 0)
res = ClusterDepth.clusterdepth(data; τ=0.4)
@test size(res) == (4,)
end

@testset "Tests for 3D data" begin
data = randn(StableRNG(1), 3, 20, 5)
@show ClusterDepth.clusterdepth(data; τ=0.4, nperm=5)
res = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5)
@test size(res) == (3, 20)
end
@testset "Test sidefun" begin
data = randn(StableRNG(1), 23, 20)
data[3:8, :] .+= 3
data[12:17, :] .-= 3
res = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5)
res_negated = ClusterDepth.clusterdepth(.-data; τ=0.4, nperm=5)
@test res res_negated # should be same if side_type=:abs


# testing the default is abs
res_abs = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5, side_type=:abs)
@test res res_abs


res_pos = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5, side_type=:positive)
@test all(res_pos[3:8] .< 0.8)
res_neg = ClusterDepth.clusterdepth(data; τ=0.4, nperm=5, side_type=:negative)
@test all(res_neg[12:17] .< 0.8)


end
@testset "Test warning clusterbegin/end" begin

# test the warning that clusters must not begin/end with a potentially significant cluster
data = randn(StableRNG(1), 23, 20)
data[1:5, :] .+= 3
@test_warn x -> occursin("Your data shows a cluster", x) Warning ClusterDepth.clusterdepth(data; τ=0.4, nperm=5)
data = randn(StableRNG(1), 23, 20)
data[23, :] .-= 3
@test_warn x -> occursin("Your data shows a cluster", x) Warning ClusterDepth.clusterdepth(data; τ=0.4, nperm=5)

end
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ end
x2 = _x[:, :, 2]
x = hcat(x1, x2)
group = repeat([false, true], inner=size(x1, 2))
t = ClusterDepth.studentt_unpaired(x, group .== 1)
@benchmark t = ClusterDepth.studentt_unpaired(x, group)

t_true = [HypothesisTests.UnequalVarianceTTest(r[group], r[.!group]).t for r in eachrow(x)]
@benchmark t_true = [HypothesisTests.UnequalVarianceTTest(r[group], r[.!group]).t for r in eachrow(x)]
@test all(t .≈ t_true)
@test length(t) == 10000

Expand All @@ -80,5 +80,5 @@ end
x = cat(x1, x2, dims=3)
group = repeat([false, true], inner=size(x1, 3))
t = ClusterDepth.studentt_unpaired(x, group .== 1)
@test size(t) == (1000, 50)
@test size(t) == (10000, 50)
end
Loading