From 9e607e6baf3a3e4cd9726c900619751899fd1109 Mon Sep 17 00:00:00 2001 From: "behinger (s-ccs 001)" Date: Mon, 18 Dec 2023 13:40:37 +0000 Subject: [PATCH] fix sphering LL --- src/helper.jl | 1 + src/main.jl | 12 +++++- src/types.jl | 5 ++- test/compare_amica_implementations.jl | 54 +++++++++++++++++---------- 4 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/helper.jl b/src/helper.jl index d5e6610..b9db918 100755 --- a/src/helper.jl +++ b/src/helper.jl @@ -14,6 +14,7 @@ function sphering!(x) Us,Ss = svd(x*x'/N) S = Us * diagm(vec(1 ./sqrt.(Ss))) * Us' x .= S*x + return S end function bene_sphering(data) diff --git a/src/main.jl b/src/main.jl index b767bcc..2adf9aa 100755 --- a/src/main.jl +++ b/src/main.jl @@ -40,7 +40,12 @@ function amica!(myAmica::AbstractAmica, removed_mean = removeMean!(data) end if do_sphering - sphering!(data) + S = sphering!(data) + myAmica.S = S + LLdetS = logabsdet(S)[1] + else + myAmica.S = I + LLdetS = 0 end dLL = zeros(1, maxiter) @@ -56,6 +61,7 @@ function amica!(myAmica::AbstractAmica, update_sources!(myAmica, data) calculate_ldet!(myAmica) initialize_Lt!(myAmica) + myAmica.Lt .+= LLdetS calculate_y!(myAmica) # pre-calculate abs(y)^rho @@ -69,6 +75,8 @@ function amica!(myAmica::AbstractAmica, loopiloop!(myAmica, y_rho) #Updates y and Lt. Todo: Rename calculate_LL!(myAmica) + + @debug (:LL,myAmica.LL) #Calculate difference in loglikelihood between iterations if iter > 1 @@ -106,6 +114,8 @@ function amica!(myAmica::AbstractAmica, end #If parameters contain NaNs, the algorithm skips the A update and terminates by jumping here @label escape_from_NaN + + #If means were removed, they are added back if remove_mean add_means_back!(myAmica, removed_mean) diff --git a/src/types.jl b/src/types.jl index 8edcc55..b616aba 100755 --- a/src/types.jl +++ b/src/types.jl @@ -12,7 +12,8 @@ mutable struct SingleModelAmica{T} <:AbstractAmica source_signals::Array{T,2} learnedParameters::GGParameters{T} m::Union{Integer, Nothing} #Number of gaussians - A::Array{T,2} #unmixing matrices for each model + A::Array{T,2} # unmixing matrices for each model + S::Array{T,2} # sphering matrix z::Array{T,3} y::Array{T,3} centers::Array{T} #model centers @@ -87,7 +88,7 @@ function SingleModelAmica(data::AbstractArray{T}; m=3, maxiter=500, A=nothing, l ldet = 0.0 source_signals = zeros(n,N) - return SingleModelAmica{T}(source_signals,GGParameters{T}(proportions,scale,location,shape),m,A,z,y,#=Q,=#centers,Lt,LL,ldet,maxiter) + return SingleModelAmica{T}(source_signals,GGParameters{T}(proportions,scale,location,shape),m,A,I(size(A,1)), z,y,#=Q,=#centers,Lt,LL,ldet,maxiter) end #Data type for AMICA with multiple ICA models diff --git a/test/compare_amica_implementations.jl b/test/compare_amica_implementations.jl index 4b38a78..5f4c468 100644 --- a/test/compare_amica_implementations.jl +++ b/test/compare_amica_implementations.jl @@ -3,15 +3,17 @@ using Amica using SignalAnalysis using LinearAlgebra using Revise +#--- using CairoMakie #--- -includet("src/simulate_data.jl") -includet("src/fortran_tools.jl") +includet("../../src/simulate_data.jl") +includet("../../src/fortran_tools.jl") n_chan=3 -n_time=5000 +n_time=10000 x,A,s = simulate_data(;T=50,n_chan,n_time,type=:gg6) i_scale,i_location,i_A = init_params(;n_chan,n_time) +#--- f = Figure() series(f[1,1],s,axis=(;title="source")) #xlims!(1,20) @@ -20,36 +22,49 @@ f #--- -maxiter = 500 +maxiter = 200 # matlab run mat""" +tic [mA,mc,mLL,mLt,mgm,malpha,mmu,mbeta,mrho] = amica_a($x,1,3,$maxiter,$i_location,$i_scale,$i_A,0); +toc """ mat""" +tic [mAopt,mWopt,mSopt,mkhindsopt,mcopt,mLLopt,mLtopt,mgmopt,malphaopt,mmuopt,mbetaopt,mrhoopt] = amica_optimized($x,1,3,$maxiter,1,1,$i_location,$i_scale,$i_A); +toc """ mA = @mget(mA); # get matlab var mAopt = @mget(mAopt); # get matlab opt var # Fortran setup + ran -fortran_setup(x;max_threads=1,max_iter=maxiter) -run(`/scratch/projects/fapra_amica/fortran/amica15test julia.param`) +fortran_setup(Float32.(x);max_threads=1,max_iter=maxiter) +@time run(`/scratch/projects/fapra_amica/fortran/amica15test julia.param`) fA = reshape(reinterpret(Float64,(read("amicaout/A"))),n_chan,n_chan) -# Julia run -am = SingleModelAmica(x;maxiter=maxiter,A=i_A,location=i_location,scale=i_scale) -fit!(am,x) +fortran_setup(Float64.(x);max_threads=1,max_iter=maxiter,dble_data=1) +@time run(`/scratch/projects/fapra_amica/fortran/amica15test julia.param`) + +#--- Julia run +am32 = SingleModelAmica(Float32.(x);maxiter=maxiter,A=deepcopy(i_A),location=i_location,scale=i_scale) +@time fit!(am32,Float32.(x)) + +am64 = SingleModelAmica(Float64.(x);maxiter=maxiter,A=deepcopy(i_A),location=i_location,scale=i_scale) +@time fit!(am64,Float64.(x)) +#am = SingleModelAmica(Float16.(x);maxiter=maxiter,A=i_A,location=i_location,scale=i_scale) +#@time fit!(am,x) #vcat(@mget(mLL),am.LL') #--- f2 = Figure(size=(800,800)) -series(f2[1,1],inv(am.A)*x, axis=(;title="unmixed julia")) -series(f2[2,1],inv(mA)*x, axis=(;title="unmixed matlab")) -series(f2[3,1],inv(mAopt)*x, axis=(;title="unmixed matlab_optimizd")) -series(f2[4,1],.-inv(fA')*x, axis=(;title="unmixed fortran")) -series(f2[5,1],s,axis=(;title="original source")) -series(f2[6,1],x, axis=(;title="original mixed")) +series(f2[1,1],inv(am64.A)*x, axis=(;title="unmixed julia64")) +series(f2[2,1],inv(am32.A)*x, axis=(;title="unmixed julia32")) +series(f2[3,1],inv(mA)*x, axis=(;title="unmixed matlab")) +series(f2[4,1],inv(mAopt)*x, axis=(;title="unmixed matlab_optimizd")) +series(f2[5,1],.-inv(fA')*x, axis=(;title="unmixed fortran")) +series(f2[6,1],s,axis=(;title="original source")) +series(f2[7,1],x, axis=(;title="original mixed")) hidedecorations!.(f2.content) linkxaxes!(f2.content...) @@ -58,18 +73,17 @@ f2 #--- compare LLs -LL = am.LL mLL = @mget mLL mLLopt = @mget mLLopt fLL = reinterpret(Float64,(read("amicaout/LL"))) -f = Figure() +f = Figure(size=(1024,1024)) ax = f[1,1] = Axis(f) -labels = ["julia","matlab", "matlab opt","fortran"] -for (ix,d) = enumerate([LL,mLL[1,:],mLLopt[1,:],fLL]) +labels = ["julia64","julia32","matlab", "matlab opt","fortran"] +for (ix,d) = enumerate([am64.LL, am32.LL,mLL[1,:],mLLopt[1,:],fLL]) lines!(ax,d;label=labels[ix]) end axislegend(ax) -ylims!(ax,-1.3,-1) +ylims!(ax,-2.3,-0.5) f #--- compare A