Skip to content

Commit bcedc64

Browse files
committed
relocate reml functions
1 parent bc74b2b commit bcedc64

File tree

5 files changed

+105
-104
lines changed

5 files changed

+105
-104
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ jobs:
1818
- "1.9"
1919
- "1.10"
2020
- "1.11"
21-
- "nightly"
2221
os:
2322
- ubuntu-latest
2423
- macos-latest

src/MultiResponseVarianceComponentModels.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ maximum likelihood (REML) estimation and inference.
55
"""
66
module MultiResponseVarianceComponentModels
77

8-
using IterativeSolvers, LinearAlgebra, Manopt, Manifolds, Distributions, SweepOperator, InvertedIndices
8+
using IterativeSolvers, LinearAlgebra, Distributions, SweepOperator, InvertedIndices
9+
# using Manopt, Manifolds
910
import LinearAlgebra: BlasReal, copytri!
1011
export VCModel,
1112
MultiResponseVarianceComponentModel,
@@ -530,12 +531,12 @@ function Base.show(io::IO, model::VCModel)
530531
printstyled(io, "$m"; color = :yellow)
531532
end
532533

533-
include("mvcalculus.jl")
534-
include("reml.jl")
535-
include("fit.jl")
536534
include("eigen.jl")
535+
include("fit.jl")
537536
# include("manopt.jl")
538-
include("parse.jl")
539537
include("missing.jl")
538+
include("mvcalculus.jl")
539+
include("parse.jl")
540+
include("reml.jl")
540541

541542
end

src/fit.jl

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -308,42 +308,6 @@ function update_B!(
308308
model.B
309309
end
310310

311-
function update_B_reml!(
312-
model :: MRVCModel{T}
313-
) where T <: BlasReal
314-
Ω⁻¹ = model.storage_nd_nd_reml
315-
G = model.storage_pd_pd_reml
316-
# Gram matrix G = (Id⊗X')Ω⁻¹(Id⊗X) = (X'Ω⁻¹ᵢⱼX)ᵢⱼ, 1 ≤ i,j ≤ d
317-
n, d, p = size(model.Y_reml, 1), size(model.Y_reml, 2), size(model.X_reml, 2)
318-
for j in 1:d
319-
Ωcidx = ((j - 1) * n + 1):(j * n)
320-
Gcidx = ((j - 1) * p + 1):(j * p)
321-
for i in 1:j
322-
Ωridx = ((i - 1) * n + 1):(i * n)
323-
Gridx = ((i - 1) * p + 1):(i * p)
324-
Ω⁻¹ᵢⱼ = view(Ω⁻¹, Ωridx, Ωcidx)
325-
Gᵢⱼ = view(G , Gridx, Gcidx)
326-
mul!(model.storage_n_p_reml, Ω⁻¹ᵢⱼ, model.X_reml)
327-
mul!(Gᵢⱼ, transpose(model.X_reml), model.storage_n_p_reml)
328-
end
329-
end
330-
copytri!(G, 'U')
331-
# (Id⊗X')Ω⁻¹vec(Y) = vec(X' * reshape(Ω⁻¹vec(Y), n, d))
332-
copyto!(model.storage_nd_1_reml, model.Y_reml)
333-
mul!(model.storage_nd_2_reml, model.storage_nd_nd_reml, model.storage_nd_1_reml)
334-
copyto!(model.storage_n_d_reml, model.storage_nd_2_reml)
335-
mul!(model.storage_p_d_reml, transpose(model.X_reml), model.storage_n_d_reml)
336-
# Cholesky solve
337-
_, info = LAPACK.potrf!('U', G)
338-
info > 0 && throw("Gram matrix (Id⊗X')Ω⁻¹(Id⊗X) is singular")
339-
copyto!(model.storage_pd_reml, model.storage_p_d_reml)
340-
LAPACK.potrs!('U', G, model.storage_pd_reml)
341-
copyto!(model.B_reml, model.storage_pd_reml)
342-
# update residuals R
343-
update_res_reml!(model)
344-
model.B_reml
345-
end
346-
347311
"""
348312
loglikelihood!(model::MRVCModel)
349313
@@ -373,27 +337,6 @@ function loglikelihood!(
373337
logl /= -2
374338
end
375339

376-
function loglikelihood_reml!(
377-
model :: MRVCModel{T}
378-
) where T <: BlasReal
379-
copyto!(model.storage_nd_nd_reml, model.Ω_reml)
380-
# Cholesky of covariance Ω = U'U
381-
_, info = LAPACK.potrf!('U', model.storage_nd_nd_reml)
382-
info > 0 && throw("Covariance matrix Ω is singular")
383-
# storage_nd = U' \ vec(R)
384-
copyto!(model.storage_nd_1_reml, model.R_reml)
385-
BLAS.trsv!('U', 'T', 'N', model.storage_nd_nd_reml, model.storage_nd_1_reml)
386-
# assemble pieces for log-likelihood
387-
logl = sum(abs2, model.storage_nd_1_reml) + length(model.storage_nd_1_reml) * log(2π)
388-
@inbounds for i in 1:length(model.storage_nd_1_reml)
389-
logl += 2log(model.storage_nd_nd_reml[i, i])
390-
end
391-
# Ω⁻¹ from upper Cholesky factor
392-
LAPACK.potri!('U', model.storage_nd_nd_reml)
393-
copytri!(model.storage_nd_nd_reml, 'U')
394-
logl /= -2
395-
end
396-
397340
function update_res!(
398341
model :: MRVCModel{T}
399342
) where T <: BlasReal
@@ -402,14 +345,6 @@ function update_res!(
402345
model.R
403346
end
404347

405-
function update_res_reml!(
406-
model :: MRVCModel{T}
407-
) where T <: BlasReal
408-
# update R = Y - XB
409-
BLAS.gemm!('N', 'N', -one(T), model.X_reml, model.B_reml, one(T), copyto!(model.R_reml, model.Y_reml))
410-
model.R
411-
end
412-
413348
function update_Ω!(
414349
model :: MRVCModel{T}
415350
) where T <: BlasReal
@@ -420,16 +355,6 @@ function update_Ω!(
420355
model.Ω
421356
end
422357

423-
function update_Ω_reml!(
424-
model :: MRVCModel{T}
425-
) where T <: BlasReal
426-
fill!(model.Ω_reml, zero(T))
427-
@inbounds for k in 1:length(model.V_reml)
428-
kron_axpy!(model.Σ[k], model.V_reml[k], model.Ω_reml)
429-
end
430-
model.Ω_reml
431-
end
432-
433358
"""
434359
fisher_B!(model::MRVCModel)
435360
@@ -459,29 +384,6 @@ function fisher_B!(
459384
copyto!(model.Bcov, pinv(G))
460385
end
461386

462-
function fisher_B_reml!(
463-
model :: MRVCModel{T}
464-
) where T <: BlasReal
465-
Ω⁻¹ = model.storage_nd_nd_reml
466-
G = model.storage_pd_pd_reml
467-
# Gram matrix G = (Id⊗X')Ω⁻¹(Id⊗X) = (X'Ω⁻¹ᵢⱼX)ᵢⱼ, 1 ≤ i,j ≤ d
468-
n, d, p = size(model.Y_reml, 1), size(model.Y_reml, 2), size(model.X_reml, 2)
469-
for j in 1:d
470-
Ωcidx = ((j - 1) * n + 1):(j * n)
471-
Gcidx = ((j - 1) * p + 1):(j * p)
472-
for i in 1:j
473-
Ωridx = ((i - 1) * n + 1):(i * n)
474-
Gridx = ((i - 1) * p + 1):(i * p)
475-
Ω⁻¹ᵢⱼ = view(Ω⁻¹, Ωridx, Ωcidx)
476-
Gᵢⱼ = view(G , Gridx, Gcidx)
477-
mul!(model.storage_n_p_reml, Ω⁻¹ᵢⱼ, model.X_reml)
478-
mul!(Gᵢⱼ, transpose(model.X_reml), model.storage_n_p_reml)
479-
end
480-
end
481-
copytri!(G, 'U')
482-
copyto!(model.Bcov_reml, pinv(G))
483-
end
484-
485387
"""
486388
fisher_Σ!(model::MRVCModel)
487389

src/reml.jl

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,101 @@
1+
function update_B_reml!(
2+
model :: MRVCModel{T}
3+
) where T <: BlasReal
4+
Ω⁻¹ = model.storage_nd_nd_reml
5+
G = model.storage_pd_pd_reml
6+
# Gram matrix G = (Id⊗X')Ω⁻¹(Id⊗X) = (X'Ω⁻¹ᵢⱼX)ᵢⱼ, 1 ≤ i,j ≤ d
7+
n, d, p = size(model.Y_reml, 1), size(model.Y_reml, 2), size(model.X_reml, 2)
8+
for j in 1:d
9+
Ωcidx = ((j - 1) * n + 1):(j * n)
10+
Gcidx = ((j - 1) * p + 1):(j * p)
11+
for i in 1:j
12+
Ωridx = ((i - 1) * n + 1):(i * n)
13+
Gridx = ((i - 1) * p + 1):(i * p)
14+
Ω⁻¹ᵢⱼ = view(Ω⁻¹, Ωridx, Ωcidx)
15+
Gᵢⱼ = view(G , Gridx, Gcidx)
16+
mul!(model.storage_n_p_reml, Ω⁻¹ᵢⱼ, model.X_reml)
17+
mul!(Gᵢⱼ, transpose(model.X_reml), model.storage_n_p_reml)
18+
end
19+
end
20+
copytri!(G, 'U')
21+
# (Id⊗X')Ω⁻¹vec(Y) = vec(X' * reshape(Ω⁻¹vec(Y), n, d))
22+
copyto!(model.storage_nd_1_reml, model.Y_reml)
23+
mul!(model.storage_nd_2_reml, model.storage_nd_nd_reml, model.storage_nd_1_reml)
24+
copyto!(model.storage_n_d_reml, model.storage_nd_2_reml)
25+
mul!(model.storage_p_d_reml, transpose(model.X_reml), model.storage_n_d_reml)
26+
# Cholesky solve
27+
_, info = LAPACK.potrf!('U', G)
28+
info > 0 && throw("Gram matrix (Id⊗X')Ω⁻¹(Id⊗X) is singular")
29+
copyto!(model.storage_pd_reml, model.storage_p_d_reml)
30+
LAPACK.potrs!('U', G, model.storage_pd_reml)
31+
copyto!(model.B_reml, model.storage_pd_reml)
32+
# update residuals R
33+
update_res_reml!(model)
34+
model.B_reml
35+
end
36+
37+
function loglikelihood_reml!(
38+
model :: MRVCModel{T}
39+
) where T <: BlasReal
40+
copyto!(model.storage_nd_nd_reml, model.Ω_reml)
41+
# Cholesky of covariance Ω = U'U
42+
_, info = LAPACK.potrf!('U', model.storage_nd_nd_reml)
43+
info > 0 && throw("Covariance matrix Ω is singular")
44+
# storage_nd = U' \ vec(R)
45+
copyto!(model.storage_nd_1_reml, model.R_reml)
46+
BLAS.trsv!('U', 'T', 'N', model.storage_nd_nd_reml, model.storage_nd_1_reml)
47+
# assemble pieces for log-likelihood
48+
logl = sum(abs2, model.storage_nd_1_reml) + length(model.storage_nd_1_reml) * log(2π)
49+
@inbounds for i in 1:length(model.storage_nd_1_reml)
50+
logl += 2log(model.storage_nd_nd_reml[i, i])
51+
end
52+
# Ω⁻¹ from upper Cholesky factor
53+
LAPACK.potri!('U', model.storage_nd_nd_reml)
54+
copytri!(model.storage_nd_nd_reml, 'U')
55+
logl /= -2
56+
end
57+
58+
function update_res_reml!(
59+
model :: MRVCModel{T}
60+
) where T <: BlasReal
61+
# update R = Y - XB
62+
BLAS.gemm!('N', 'N', -one(T), model.X_reml, model.B_reml, one(T), copyto!(model.R_reml, model.Y_reml))
63+
model.R
64+
end
65+
66+
function update_Ω_reml!(
67+
model :: MRVCModel{T}
68+
) where T <: BlasReal
69+
fill!(model.Ω_reml, zero(T))
70+
@inbounds for k in 1:length(model.V_reml)
71+
kron_axpy!(model.Σ[k], model.V_reml[k], model.Ω_reml)
72+
end
73+
model.Ω_reml
74+
end
75+
76+
function fisher_B_reml!(
77+
model :: MRVCModel{T}
78+
) where T <: BlasReal
79+
Ω⁻¹ = model.storage_nd_nd_reml
80+
G = model.storage_pd_pd_reml
81+
# Gram matrix G = (Id⊗X')Ω⁻¹(Id⊗X) = (X'Ω⁻¹ᵢⱼX)ᵢⱼ, 1 ≤ i,j ≤ d
82+
n, d, p = size(model.Y_reml, 1), size(model.Y_reml, 2), size(model.X_reml, 2)
83+
for j in 1:d
84+
Ωcidx = ((j - 1) * n + 1):(j * n)
85+
Gcidx = ((j - 1) * p + 1):(j * p)
86+
for i in 1:j
87+
Ωridx = ((i - 1) * n + 1):(i * n)
88+
Gridx = ((i - 1) * p + 1):(i * p)
89+
Ω⁻¹ᵢⱼ = view(Ω⁻¹, Ωridx, Ωcidx)
90+
Gᵢⱼ = view(G , Gridx, Gcidx)
91+
mul!(model.storage_n_p_reml, Ω⁻¹ᵢⱼ, model.X_reml)
92+
mul!(Gᵢⱼ, transpose(model.X_reml), model.storage_n_p_reml)
93+
end
94+
end
95+
copytri!(G, 'U')
96+
copyto!(model.Bcov_reml, pinv(G))
97+
end
98+
199
function project_null(
2100
Y :: AbstractVecOrMat{T},
3101
X :: AbstractVecOrMat{T},

test/missing_test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ end
5050
@testset "fit! missing response with MM" begin
5151
model = MRVCModel(Y_miss, X, V; se = false)
5252
# @timev MultiResponseVarianceComponentModels.fit!(model)
53+
# @test model.logl[1] ≈ -4435.064121104977
5354
end
5455

5556
end

0 commit comments

Comments
 (0)