3333# other eigenvectors (which is not the case in many - all ? - other
3434# implementations)
3535
36+ # - Some functions are reimplemented in a GPU optimized way as part of
37+ # the DFTK CUDA Extension (ext/DFTKCUDAExt/lobpcg.jl).
38+
3639
3740# # TODO micro-optimization of buffer reuse
3841# # TODO write a version that doesn't assume that B is well-conditioned, and doesn't reuse B applications at all
@@ -170,7 +173,7 @@ function B_ortho!(X, BX)
170173 rdiv! (BX, U)
171174end
172175
173- normest (M) = maximum (abs .( diag (M) )) + norm (M - Diagonal (diag (M)))
176+ normest (M) = maximum (abs, diag (M)) + norm (M - Diagonal (diag (M)))
174177# Orthogonalizes X to tol
175178# Returns the new X, the number of Cholesky factorizations algorithm, and the
176179# growth factor by which small perturbations of X can have been magnified
@@ -252,24 +255,15 @@ end
252255
253256# Randomize the columns of X if the norm is below tol
254257function drop_small! (X:: AbstractArray{T} ; tol= 2 eps (real (T))) where {T}
255- dropped = Int[]
256- for i= 1 : size (X,2 )
257- n = norm (@views X[:,i])
258- if n <= tol
259- X[:,i] = randn (T, size (X,1 ))
260- push! (dropped, i)
261- end
262- end
258+ dropped = findall (n -> n <= tol, columnwise_norms (X))
259+ @views randn! (TaskLocalRNG (), X[:, dropped])
263260 dropped
264261end
265262
266263# Find X that is orthogonal, and B-orthogonal to Y, up to a tolerance tol.
267264@timing " ortho! X vs Y" function ortho! (X:: AbstractArray{T} , Y, BY; tol= 2 eps (real (T))) where {T}
268265 # normalize to try to cheaply improve conditioning
269- parallel_loop_over_range (1 : size (X, 2 )) do i
270- n = norm (@views X[:,i])
271- @views X[:,i] ./= n
272- end
266+ X ./= columnwise_norms (X)'
273267
274268 niter = 1
275269 ninners = zeros (Int, 0 )
322316end
323317
324318function final_retval (X, AX, BX, λ, resid_history, niter, n_matvec)
325- λ_host = oftype ( ones ( eltype (λ), 1 ), λ) # Copy to CPU for element-wise access
319+ λ_host = to_cpu ( λ) # Copy to CPU for element-wise access
326320 if ! issorted (λ_host)
327321 p = sortperm (λ_host)
328322 λ_host = λ_host[p]
@@ -336,6 +330,12 @@ function final_retval(X, AX, BX, λ, resid_history, niter, n_matvec)
336330 residual_history= resid_history[:, 1 : niter+ 1 ], n_matvec)
337331end
338332
333+ # Computes λ = real((X' * AX) / (X' *BX)), for each column of X
334+ function compute_λ (X, AX, BX)
335+ λs = @views [real ((X[:, n]' * AX[:, n]) / (X[:, n]' BX[:, n])) for n= 1 : size (X, 2 )]
336+ oftype (real (X[:, 1 ]), λs) # Offload to GPU if needed
337+ end
338+
339339# ## The algorithm is Xn+1 = rayleigh_ritz(hcat(Xn, A*Xn, Xn-Xn-1))
340340# ## We follow the strategy of Hetmaniuk and Lehoucq, and maintain a B-orthonormal basis Y = (X,R,P)
341341# ## After each rayleigh_ritz step, the B-orthonormal X and P are deduced by an orthogonal rotation from Y
389389 end
390390 nlocked = 0
391391 niter = 0 # the first iteration is fake
392- λs = @views [real ((X[:, n]' * AX[:, n]) / (X[:, n]' BX[:, n])) for n= 1 : M]
393- λs = oftype (real (X[:, 1 ]), λs) # Offload to GPU if needed
392+ λs = compute_λ (X, AX, BX)
394393 new_X = X
395394 new_AX = AX
396395 new_BX = BX
431430 # ## Compute new residuals
432431 @timing " Update residuals" begin
433432 new_R = new_AX .- new_BX .* λs'
434- @views for i = 1 : size (X, 2 )
435- resid_history[i + nlocked, niter+ 1 ] = norm (new_R[:, i])
436- end
433+ norms = to_cpu (columnwise_norms (new_R))
434+ @views resid_history[1 + nlocked: size (new_R, 2 ) + nlocked, niter+ 1 ] .= norms[:]
437435 end
438436 @debug niter resid_history[:, niter+ 1 ]
439437
512510 end
513511
514512 # Quick sanity check
515- for i = 1 : size (X, 2 )
516- @views if abs (BX[:, i]' X[:, i] - 1 ) >= sqrt (eps (real (eltype (X))))
517- error (" LOBPCG is badly failing to keep the vectors normalized; this should never happen" )
518- end
513+ diffs = abs .(columnwise_dots (BX, X) .- 1 )
514+ if any (diffs .>= sqrt (eps (real (eltype (X)))))
515+ error (" LOBPCG is badly failing to keep the vectors normalized; this should never happen" )
519516 end
520517
521518 # Restrict all views to active
0 commit comments