Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ function sphering!(x)
Us,Ss = svd(x*x'/N)
S = Us * diagm(vec(1 ./sqrt.(Ss))) * Us'
x .= S*x
return S
end

function bene_sphering(data)
Expand Down
12 changes: 11 additions & 1 deletion src/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ function amica!(myAmica::AbstractAmica,
removed_mean = removeMean!(data)
end
if do_sphering
sphering!(data)
S = sphering!(data)
myAmica.S = S
LLdetS = logabsdet(S)[1]
else
myAmica.S = I
LLdetS = 0
end

dLL = zeros(1, maxiter)
Expand All @@ -56,6 +61,7 @@ function amica!(myAmica::AbstractAmica,
update_sources!(myAmica, data)
calculate_ldet!(myAmica)
initialize_Lt!(myAmica)
myAmica.Lt .+= LLdetS
calculate_y!(myAmica)

# pre-calculate abs(y)^rho
Expand All @@ -69,6 +75,8 @@ function amica!(myAmica::AbstractAmica,

loopiloop!(myAmica, y_rho) #Updates y and Lt. Todo: Rename
calculate_LL!(myAmica)


@debug (:LL,myAmica.LL)
#Calculate difference in loglikelihood between iterations
if iter > 1
Expand Down Expand Up @@ -106,6 +114,8 @@ function amica!(myAmica::AbstractAmica,
end
#If parameters contain NaNs, the algorithm skips the A update and terminates by jumping here
@label escape_from_NaN


#If means were removed, they are added back
if remove_mean
add_means_back!(myAmica, removed_mean)
Expand Down
5 changes: 3 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ mutable struct SingleModelAmica{T} <:AbstractAmica
source_signals::Array{T,2}
learnedParameters::GGParameters{T}
m::Union{Integer, Nothing} #Number of gaussians
A::Array{T,2} #unmixing matrices for each model
A::Array{T,2} # unmixing matrices for each model
S::Array{T,2} # sphering matrix
z::Array{T,3}
y::Array{T,3}
centers::Array{T} #model centers
Expand Down Expand Up @@ -87,7 +88,7 @@ function SingleModelAmica(data::AbstractArray{T}; m=3, maxiter=500, A=nothing, l
ldet = 0.0
source_signals = zeros(n,N)

return SingleModelAmica{T}(source_signals,GGParameters{T}(proportions,scale,location,shape),m,A,z,y,#=Q,=#centers,Lt,LL,ldet,maxiter)
return SingleModelAmica{T}(source_signals,GGParameters{T}(proportions,scale,location,shape),m,A,I(size(A,1)), z,y,#=Q,=#centers,Lt,LL,ldet,maxiter)
end

#Data type for AMICA with multiple ICA models
Expand Down
54 changes: 34 additions & 20 deletions test/compare_amica_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@ using Amica
using SignalAnalysis
using LinearAlgebra
using Revise
#---
using CairoMakie
#---
includet("src/simulate_data.jl")
includet("src/fortran_tools.jl")
includet("../../src/simulate_data.jl")
includet("../../src/fortran_tools.jl")
n_chan=3
n_time=5000
n_time=10000
x,A,s = simulate_data(;T=50,n_chan,n_time,type=:gg6)
i_scale,i_location,i_A = init_params(;n_chan,n_time)

#---
f = Figure()
series(f[1,1],s,axis=(;title="source"))
#xlims!(1,20)
Expand All @@ -20,36 +22,49 @@ f

#---

maxiter = 500
maxiter = 200

# matlab run
mat"""
tic
[mA,mc,mLL,mLt,mgm,malpha,mmu,mbeta,mrho] = amica_a($x,1,3,$maxiter,$i_location,$i_scale,$i_A,0);
toc
"""
mat"""
tic
[mAopt,mWopt,mSopt,mkhindsopt,mcopt,mLLopt,mLtopt,mgmopt,malphaopt,mmuopt,mbetaopt,mrhoopt] = amica_optimized($x,1,3,$maxiter,1,1,$i_location,$i_scale,$i_A);
toc
"""
mA = @mget(mA); # get matlab var
mAopt = @mget(mAopt); # get matlab opt var

# Fortran setup + ran
fortran_setup(x;max_threads=1,max_iter=maxiter)
run(`/scratch/projects/fapra_amica/fortran/amica15test julia.param`)
fortran_setup(Float32.(x);max_threads=1,max_iter=maxiter)
@time run(`/scratch/projects/fapra_amica/fortran/amica15test julia.param`)
fA = reshape(reinterpret(Float64,(read("amicaout/A"))),n_chan,n_chan)

# Julia run
am = SingleModelAmica(x;maxiter=maxiter,A=i_A,location=i_location,scale=i_scale)
fit!(am,x)
fortran_setup(Float64.(x);max_threads=1,max_iter=maxiter,dble_data=1)
@time run(`/scratch/projects/fapra_amica/fortran/amica15test julia.param`)

#--- Julia run
am32 = SingleModelAmica(Float32.(x);maxiter=maxiter,A=deepcopy(i_A),location=i_location,scale=i_scale)
@time fit!(am32,Float32.(x))

am64 = SingleModelAmica(Float64.(x);maxiter=maxiter,A=deepcopy(i_A),location=i_location,scale=i_scale)
@time fit!(am64,Float64.(x))
#am = SingleModelAmica(Float16.(x);maxiter=maxiter,A=i_A,location=i_location,scale=i_scale)
#@time fit!(am,x)
#vcat(@mget(mLL),am.LL')

#---
f2 = Figure(size=(800,800))
series(f2[1,1],inv(am.A)*x, axis=(;title="unmixed julia"))
series(f2[2,1],inv(mA)*x, axis=(;title="unmixed matlab"))
series(f2[3,1],inv(mAopt)*x, axis=(;title="unmixed matlab_optimizd"))
series(f2[4,1],.-inv(fA')*x, axis=(;title="unmixed fortran"))
series(f2[5,1],s,axis=(;title="original source"))
series(f2[6,1],x, axis=(;title="original mixed"))
series(f2[1,1],inv(am64.A)*x, axis=(;title="unmixed julia64"))
series(f2[2,1],inv(am32.A)*x, axis=(;title="unmixed julia32"))
series(f2[3,1],inv(mA)*x, axis=(;title="unmixed matlab"))
series(f2[4,1],inv(mAopt)*x, axis=(;title="unmixed matlab_optimizd"))
series(f2[5,1],.-inv(fA')*x, axis=(;title="unmixed fortran"))
series(f2[6,1],s,axis=(;title="original source"))
series(f2[7,1],x, axis=(;title="original mixed"))
hidedecorations!.(f2.content)

linkxaxes!(f2.content...)
Expand All @@ -58,18 +73,17 @@ f2


#--- compare LLs
LL = am.LL
mLL = @mget mLL
mLLopt = @mget mLLopt
fLL = reinterpret(Float64,(read("amicaout/LL")))
f = Figure()
f = Figure(size=(1024,1024))
ax = f[1,1] = Axis(f)
labels = ["julia","matlab", "matlab opt","fortran"]
for (ix,d) = enumerate([LL,mLL[1,:],mLLopt[1,:],fLL])
labels = ["julia64","julia32","matlab", "matlab opt","fortran"]
for (ix,d) = enumerate([am64.LL, am32.LL,mLL[1,:],mLLopt[1,:],fLL])
lines!(ax,d;label=labels[ix])
end
axislegend(ax)
ylims!(ax,-1.3,-1)
ylims!(ax,-2.3,-0.5)
f

#--- compare A