Skip to content

Commit 128a204

Browse files
committed
Allow passing X_init
1 parent 49f2ec9 commit 128a204

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

src/KSVD.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,24 @@ A named tuple containing:
9494
- To enable timing outputs, run `TimerOutputs.enable_debug_timings(KSVD)`.
9595
- To set the number of nonzeros, specify e.g. `sparse_coding_method=ParallelMatchingPursuit(; max_nnz=..., rtol=5e-2)`.
9696
"""
97-
function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=max(3, n_atoms÷100);
98-
ksvd_update_method = BatchedParallelKSVD{false, T}(; shuffle_indices=true, batch_size_per_thread=1),
99-
sparse_coding_method = ParallelMatchingPursuit(; max_nnz, rtol=5e-2),
100-
minibatch_size=nothing,
101-
D_init::Union{Nothing, <:AbstractMatrix{T}} = nothing,
102-
# termination conditions
103-
maxiters::Int=100,
104-
maxtime::Union{Nothing, <:Real}=nothing,
105-
abstol::Union{Nothing, <:Real}=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5),
106-
reltol::Union{Nothing, <:Real}=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5),
107-
nnz_per_col_target::Number=0.0,
108-
# tracing options
109-
show_trace::Bool=false,
110-
callback_fn::Union{Nothing, Function}=nothing,
111-
verbose=false,
112-
) where T
113-
timer = TimerOutput()
97+
function ksvd(Y::AbstractMatrix{T}, n_atoms::Int, max_nnz=max(3, n_atoms ÷ 100);
98+
ksvd_update_method=BatchedParallelKSVD{false,T}(; shuffle_indices=true, batch_size_per_thread=1),
99+
sparse_coding_method=ParallelMatchingPursuit(; max_nnz, rtol=5e-2),
100+
minibatch_size=nothing,
101+
D_init::Union{Nothing,<:AbstractMatrix{T}}=nothing,
102+
X_init::Union{Nothing,<:AbstractSparseMatrix}=nothing,
103+
# termination conditions
104+
maxiters::Int=100,
105+
maxtime::Union{Nothing,<:Real}=nothing,
106+
abstol::Union{Nothing,<:Real}=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5),
107+
reltol::Union{Nothing,<:Real}=real(oneunit(T)) * (eps(real(one(T))))^(4 // 5),
108+
nnz_per_col_target::Number=0.0,
109+
# tracing options
110+
show_trace::Bool=false,
111+
callback_fn::Union{Nothing,Function}=nothing,
112+
verbose=false,
113+
timer::TimerOutput=TimerOutput()
114+
) where {T}
114115
emb_dim, n_samples = size(Y)
115116

116117
# D is a dictionary matrix that contains atoms for columns.

0 commit comments

Comments
 (0)