@@ -94,23 +94,24 @@ A named tuple containing:
9494- To enable timing outputs, run `TimerOutputs.enable_debug_timings(KSVD)`.
9595- To set the number of nonzeros, specify e.g. `sparse_coding_method=ParallelMatchingPursuit(; max_nnz=..., rtol=5e-2)`.
9696"""
97- function ksvd (Y:: AbstractMatrix{T} , n_atoms:: Int , max_nnz= max (3 , n_atoms÷ 100 );
98- ksvd_update_method = BatchedParallelKSVD {false, T} (; shuffle_indices= true , batch_size_per_thread= 1 ),
99- sparse_coding_method = ParallelMatchingPursuit (; max_nnz, rtol= 5e-2 ),
100- minibatch_size= nothing ,
101- D_init:: Union{Nothing, <:AbstractMatrix{T}} = nothing ,
102- # termination conditions
103- maxiters:: Int = 100 ,
104- maxtime:: Union{Nothing, <:Real} = nothing ,
105- abstol:: Union{Nothing, <:Real} = real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ),
106- reltol:: Union{Nothing, <:Real} = real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ),
107- nnz_per_col_target:: Number = 0.0 ,
108- # tracing options
109- show_trace:: Bool = false ,
110- callback_fn:: Union{Nothing, Function} = nothing ,
111- verbose= false ,
112- ) where T
113- timer = TimerOutput ()
97+ function ksvd (Y:: AbstractMatrix{T} , n_atoms:: Int , max_nnz= max (3 , n_atoms ÷ 100 );
98+ ksvd_update_method= BatchedParallelKSVD {false,T} (; shuffle_indices= true , batch_size_per_thread= 1 ),
99+ sparse_coding_method= ParallelMatchingPursuit (; max_nnz, rtol= 5e-2 ),
100+ minibatch_size= nothing ,
101+ D_init:: Union{Nothing,<:AbstractMatrix{T}} = nothing ,
102+ X_init:: Union{Nothing,<:AbstractSparseMatrix} = nothing ,
103+ # termination conditions
104+ maxiters:: Int = 100 ,
105+ maxtime:: Union{Nothing,<:Real} = nothing ,
106+ abstol:: Union{Nothing,<:Real} = real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ),
107+ reltol:: Union{Nothing,<:Real} = real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ),
108+ nnz_per_col_target:: Number = 0.0 ,
109+ # tracing options
110+ show_trace:: Bool = false ,
111+ callback_fn:: Union{Nothing,Function} = nothing ,
112+ verbose= false ,
113+ timer:: TimerOutput = TimerOutput ()
114+ ) where {T}
114115 emb_dim, n_samples = size (Y)
115116
116117 # D is a dictionary matrix that contains atoms for columns.
0 commit comments