-
Notifications
You must be signed in to change notification settings - Fork 2
Description
We wrote the below code working under Julia 12.1. Hope it can help further studies since the default code snippets cannot work directly. It takes 30 min for 10k samples on our PC-level GPU.
import Pkg
Pkg.activate(".")
Pkg.instantiate()
using Pkg; Pkg.add("NPZ")
using NPZ
using Dates
timestamp = Dates.format(now(), "yyyy_mm_dd_HH_MM_SS")
using KSVD
using Random, StatsBase, SparseArrays, LinearAlgebra
using KSVD.OhMyThreads, CUDA
include("ext/KSVDCudaExt.jl")
using .KSVDCudaExt
nsamples = 10_000
nnzpercol = 20
m = 4096
k = 20
emb_size = 2381
T = Float32
D = rand(Float32, emb_size, m)
X = stack(
(SparseVector(m,
sample(1:m, nnzpercol; replace=false),
rand(T, nnzpercol))
for _ in 1:nsamples);
dims=2)
Y = D*X + T(0.05)randn(T, size(DX))
println("use synthetic features")
Y_filename = "Y.npy"
println(Y_filename)
X_filename = "X.npy"
D_filename = "D.npy"
println(X_filename)
println(D_filename)
println("Y size: ", size(Y))
sparse_coding_method = KSVD.CUDAAcceleratedMatchingPursuit(max_nnz=k)
ksvd_update_method = KSVD.BatchedParallelKSVD{
false,
Float32,
KSVD.OhMyThreads.DynamicScheduler,
KSVD.CUDAAcceleratedArnoldiSVDSolver{Float32}
}(; shuffle_indices=true, batch_size_per_thread=1)
println("ksvd is starting")
(; D, X) = ksvd(Y, m;
ksvd_update_method,
sparse_coding_method,
maxiters=100,
abstol=1e-6,
reltol=1e-6,
show_trace=true)
err = mean(norm.(eachcol(Y - D * X)))
println("Mean reconstruction error = $err")
npzwrite(Y_filename, Y)_
npzwrite(X_filename, X)
npzwrite(D_filename, D)
println("Done!")