Skip to content

Commit 0ce5b4e

Browse files
committed
reml for eigen
1 parent edf6493 commit 0ce5b4e

File tree

2 files changed

+104
-60
lines changed

2 files changed

+104
-60
lines changed

src/MultiResponseVarianceComponentModels.jl

+30-26
Original file line numberDiff line numberDiff line change
@@ -325,20 +325,23 @@ struct MRTVCModel{T <: BlasReal} <: VCModel
325325
Σcov :: Union{Nothing, Matrix{T}} # for fisher_Σ!
326326
# original data for reml
327327
Y_reml :: Union{Nothing, Matrix{T}}
328+
Ỹ_reml :: Union{Nothing, Matrix{T}}
328329
X_reml :: Union{Nothing, Matrix{T}}
330+
X̃_reml :: Union{Nothing, Matrix{T}}
329331
V_reml :: Union{Nothing, Vector{Matrix{T}}}
332+
U_reml :: Union{Nothing, Matrix{T}}
333+
D_reml :: Union{Nothing, Vector{T}}
334+
logdetV2_reml :: Union{Nothing, T}
330335
# fixed effects parameters for reml
331336
B_reml :: Union{Nothing, Matrix{T}}
332337
# working arrays for reml
333-
Ω_reml :: Union{Nothing, Matrix{T}}
334-
R_reml :: Union{Nothing, Matrix{T}}
335-
storage_nd_nd_reml :: Union{Nothing, Matrix{T}}
336-
storage_pd_pd_reml :: Union{Nothing, Matrix{T}}
337-
storage_n_p_reml :: Union{Nothing, Matrix{T}}
338+
ỸΦ_reml :: Union{Nothing, Matrix{T}}
339+
R̃_reml :: Union{Nothing, Matrix{T}}
340+
R̃Φ_reml :: Union{Nothing, Matrix{T}}
341+
storage_nd_pd_reml :: Union{Nothing, Matrix{T}}
338342
storage_nd_1_reml :: Union{Nothing, Vector{T}}
339343
storage_nd_2_reml :: Union{Nothing, Vector{T}}
340-
storage_n_d_reml :: Union{Nothing, Matrix{T}}
341-
storage_p_d_reml :: Union{Nothing, Matrix{T}}
344+
storage_pd_pd_reml :: Union{Nothing, Matrix{T}}
342345
storage_pd_reml :: Union{Nothing, Vector{T}}
343346
logl_reml :: Union{Nothing, Vector{T}}
344347
Bcov_reml :: Union{Nothing, Matrix{T}}
@@ -376,25 +379,26 @@ function MRTVCModel(
376379
n, p = size(Y, 1), 0
377380
nd, pd = n * d, p * d
378381
nd_reml, pd_reml = n_reml * d, p_reml * d
382+
D_reml, U_reml = eigen(Symmetric(V_reml[1]), Symmetric(V_reml[2]))
383+
logdetV2_reml = logdet(V_reml[2])
384+
Ỹ_reml = transpose(U_reml) * Y_reml
385+
X̃_reml = transpose(U_reml) * X_reml
379386
B_reml = Matrix{T}(undef, p_reml, d)
380-
Ω_reml = Matrix{T}(undef, nd_reml, nd_reml)
381-
R_reml = Matrix{T}(undef, n_reml, d)
382-
storage_nd_nd_reml = Matrix{T}(undef, nd_reml, nd_reml)
383-
storage_pd_pd_reml = Matrix{T}(undef, pd_reml, pd_reml)
384-
storage_n_p_reml = Matrix{T}(undef, n_reml, p_reml)
387+
ỸΦ_reml = Matrix{T}(undef, n_reml, d)
388+
R̃_reml = Matrix{T}(undef, n_reml, d)
389+
R̃Φ_reml = Matrix{T}(undef, n_reml, d)
390+
storage_nd_pd_reml = Matrix{T}(undef, nd_reml, pd_reml)
385391
storage_nd_1_reml = Vector{T}(undef, nd_reml)
386392
storage_nd_2_reml = Vector{T}(undef, nd_reml)
387-
storage_n_d_reml = Matrix{T}(undef, n_reml, d)
388-
storage_p_d_reml = Matrix{T}(undef, p_reml, d)
393+
storage_pd_pd_reml = Matrix{T}(undef, pd_reml, pd_reml)
389394
storage_pd_reml = Vector{T}(undef, pd_reml)
390-
logl_reml = zeros(T, 1)
395+
logl_reml = zeros(T, 1)
391396
else
392-
Y_reml = X_reml = V_reml = B_reml = Ω_reml = R_reml =
393-
storage_nd_nd_reml = storage_pd_pd_reml =
394-
storage_n_p_reml = storage_nd_1_reml =
395-
storage_nd_2_reml = storage_n_d_reml =
396-
storage_p_d_reml = storage_pd_reml =
397-
logl_reml = Bcov_reml = nothing
397+
Y_reml = Ỹ_reml = X_reml = X̃_reml = V_reml = U_reml = D_reml =
398+
logdetV2_reml = B_reml = ỸΦ_reml = R̃_reml = R̃Φ_reml =
399+
storage_nd_pd_reml = storage_nd_1_reml =
400+
storage_nd_2_reml = storage_pd_pd_reml = storage_pd_reml =
401+
logl_reml = Bcov_reml = nothing
398402
end
399403
if se
400404
Bcov = Matrix{T}(undef, pd, pd)
@@ -439,11 +443,11 @@ function MRTVCModel(
439443
storage_d_1, storage_d_2, storage_d_d_1, storage_d_d_2,
440444
storage_p_p, storage_pd, storage_pd_pd,
441445
storage_nd_1, storage_nd_2, storage_nd_pd, logl, Bcov, Σcov,
442-
Y_reml, X_reml, V_reml, B_reml, Ω_reml, R_reml,
443-
storage_nd_nd_reml, storage_pd_pd_reml, storage_n_p_reml,
444-
storage_nd_1_reml, storage_nd_2_reml, storage_n_d_reml,
445-
storage_p_d_reml, storage_pd_reml, logl_reml, Bcov_reml,
446-
se, reml
446+
Y_reml, Ỹ_reml, X_reml, X̃_reml, V_reml, U_reml, D_reml,
447+
logdetV2_reml, B_reml, ỸΦ_reml, R̃_reml, R̃Φ_reml,
448+
storage_nd_pd_reml, storage_nd_1_reml,
449+
storage_nd_2_reml, storage_pd_pd_reml, storage_pd_reml,
450+
logl_reml, Bcov_reml, se, reml
447451
)
448452
end
449453

src/eigen.jl

+74-34
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ function fit!(
9595
break
9696
end
9797
end
98-
# if model.reml
99-
# update_B_reml!(model)
100-
# update_res_reml!(model)
101-
# mul!(model.R̃Φ_reml, model.R̃_reml, model.Φ)
102-
# copyto!(model.logl_reml, loglikelihood_reml!(model))
103-
# model.se ? fisher_B_reml!(model) : nothing
104-
# end
98+
if model.reml
99+
update_B_reml!(model)
100+
update_res_reml!(model)
101+
mul!(model.R̃Φ_reml, model.R̃_reml, model.Φ)
102+
copyto!(model.logl_reml, loglikelihood_reml!(model))
103+
model.se ? fisher_B_reml!(model) : nothing
104+
end
105105
log && IterativeSolvers.shrink!(history)
106106
history
107107
end
@@ -183,7 +183,6 @@ function update_Φ!(
183183
copy!(model.Λ, Λ)
184184
copy!(model.Φ, Φ)
185185
copyto!(model.logdetΣ2, logdet(model.Σ[2]))
186-
mul!(model.ỸΦ, model.Ỹ, model.Φ)
187186
end
188187

189188
function update_res!(
@@ -194,12 +193,20 @@ function update_res!(
194193
model.
195194
end
196195

196+
function update_res_reml!(
197+
model :: MRTVCModel{T}
198+
) where T <: BlasReal
199+
# update R̃ = Ỹ - X̃B
200+
BLAS.gemm!('N', 'N', -one(T), model.X̃_reml, model.B_reml, one(T), copyto!(model.R̃_reml, model.Ỹ_reml))
201+
model.
202+
end
203+
197204
function loglikelihood!(
198205
model :: MRTVCModel{T}
199206
) where T <: BlasReal
200207
n, d = size(model.Ỹ, 1), size(model.Ỹ, 2)
201208
# assemble pieces for log-likelihood
202-
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2[1]
209+
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2
203210
@inbounds for j in 1:d
204211
λj = model.Λ[j]
205212
@simd for i in 1:n
@@ -210,9 +217,26 @@ function loglikelihood!(
210217
logl /= -2
211218
end
212219

220+
function loglikelihood_reml!(
221+
model :: MRTVCModel{T}
222+
) where T <: BlasReal
223+
n, d = size(model.Ỹ_reml, 1), size(model.Ỹ_reml, 2)
224+
# assemble pieces for log-likelihood
225+
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2_reml
226+
@inbounds for j in 1:d
227+
λj = model.Λ[j]
228+
@simd for i in 1:n
229+
tmp = model.D_reml[i] * λj + one(T)
230+
logl += log(tmp) + inv(tmp) * model.R̃Φ_reml[i, j]^2
231+
end
232+
end
233+
logl /= -2
234+
end
235+
213236
function update_B!(
214237
model :: MRTVCModel{T}
215238
) where T <: BlasReal
239+
mul!(model.ỸΦ, model.Ỹ, model.Φ)
216240
# Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃)
217241
G = model.storage_pd_pd
218242
fill!(model.storage_nd_pd, zero(T))
@@ -236,31 +260,32 @@ function update_B!(
236260
model.B
237261
end
238262

239-
# function update_B_reml!(
240-
# model :: MRTVCModel{T}
241-
# ) where T <: BlasReal
242-
# # Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃)
243-
# G = model.storage_pd_pd
244-
# fill!(model.storage_nd_pd_reml, zero(T))
245-
# kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml)
246-
# fill!(model.storage_nd_1_reml, zero(T))
247-
# kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml)
248-
# @inbounds @simd for i in eachindex(model.storage_nd_1_reml)
249-
# model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T))
250-
# end
251-
# lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml)
252-
# mul!(G, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml)
253-
# # (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹vec(ỸΦ)
254-
# copyto!(model.storage_nd_2_reml, model.ỸΦ_reml)
255-
# model.storage_nd_2_reml .= model.storage_nd_1_reml .* model.storage_nd_2_reml
256-
# mul!(model.storage_pd, transpose(model.storage_nd_pd_reml), model.storage_nd_2_reml)
257-
# # Cholesky solve
258-
# _, info = LAPACK.potrf!('U', G)
259-
# info > 0 && throw("Gram matrix (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) is singular")
260-
# LAPACK.potrs!('U', G, model.storage_pd)
261-
# copyto!(model.B, model.storage_pd)
262-
# model.B
263-
# end
263+
function update_B_reml!(
264+
model :: MRTVCModel{T}
265+
) where T <: BlasReal
266+
mul!(model.ỸΦ_reml, model.Ỹ_reml, model.Φ)
267+
# Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃)
268+
G = model.storage_pd_pd_reml
269+
fill!(model.storage_nd_pd_reml, zero(T))
270+
kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml)
271+
fill!(model.storage_nd_1_reml, zero(T))
272+
kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml)
273+
@inbounds @simd for i in eachindex(model.storage_nd_1_reml)
274+
model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T))
275+
end
276+
lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml)
277+
mul!(G, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml)
278+
# (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹vec(ỸΦ)
279+
copyto!(model.storage_nd_2_reml, model.ỸΦ_reml)
280+
model.storage_nd_2_reml .= model.storage_nd_1_reml .* model.storage_nd_2_reml
281+
mul!(model.storage_pd_reml, transpose(model.storage_nd_pd_reml), model.storage_nd_2_reml)
282+
# Cholesky solve
283+
_, info = LAPACK.potrf!('U', G)
284+
info > 0 && throw("Gram matrix (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) is singular")
285+
LAPACK.potrs!('U', G, model.storage_pd_reml)
286+
copyto!(model.B_reml, model.storage_pd_reml)
287+
model.B_reml
288+
end
264289

265290
function fisher_B!(
266291
model :: MRTVCModel{T}
@@ -277,6 +302,21 @@ function fisher_B!(
277302
copyto!(model.Bcov, pinv(model.storage_pd_pd))
278303
end
279304

305+
function fisher_B_reml!(
306+
model :: MRTVCModel{T}
307+
) where T <: BlasReal
308+
fill!(model.storage_nd_pd_reml, zero(T))
309+
kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml)
310+
fill!(model.storage_nd_1_reml, zero(T))
311+
kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml)
312+
@inbounds @simd for i in eachindex(model.storage_nd_1_reml)
313+
model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T))
314+
end
315+
lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml)
316+
mul!(model.storage_pd_pd_reml, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml)
317+
copyto!(model.Bcov_reml, pinv(model.storage_pd_pd_reml))
318+
end
319+
280320
function fisher_Σ!(
281321
model :: MRTVCModel{T}
282322
) where T <: BlasReal

0 commit comments

Comments
 (0)