Skip to content

Commit bf110d6

Browse files
committed
eigen working
1 parent d97deaf commit bf110d6

File tree

2 files changed

+72
-74
lines changed

2 files changed

+72
-74
lines changed

src/MultiResponseVarianceComponentModels.jl

+27-25
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ export fit!,
2222
rg,
2323
permute
2424

25-
struct MRVCModel{T <: BlasReal}
25+
abstract type VCModel end
26+
27+
struct MRVCModel{T <: BlasReal} <: VCModel
2628
# data
2729
Y :: Matrix{T}
2830
X :: Matrix{T}
@@ -284,30 +286,7 @@ MRVCModel(Y, V::AbstractMatrix; kwargs...) = MRVCModel(Y, [V]; kwargs...)
284286

285287
const MultiResponseVarianceComponentModel = MRVCModel
286288

287-
function Base.show(io::IO, model::MRVCModel)
288-
if model.reml
289-
n, d, p, m = size(model.Y_reml, 1), size(model.Y_reml, 2), size(model.X_reml, 2), length(model.V_reml)
290-
else
291-
n, d, p, m = size(model.Y, 1), size(model.Y, 2), size(model.X, 2), length(model.V)
292-
end
293-
if d == 1
294-
printstyled(io, "A univariate response variance component model\n"; underline = true)
295-
elseif d == 2
296-
printstyled(io, "A bivariate response variance component model\n"; underline = true)
297-
else
298-
printstyled(io, "A multivariate response variance component model\n"; underline = true)
299-
end
300-
print(io, " * number of responses: ")
301-
printstyled(io, "$d\n"; color = :yellow)
302-
print(io, " * number of observations: ")
303-
printstyled(io, "$n\n"; color = :yellow)
304-
print(io, " * number of fixed effects: ")
305-
printstyled(io, "$p\n"; color = :yellow)
306-
print(io, " * number of variance components: ")
307-
printstyled(io, "$m"; color = :yellow)
308-
end
309-
310-
struct MRTVCModel{T <: BlasReal}
289+
struct MRTVCModel{T <: BlasReal} <: VCModel
311290
# data
312291
Y :: Matrix{T}
313292
:: Matrix{T}
@@ -469,6 +448,29 @@ function MRTVCModel(
469448
)
470449
end
471450

451+
function Base.show(io::IO, model::VCModel)
452+
if model.reml
453+
n, d, p, m = size(model.Y_reml, 1), size(model.Y_reml, 2), size(model.X_reml, 2), length(model.V_reml)
454+
else
455+
n, d, p, m = size(model.Y, 1), size(model.Y, 2), size(model.X, 2), length(model.V)
456+
end
457+
if d == 1
458+
printstyled(io, "A univariate response variance component model\n"; underline = true)
459+
elseif d == 2
460+
printstyled(io, "A bivariate response variance component model\n"; underline = true)
461+
else
462+
printstyled(io, "A multivariate response variance component model\n"; underline = true)
463+
end
464+
print(io, " * number of responses: ")
465+
printstyled(io, "$d\n"; color = :yellow)
466+
print(io, " * number of observations: ")
467+
printstyled(io, "$n\n"; color = :yellow)
468+
print(io, " * number of fixed effects: ")
469+
printstyled(io, "$p\n"; color = :yellow)
470+
print(io, " * number of variance components: ")
471+
printstyled(io, "$m"; color = :yellow)
472+
end
473+
472474
include("multivariate_calculus.jl")
473475
include("reml.jl")
474476
include("fit.jl")

src/eigen.jl

+45-49
Original file line numberDiff line numberDiff line change
@@ -41,37 +41,30 @@ function fit!(
4141
for k in 1:2
4242
model.Σ[k] .= inv(tr(model.V[k])) .* model.storage_d_d_1
4343
end
44-
copy!(model.storage_d_d_1, model.Σ[1])
45-
copy!(model.storage_d_d_2, model.Σ[2])
46-
Λ, Φ = eigen!(Symmetric(model.storage_d_d_1), Symmetric(model.storage_d_d_2))
47-
copy!(model.Λ, Λ)
48-
copy!(model.Φ, Φ)
49-
copyto!(model.logdetΣ2, logdet(model.Σ[2]))
50-
mul!(model.ỸΦ, model.Ỹ, model.Φ)
51-
mul!(model.BΦ, model.B, model.Φ)
52-
update_res!(model) # update R̃Φ
44+
update_Φ!(model)
45+
update_res!(model)
5346
elseif init == :user
54-
copy!(model.storage_d_d_1, model.Σ[1])
55-
copy!(model.storage_d_d_2, model.Σ[2])
56-
Λ, Φ = eigen!(Symmetric(model.storage_d_d_1), Symmetric(model.storage_d_d_2))
57-
copy!(model.Λ, Λ)
58-
copy!(model.Φ, Φ)
59-
copyto!(model.logdetΣ2, logdet(model.Σ[2]))
60-
mul!(model.ỸΦ, model.Ỹ, model.Φ)
61-
mul!(model.BΦ, model.B, model.Φ)
62-
update_res!(model) # update R̃Φ
47+
update_Φ!(model)
48+
update_res!(model)
6349
else
6450
throw("Cannot recognize initialization method $init")
6551
end
6652
logl = loglikelihood!(model)
6753
toc = time()
54+
verbose && println("iter = 0, logl = $logl")
55+
IterativeSolvers.nextiter!(history)
56+
push!(history, :iter , 0)
57+
push!(history, :logl , logl)
58+
push!(history, :itertime, toc - tic)
6859
# MM loop
6960
for iter in 1:maxiter
7061
IterativeSolvers.nextiter!(history)
7162
tic = time()
7263
# initial estiamte of Σ[i] can be lousy, so we update Σ[i] first in the 1st round
7364
p > 0 && iter > 1 && update_B!(model)
7465
update_Σ!(model)
66+
update_Φ!(model)
67+
update_res!(model)
7568
logl_prev = logl
7669
logl = loglikelihood!(model)
7770
toc = time()
@@ -157,7 +150,7 @@ function update_Σ!(
157150
end
158151
lmul!(Diagonal(one(T) ./ model.storage_d_1), vecs)
159152
mul!(model.storage_d_d_1, transpose(Φinv), vecs)
160-
mul!(model.Σ[1], transpose(model.storage_d_d_1), model.storage_d_d_1)
153+
mul!(model.Σ[1], model.storage_d_d_1, transpose(model.storage_d_d_1))
161154
# update Σ[2]
162155
lmul!(Diagonal(model.storage_d_2), model.N2tN2)
163156
rmul!(model.N2tN2, Diagonal(model.storage_d_2))
@@ -176,8 +169,13 @@ function update_Σ!(
176169
end
177170
lmul!(Diagonal(one(T) ./ model.storage_d_2), vecs)
178171
mul!(model.storage_d_d_1, transpose(Φinv), vecs)
179-
mul!(model.Σ[2], transpose(model.storage_d_d_1), model.storage_d_d_1)
180-
# update parameters
172+
mul!(model.Σ[2], model.storage_d_d_1, transpose(model.storage_d_d_1))
173+
model.Σ
174+
end
175+
176+
function update_Φ!(
177+
model :: MRTVCModel{T};
178+
) where T <: BlasReal
181179
copy!(model.storage_d_d_1, model.Σ[1])
182180
copy!(model.storage_d_d_2, model.Σ[2])
183181
Λ, Φ = eigen!(Symmetric(model.storage_d_d_1), Symmetric(model.storage_d_d_2))
@@ -186,8 +184,30 @@ function update_Σ!(
186184
copyto!(model.logdetΣ2, logdet(model.Σ[2]))
187185
mul!(model.ỸΦ, model.Ỹ, model.Φ)
188186
mul!(model.BΦ, model.B, model.Φ)
189-
update_res!(model) # update R̃Φ
190-
model.Σ
187+
end
188+
189+
function update_res!(
190+
model :: MRTVCModel{T}
191+
) where T <: BlasReal
192+
# update R̃Φ = (Ỹ - X̃B)Φ
193+
BLAS.gemm!('N', 'N', -one(T), model.X̃, model.BΦ, one(T), copyto!(model.R̃Φ, model.ỸΦ))
194+
model.R̃Φ
195+
end
196+
197+
function loglikelihood!(
198+
model :: MRTVCModel{T}
199+
) where T <: BlasReal
200+
n, d = size(model.Ỹ, 1), size(model.Ỹ, 2)
201+
# assemble pieces for log-likelihood
202+
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2[1]
203+
@inbounds for j in 1:d
204+
λj = model.Λ[j]
205+
@simd for i in 1:n
206+
tmp = model.D[i] * λj + one(T)
207+
logl += log(tmp) + inv(tmp) * model.R̃Φ[i, j]^2
208+
end
209+
end
210+
logl /= -2
191211
end
192212

193213
function update_B!(
@@ -216,30 +236,6 @@ function update_B!(
216236
model.B
217237
end
218238

219-
function loglikelihood!(
220-
model :: MRTVCModel{T}
221-
) where T <: BlasReal
222-
n, d = size(model.Ỹ, 1), size(model.Ỹ, 2)
223-
# assemble pieces for log-likelihood
224-
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2[1]
225-
@inbounds for j in 1:d
226-
λj = model.Λ[j]
227-
@simd for i in 1:n
228-
tmp = model.D[i] * λj + one(T)
229-
logl += log(tmp) + inv(tmp) * model.R̃Φ[i, j]^2
230-
end
231-
end
232-
logl /= -2
233-
end
234-
235-
function update_res!(
236-
model :: MRTVCModel{T}
237-
) where T <: BlasReal
238-
# update R̃Φ = (Ỹ - X̃B)Φ
239-
BLAS.gemm!('N', 'N', -one(T), model.X̃, model.BΦ, one(T), copyto!(model.R̃Φ, model.ỸΦ))
240-
model.R̃Φ
241-
end
242-
243239
function fisher_B!(
244240
model :: MRTVCModel{T}
245241
) where T <: BlasReal
@@ -252,7 +248,7 @@ function fisher_B!(
252248
end
253249
lmul!(Diagonal(model.storage_nd_1), model.storage_nd_pd)
254250
mul!(model.storage_pd_pd, transpose(model.storage_nd_pd), model.storage_nd_pd)
255-
copyto!(model.Bcov, pinv(storage_pd_pd))
251+
copyto!(model.Bcov, pinv(model.storage_pd_pd))
256252
end
257253

258254
function fisher_Σ!(
@@ -301,7 +297,7 @@ function fisher_Σ!(
301297
idx1 = Int(d * (d + 1) / 2 * (i - 1) + 1)
302298
idx2 = Int(d * (d + 1) / 2 * i)
303299
idx5, idx6 = d^2 * (i - 1) + 1, d^2 * i
304-
for j in i:m
300+
for j in i:2
305301
idx3 = Int(d * (d + 1) / 2 * (j - 1) + 1)
306302
idx4 = Int(d * (d + 1) / 2 * j)
307303
idx7, idx8 = d^2 * (j - 1) + 1, d^2 * j

0 commit comments

Comments
 (0)