Skip to content

Commit 359996b

Browse files
authored
GPU optimization of LOBPCG (#1068)
1 parent 3c34d38 commit 359996b

7 files changed

Lines changed: 97 additions & 37 deletions

File tree

src/DFTK.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ export PreconditionerNone
138138
export lobpcg_hyper
139139
export diag_full
140140
export diagonalize_all_kblocks
141+
include("eigen/linalg.jl")
141142
include("eigen/preconditioners.jl")
142143
include("eigen/diag.jl")
143144

@@ -231,7 +232,10 @@ include("postprocess/refine.jl")
231232
# Workarounds
232233
include("workarounds/dummy_inplace_fft.jl")
233234
include("workarounds/forwarddiff_rules.jl")
234-
include("workarounds/gpu_arrays.jl")
235+
236+
# Optimized generic GPU functions and GPU workarounds
237+
include("gpu/linalg.jl")
238+
include("gpu/gpu_arrays.jl")
235239

236240
# Precompilation block with a basic workflow
237241

src/eigen/linalg.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Calculate the norms of the columns of an array
2+
function columnwise_norms(X::AbstractArray{T}) where{T}
3+
vec(sqrt.(sum(abs2, X; dims=1)))
4+
end
5+
6+
# Returns a vector of dot(A[:, i], B[:, i]), for all columns of A, B
7+
@views function columnwise_dots(A::AbstractArray{T}, B::AbstractArray{T}) where {T}
8+
[real(dot(A[:, i], B[:, i])) for i = 1:size(A, 2)]
9+
end
10+
11+
# Returns a vector of real(dot(A[:, i], M, B[:, i])), for all columns of
12+
# A, B, and matrix M
13+
@views function columnwise_dots(A::AbstractArray{T}, M, B::AbstractArray{T}) where {T}
14+
[real(dot(A[:, i], M, B[:, i])) for i = 1:size(A, 2)]
15+
end

src/eigen/lobpcg_hyper_impl.jl

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
# other eigenvectors (which is not the case in many - all ? - other
3434
# implementations)
3535

36+
# - Some functions are reimplemented in a GPU optimized way as part of
37+
# the DFTK CUDA Extension (ext/DFTKCUDAExt/lobpcg.jl).
38+
3639

3740
## TODO micro-optimization of buffer reuse
3841
## TODO write a version that doesn't assume that B is well-conditioned, and doesn't reuse B applications at all
@@ -170,7 +173,7 @@ function B_ortho!(X, BX)
170173
rdiv!(BX, U)
171174
end
172175

173-
normest(M) = maximum(abs.(diag(M))) + norm(M - Diagonal(diag(M)))
176+
normest(M) = maximum(abs, diag(M)) + norm(M - Diagonal(diag(M)))
174177
# Orthogonalizes X to tol
175178
# Returns the new X, the number of Cholesky factorizations algorithm, and the
176179
# growth factor by which small perturbations of X can have been magnified
@@ -252,24 +255,15 @@ end
252255

253256
# Randomize the columns of X if the norm is below tol
254257
function drop_small!(X::AbstractArray{T}; tol=2eps(real(T))) where {T}
255-
dropped = Int[]
256-
for i=1:size(X,2)
257-
n = norm(@views X[:,i])
258-
if n <= tol
259-
X[:,i] = randn(T, size(X,1))
260-
push!(dropped, i)
261-
end
262-
end
258+
dropped = findall(n -> n <= tol, columnwise_norms(X))
259+
@views randn!(TaskLocalRNG(), X[:, dropped])
263260
dropped
264261
end
265262

266263
# Find X that is orthogonal, and B-orthogonal to Y, up to a tolerance tol.
267264
@timing "ortho! X vs Y" function ortho!(X::AbstractArray{T}, Y, BY; tol=2eps(real(T))) where {T}
268265
# normalize to try to cheaply improve conditioning
269-
parallel_loop_over_range(1:size(X, 2)) do i
270-
n = norm(@views X[:,i])
271-
@views X[:,i] ./= n
272-
end
266+
X ./= columnwise_norms(X)'
273267

274268
niter = 1
275269
ninners = zeros(Int, 0)
@@ -322,7 +316,7 @@ end
322316
end
323317

324318
function final_retval(X, AX, BX, λ, resid_history, niter, n_matvec)
325-
λ_host = oftype(ones(eltype(λ), 1), λ) # Copy to CPU for element-wise access
319+
λ_host = to_cpu(λ) # Copy to CPU for element-wise access
326320
if !issorted(λ_host)
327321
p = sortperm(λ_host)
328322
λ_host = λ_host[p]
@@ -336,6 +330,12 @@ function final_retval(X, AX, BX, λ, resid_history, niter, n_matvec)
336330
residual_history=resid_history[:, 1:niter+1], n_matvec)
337331
end
338332

333+
# Computes λ = real((X' * AX) / (X' *BX)), for each column of X
334+
function compute_λ(X, AX, BX)
335+
λs = @views [real((X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n])) for n=1:size(X, 2)]
336+
oftype(real(X[:, 1]), λs) # Offload to GPU if needed
337+
end
338+
339339
### The algorithm is Xn+1 = rayleigh_ritz(hcat(Xn, A*Xn, Xn-Xn-1))
340340
### We follow the strategy of Hetmaniuk and Lehoucq, and maintain a B-orthonormal basis Y = (X,R,P)
341341
### After each rayleigh_ritz step, the B-orthonormal X and P are deduced by an orthogonal rotation from Y
@@ -389,8 +389,7 @@ end
389389
end
390390
nlocked = 0
391391
niter = 0 # the first iteration is fake
392-
λs = @views [real((X[:, n]'*AX[:, n]) / (X[:, n]'BX[:, n])) for n=1:M]
393-
λs = oftype(real(X[:, 1]), λs) # Offload to GPU if needed
392+
λs = compute_λ(X, AX, BX)
394393
new_X = X
395394
new_AX = AX
396395
new_BX = BX
@@ -431,9 +430,8 @@ end
431430
### Compute new residuals
432431
@timing "Update residuals" begin
433432
new_R = new_AX .- new_BX .* λs'
434-
@views for i = 1:size(X, 2)
435-
resid_history[i + nlocked, niter+1] = norm(new_R[:, i])
436-
end
433+
norms = to_cpu(columnwise_norms(new_R))
434+
@views resid_history[1 + nlocked: size(new_R, 2) + nlocked, niter+1] .= norms[:]
437435
end
438436
@debug niter resid_history[:, niter+1]
439437

@@ -512,10 +510,9 @@ end
512510
end
513511

514512
# Quick sanity check
515-
for i = 1:size(X, 2)
516-
@views if abs(BX[:, i]'X[:, i] - 1) >= sqrt(eps(real(eltype(X))))
517-
error("LOBPCG is badly failing to keep the vectors normalized; this should never happen")
518-
end
513+
diffs = abs.(columnwise_dots(BX, X) .-1)
514+
if any(diffs .>= sqrt(eps(real(eltype(X)))))
515+
error("LOBPCG is badly failing to keep the vectors normalized; this should never happen")
519516
end
520517

521518
# Restrict all views to active

src/eigen/preconditioners.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ PreconditionerNone(::HamiltonianBlock) = I
2424
(Simplified version of)
2525
[Tetter-Payne-Allan preconditioning](https://doi.org/10.1103/physrevb.40.12255).
2626
"""
27-
mutable struct PreconditionerTPA{T <: Real}
27+
mutable struct PreconditionerTPA{T <: Real, Tkin <: AbstractVector{T}}
2828
basis::PlaneWaveBasis
2929
kpt::Kpoint
30-
kin::AbstractVector{T} # kinetic energy of every G
31-
mean_kin::Union{Nothing, Vector{T}} # mean kinetic energy of every band
30+
kin::Tkin # kinetic energy of every G
31+
mean_kin::Union{Nothing, Tkin} # mean kinetic energy of every band
3232
default_shift::T # if mean_kin is not set by `precondprep!`, this will be used for the shift
3333
end
3434

@@ -40,7 +40,7 @@ function PreconditionerTPA(basis::PlaneWaveBasis{T}, kpt::Kpoint; default_shift=
4040
# it's better to pass a HamiltonianBlock directly and read the computed values.
4141
kinetic_term = only(kinetic_term)
4242
kin = kinetic_energy(kinetic_term, basis.Ecut, Gplusk_vectors_cart(basis, kpt))
43-
PreconditionerTPA{T}(basis, kpt, kin, nothing, default_shift)
43+
PreconditionerTPA{T, typeof(kin)}(basis, kpt, kin, nothing, default_shift)
4444
end
4545
function PreconditionerTPA(ham::HamiltonianBlock; kwargs...)
4646
PreconditionerTPA(ham.basis, ham.kpoint; kwargs...)
@@ -50,7 +50,7 @@ end
5050
if P.mean_kin === nothing
5151
ldiv!(Y, Diagonal(P.kin .+ P.default_shift), R)
5252
else
53-
parallel_loop_over_range(1:size(Y, 2)) do n
53+
parallel_loop_over_range(1:size(Y, 2)) do n
5454
Y[:, n] .= P.mean_kin[n] ./ (P.mean_kin[n] .+ P.kin) .* R[:, n]
5555
end
5656
end
@@ -73,7 +73,7 @@ end
7373
(Base.:*)(P::PreconditionerTPA, R) = mul!(copy(R), P, R)
7474

7575
function precondprep!(P::PreconditionerTPA, X::AbstractArray)
76-
P.mean_kin = [real(dot(x, Diagonal(P.kin), x)) for x in eachcol(X)]
76+
P.mean_kin = vec(real(columnwise_dots(X, Diagonal(P.kin), X)))
7777
end
7878
precondprep!(P::PreconditionerTPA, ::Nothing) = 1 # fallback for edge cases
7979

src/gpu/linalg.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
### GPU-specific implementations of functions called during LOBPCG
2+
# The massive parallelism of the GPU can only be fully exploited when
3+
# operating on whole arrays. For performance reasons, one should avoid
4+
# explicitly looping over columns or elements. This approach is not
5+
# necessarily the most performant on CPU, as the allocation of large
6+
# temporary arrays hurts cache locality. It is also harder to read.
7+
8+
using LinearAlgebra
9+
using GPUArraysCore
10+
11+
function compute_λ(X::AbstractGPUArray{T}, AX::AbstractGPUArray{T}, BX::AbstractGPUArray{T}) where {T}
12+
num = sum(conj(X) .* AX, dims=1)
13+
den = sum(conj(X) .* BX, dims=1)
14+
vec(real.(num ./ den))
15+
end
16+
17+
function columnwise_dots(A::AbstractGPUArray{T}, B::AbstractGPUArray{T}) where {T}
18+
sum(conj(A) .* B; dims=1)
19+
end
20+
21+
function columnwise_dots(A::AbstractGPUArray{T}, M, B::AbstractGPUArray{T}) where {T}
22+
sum(conj(A) .* (M * B); dims=1)
23+
end
24+
25+
function columnwise_dots(A::AbstractGPUArray{T}, D::Diagonal, B::AbstractGPUArray{T}) where {T}
26+
sum(conj(A) .* (D.diag .* B); dims=1)
27+
end
28+
29+
function ldiv!(Y::AbstractGPUArray{T}, P::PreconditionerTPA, R::AbstractGPUArray{T}) where {T}
30+
if P.mean_kin === nothing
31+
ldiv!(Y, Diagonal(P.kin .+ P.default_shift), R)
32+
else
33+
Y .= (P.mean_kin' ./ (P.mean_kin' .+ P.kin)) .* R
34+
end
35+
Y
36+
end
37+
38+
function mul!(Y::AbstractGPUArray{T}, P::PreconditionerTPA, R::AbstractGPUArray{T}) where {T}
39+
if P.mean_kin === nothing
40+
mul!(Y, Diagonal(P.kin .+ P.default_shift), R)
41+
else
42+
Y .= ((P.mean_kin' .+ P.kin) ./ P.mean_kin') .* R
43+
end
44+
Y
45+
end

src/interpolation.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,11 @@ function interpolate_kpoint(data_in::AbstractVecOrMat,
105105
n_bands = size(data_in, 2)
106106
n_Gk_out = length(G_vectors(basis_out, kpoint_out))
107107
data_out = similar(data_in, n_Gk_out, n_bands) .= 0
108-
# TODO: use a map, or this will not be GPU compatible (scalar indexing)
109-
for iin = 1:size(data_in, 1)
110-
idx_fft = kpoint_in.mapping[iin]
111-
idx_fft in keys(kpoint_out.mapping_inv) || continue
112-
iout = kpoint_out.mapping_inv[idx_fft]
113-
data_out[iout, :] = data_in[iin, :]
114-
end
108+
109+
max_nG = max(length(G_vectors(basis_in)), length(G_vectors(basis_out)))
110+
tmp = similar(data_in, max_nG, n_bands) .= 0
111+
112+
tmp[kpoint_in.mapping, :] .= data_in
113+
data_out .= @view tmp[kpoint_out.mapping, :]
115114
ortho_qr(data_out) # Re-orthogonalize and renormalize
116115
end

0 commit comments

Comments
 (0)