Skip to content

Commit 835eef4

Browse files
authored
Merge pull request #14 from s-ccs/fix-studentt
fix studentt
2 parents 8ed0ce8 + 983577d commit 835eef4

File tree

4 files changed

+304
-29
lines changed

4 files changed

+304
-29
lines changed

src/utils.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function studentt!(out::AbstractMatrix, x::AbstractArray{<:Real,3}; kwargs...)
22

3-
for (x_ch, o_ch) in zip(eachslice(x, dims = 1), eachslice(out, dims = 1))
3+
for (x_ch, o_ch) in zip(eachslice(x, dims=1), eachslice(out, dims=1))
44
#@debug size(x_ch),size(o_ch)
55
studentt!(o_ch, x_ch; kwargs...)
66
end
@@ -44,17 +44,17 @@ end
4444
function studentt!(out, x)
4545
#@debug size(out),size(x)
4646
mean!(out, x)
47-
out .= out ./ (std(x, mean = out, dims = 2)[:, 1] ./ sqrt(size(x, 2) - 1))
47+
out .= out ./ (std(x, mean=out, dims=2, corrected=true)[:, 1] ./ sqrt(size(x, 2)))
4848
end
4949
function studentt(x::AbstractMatrix)
5050
# more efficient than this one liner
5151
# studentt(x::AbstractMatrix) = (mean(x,dims=2)[:,1])./(std(x,dims=2)[:,1]./sqrt(size(x,2)-1))
52-
μ = mean(x, dims = 2)[:, 1]
53-
μ .= μ ./ (std(x, mean = μ, dims = 2)[:, 1] ./ sqrt(size(x, 2) - 1))
52+
μ = mean(x, dims=2)[:, 1]
53+
μ .= μ ./ (std(x, mean=μ, dims=2, corrected=true)[:, 1] ./ sqrt(size(x, 2)))
5454
end
5555

5656
studentt(x::AbstractArray{<:Real,3}) =
57-
dropdims(mapslices(studentt, x, dims = (2, 3)), dims = 3)
57+
dropdims(mapslices(studentt, x, dims=(2, 3)), dims=3)
5858

5959
"""
6060
Permutation via random sign-flip
@@ -66,9 +66,26 @@ function sign_permute!(rng, x::AbstractArray)
6666

6767
fl = rand(rng, [-1, 1], size(x, n))
6868

69-
for (flip, xslice) in zip(fl, eachslice(x; dims = n))
69+
for (flip, xslice) in zip(fl, eachslice(x; dims=n))
7070
xslice .= xslice .* flip
7171
end
7272

7373
return x
7474
end
75+
76+
function batch_unpaired_ttest_unequal_var(x::AbstractArray, group)
77+
x_reshaped = reshape(x, :, size(x, ndims(x)))
78+
x₁ = x_reshaped[:, group]
79+
x₂ = x_reshaped[:, .!group]
80+
81+
n₁, n₂ = size(x₁, 2), size(x₂, 2)
82+
μ₁, μ₂ = mean(x₁, dims=2), mean(x₂, dims=2)
83+
var₁, var₂ = var(x₁, dims=2, corrected=true), var(x₂, dims=2, corrected=true)
84+
85+
se = sqrt.(var₁ ./ n₁ .+ var₂ ./ n₂)
86+
t_stat = (μ₁ .- μ₂) ./ se
87+
88+
# Reshape t_stat to match input dimensions (excluding last dim)
89+
t_stat_reshaped = reshape(t_stat, size(x)[1:end-1])
90+
return t_stat_reshaped
91+
end

0 commit comments

Comments
 (0)