Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 59 additions & 66 deletions ext/NonparametricVecchiaCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using NonparametricVecchia
using CUDA
using CUDA.CUSPARSE
using KernelAbstractions
using SparseArrays

function NonparametricVecchia.VecchiaModel(I::Vector{Int}, J::Vector{Int}, samples::CuMatrix{T};
lvar_diag::Union{Nothing,CuVector{T}}=nothing,
Expand Down Expand Up @@ -44,7 +45,7 @@ function NonparametricVecchia.VecchiaModel(I::Vector{Int}, J::Vector{Int}, sampl
nvar,
ncon = ncon,
x0 = x0,
name = "Vecchia_manual",
name = "nonparametric_vecchia_gpu",
nnzj = 2*cache.n,
nnzh = cache.nnzh_tri_lag,
y0 = y0,
Expand Down Expand Up @@ -140,54 +141,21 @@ function NonparametricVecchia.vecchia_mul!(y::CuVector{T}, B::Vector{<:CuMatrix{
return y
end

function NonparametricVecchia.vecchia_build_B!(B::Vector{<:CuMatrix{T}}, samples::CuMatrix{T}, lambda::T,
rowsL::CuVector{Int}, colptrL::CuVector{Int}, hess_obj_vals::CuVector{T},
n::Int, m::CuVector{Int}) where T <: AbstractFloat
# Launch the kernel
backend = KernelAbstractions.get_backend(samples)
r = size(samples, 1)
kernel = vecchia_build_B_kernel!(backend)
kernel(hess_obj_vals, samples, lambda, rowsL, colptrL, m, r, ndrange=n)
KernelAbstractions.synchronize(backend)
return nothing
end

function NonparametricVecchia.vecchia_generate_hess_tri_structure!(nnzh::Int, n::Int, colptr_diff::CuVector{Int},
hrows::CuVector{Int}, hcols::CuVector{Int})
# reset hrows, hcols
fill!(hrows, one(Int))
fill!(hcols, one(Int))

# launch the kernel
backend = KernelAbstractions.get_backend(hrows)
kernel = vecchia_generate_hess_tri_structure_kernel!(backend)

f(x) = (x * (x+1)) ÷ 2

# NOTE: Might be a race condition here. Solution is to store them in the first indices of each thread.
view(hrows, 2:n) .+= cumsum(f.(view(colptr_diff, 1:n-1)))
view(hcols, 2:n) .+= cumsum(view(colptr_diff, 1:n-1))

kernel(nnzh, n, colptr_diff, view(hrows, 1:n), view(hcols, 1:n), hrows, hcols, ndrange = n)
KernelAbstractions.synchronize(backend)
return nothing
end

@kernel function vecchia_mul_kernel!(y, @Const(hess_obj_vals), @Const(x), @Const(m), @Const(offsets))
index = @index(Global)
offset = offsets[index]
mj = m[index]
pos2 = 0
pos = 0
for i = 1:index-1
pos2 += m[i] * (m[i] + 1) ÷ 2
pos += m[i] * (m[i] + 1) ÷ 2
end

# Perform the matrix-vector multiplication for the current symmetric block
for j in 1:mj
idx1 = (j - 1) * (mj + 1) - j * (j-1) ÷ 2
for i in j:mj
idx2 = idx1 + (i - j + 1)
val = hess_obj_vals[pos2+idx2]
val = hess_obj_vals[pos+idx2]

# Diagonal element contributes only once
if i == j
Expand All @@ -201,54 +169,79 @@ end
nothing
end

function NonparametricVecchia.vecchia_build_B!(B::Vector{<:CuMatrix{T}}, samples::CuMatrix{T}, lambda::T,
rowsL::CuVector{Int}, colptrL::CuVector{Int}, hess_obj_vals::CuVector{T},
n::Int, m::CuVector{Int}) where T <: AbstractFloat
# Launch the kernel
backend = KernelAbstractions.get_backend(samples)
r = size(samples, 1)
kernel = vecchia_build_B_kernel!(backend)
kernel(hess_obj_vals, samples, lambda, rowsL, colptrL, m, r, ndrange=n)
KernelAbstractions.synchronize(backend)
return nothing
end

@kernel function vecchia_build_B_kernel!(hess_obj_vals, @Const(samples), @Const(lambda), @Const(rowsL), @Const(colptrL), @Const(m), @Const(r))
index = @index(Global)
pos = colptrL[index]
col = colptrL[index]
mj = m[index]
pos2 = 0

pos = 0
for i = 1:index-1
pos2 += m[i] * (m[i] + 1) ÷ 2
pos += m[i] * (m[i] + 1) ÷ 2
end

k = 0
for s in 1:mj
for t in s:mj
acc = 0.0
for i = 1:r
acc += samples[i, rowsL[pos+t-1]] * samples[i, rowsL[pos+s-1]]
end
k = k + 1
hess_obj_vals[pos2+k] = acc
if (lambda != 0) && (s == t) && (s != 1)
hess_obj_vals[pos2+k] += lambda
if s ≤ t
pos = pos + 1
acc = 0.0
for i = 1:r
acc += samples[i, rowsL[col+t-1]] * samples[i, rowsL[col+s-1]]
end
if (lambda != 0) && (s == t)
acc += lambda
end
hess_obj_vals[pos] = acc
end
end
end
nothing
end

function NonparametricVecchia.vecchia_generate_hess_tri_structure!(n::Int, m::CuVector{Int}, nnzL::Int, nnzh_tri_obj::Int,
offsets::CuVector{Int}, hrows::CuVector{Int}, hcols::CuVector{Int})
# launch the kernel
backend = KernelAbstractions.get_backend(hrows)
kernel = vecchia_generate_hess_tri_structure_kernel!(backend)
kernel(n, m, nnzL, nnzh_tri_obj, offsets, hrows, hcols, ndrange=n)
KernelAbstractions.synchronize(backend)
return nothing
end

@kernel function vecchia_generate_hess_tri_structure_kernel!(@Const(n), @Const(m), @Const(nnzL), @Const(nnzh_tri_obj), @Const(offsets), hrows, hcols)
index = @index(Global)
mj = m[index]
offset = offsets[index]

@kernel function vecchia_generate_hess_tri_structure_kernel!(
@Const(nnzh), @Const(n), @Const(colptr_diff), @Const(carry_offsets), @Const(idx_offsets),
hrows, hcols
)
thread_idx = @index(Global) # in 1:n
m = colptr_diff[thread_idx]
carry = carry_offsets[thread_idx]
idx = idx_offsets[thread_idx]
pos = 0
for i = 1:index-1
pos += m[i] * (m[i] + 1) ÷ 2
end

for j in 1:m
for k in carry:m - j + carry
hrows[k] = (idx + j - 1) + (k - carry)
hcols[k] = idx + j - 1
for s in 1:mj
for t in 1:mj
if s ≤ t
pos = pos + 1
hrows[pos] = offset + t
hcols[pos] = offset + s
end
end
carry += m - j + 1
end

# fill one index of the tail for each thread
@inbounds hrows[nnzh-n + thread_idx] = hrows[nnzh-n] + thread_idx
@inbounds hcols[nnzh-n + thread_idx] = hrows[nnzh-n] + thread_idx
hrows[nnzh_tri_obj + index] = nnzL + index
hcols[nnzh_tri_obj + index] = nnzL + index
nothing
end

end
end # end module
2 changes: 1 addition & 1 deletion src/VecchiaModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ function VecchiaModel(I::Vector{Int}, J::Vector{Int}, samples::Matrix{T};
nvar,
ncon = ncon,
x0 = x0,
name = "Vecchia_manual",
name = "nonparametric_vecchia_cpu",
nnzj = 2*cache.n,
nnzh = cache.nnzh_tri_lag,
y0 = y0,
Expand Down
58 changes: 27 additions & 31 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,8 @@ function vecchia_build_B!(B::Vector{Matrix{T}}, samples::Matrix{T}, lambda::T, r
vs = view(samples, :, rowsL[colptrL[j] + s - 1])
B[j][t, s] = dot(vt, vs)
# Ridge regularization
if (lambda != 0) && (s == t) && (s != 1)
# s == 1 means that we treat the variable related to the diagonal coefficient of column j
# of the sparse Cholesky factor.
#
# Only update diagonal coefficient of the Hessian that are
# related to the variables that represent the off-diagonal terms
# of the sparse Cholesky factor.
if (lambda != 0) && (s == t)
# Only update diagonal coefficient of the Hessian
B[j][t, s] += lambda
end
# Lower triangular part of the block Bⱼ
Expand All @@ -42,30 +37,31 @@ function vecchia_build_B!(B::Vector{Matrix{T}}, samples::Matrix{T}, lambda::T, r
return nothing
end

for INT in (:Int32, :Int64)
@eval begin
function vecchia_generate_hess_tri_structure!(nnzh::Int, n::Int, colptr_diff::Vector{Int},
hrows::Vector{$INT}, hcols::Vector{$INT})
carry = 1
idx = 1
for i in 1:n
m = colptr_diff[i]
for j in 1:m
view(hrows, (0:(m-j)).+carry) .= (j:m).+(idx-1)
fill!(view(hcols, carry:carry+m-j), idx + j - 1)
carry += m - j + 1
end
idx += m
function vecchia_generate_hess_tri_structure!(n::Int, m::Vector{Int}, nnzL::Int, nnzh_tri_obj::Int, offsets::Vector{Int}, hrows::Vector{T}, hcols::Vector{T}) where T <: Integer
pos = 0
offset = 0
for j in 1:n
for s in 1:m[j]
for t in 1:m[j]
if s ≤ t
pos = pos + 1
hrows[pos] = offset + t
hcols[pos] = offset + s
end
end
end
offset += m[j]
end

#Then need the diagonal tail
idx_to = idx + nnzh - carry
view(hrows, carry:nnzh) .= idx:idx_to
view(hcols, carry:nnzh) .= idx:idx_to
@assert pos == nnzh_tri_obj
@assert offset == nnzL

return hrows, hcols
end
for k = 1:n
hrows[nnzh_tri_obj+k] = nnzL + k
hcols[nnzh_tri_obj+k] = nnzL + k
end

return hrows, hcols
end

# The objective of the optimization problem.
Expand Down Expand Up @@ -105,7 +101,7 @@ function NLPModels.hess_structure!(nlp::VecchiaModel, hrows::AbstractVector, hco
@lencheck nlp.meta.nnzh hcols

# stored as lower triangular!
vecchia_generate_hess_tri_structure!(nlp.meta.nnzh, nlp.cache.n, nlp.cache.m, hrows, hcols)
vecchia_generate_hess_tri_structure!(nlp.cache.n, nlp.cache.m, nlp.cache.nnzL, nlp.cache.nnzh_tri_obj, nlp.cache.offsets, hrows, hcols)
return hrows, hcols
end

Expand Down Expand Up @@ -176,9 +172,9 @@ function NLPModels.jac_structure!(nlp::VecchiaModel, jrows::AbstractVector, jcol
@lencheck 2*nlp.cache.n jcols

copyto!(view(jcols, 1:nlp.cache.n), nlp.cache.diagL)
view(jcols, (1:nlp.cache.n).+nlp.cache.n) .= (1:nlp.cache.n).+nlp.cache.nnzL
view(jcols, nlp.cache.n+1:2*nlp.cache.n) .= (nlp.cache.nnzL+1:nlp.meta.nvar)
view(jrows, 1:nlp.cache.n) .= 1:nlp.cache.n
view(jrows, (1:nlp.cache.n).+nlp.cache.n) .= 1:nlp.cache.n
view(jrows, nlp.cache.n+1:2*nlp.cache.n) .= 1:nlp.cache.n
return jrows, jcols
end

Expand Down Expand Up @@ -207,7 +203,7 @@ function NLPModels.jprod!(nlp::VecchiaModel, x::AbstractVector, v::AbstractVecto
Jv, 1,
-view(v, nlp.cache.diagL)
.+ exp.(view(x, nlp.cache.nnzL+1:nlp.meta.nvar))
.* view(v, (1:nlp.cache.n).+nlp.cache.nnzL),
.* view(v, (nlp.cache.nnzL+1:nlp.meta.nvar)),
1, nlp.cache.n
)
return Jv
Expand Down
2 changes: 1 addition & 1 deletion test/Jump_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function obj_vecchia(w::AbstractVector, samples, lambda, cache::VecchiaCache)
for k in 1:cache.M
)

t3 = sum(w[i]^2 for i in 1:cache.nnzL if !(i in cache.diagL))
t3 = sum(w[i]^2 for i in 1:cache.nnzL)

return t1 + 0.5 * t2 + 0.5 * lambda * t3
end
Expand Down
2 changes: 1 addition & 1 deletion test/test_memory_allocation_gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ end
@test mems[:obj] == 16.0 # these allocations are related to allocations in "sum" and "dot"
@test mems[:grad!] == 0.0
@test_broken mems[:cons!] == 0.0
@test_broken mems[:hess_structure!] == 0.0
@test mems[:hess_structure!] == 0.0
@test mems[:jac_structure!] == 0.0
@test mems[:jac_coord!] == 0.0
@test mems[:hess_coord!] == 0.0
Expand Down
Loading