@@ -269,48 +269,58 @@ function apply_symop(symop::SymOp, basis, ρin; kwargs...)
269269end
270270
271271# Accumulates the symmetrized versions of the density ρin into ρout (in Fourier space).
272- # No normalization is performed
272+ # No normalization is performed. This function is optimized for CPU and GPU.
273273function accumulate_over_symmetries! (ρaccu, ρin, basis:: PlaneWaveBasis{T} , symmetries) where {T}
274- for symop in symmetries
275- # Common special case, where ρin does not need to be processed
276- if isone (symop)
277- ρaccu .+ = ρin
278- continue
279- end
280-
281- # Transform ρin -> to the partial density at S * k.
282- #
283- # Since the eigenfunctions of the Hamiltonian at k and Sk satisfy
284- # u_{Sk}(x) = u_{k}(S^{-1} (x - τ))
285- # with Fourier transform
286- # ̂u_{Sk}(G) = e^{-i G \cdot τ} ̂u_k(S^{-1} G)
287- # equivalently
288- # ρ ̂_{Sk}(G) = e^{-i G \cdot τ} ̂ρ_k(S^{-1} G )
289- invS = Mat3 {Int} ( inv (symop . S) )
290- for (ig, G) in enumerate ( G_vectors_generator (basis . fft_size))
291- igired = index_G_vectors (basis, invS * G)
292- isnothing (igired) && continue
293-
294- if iszero (symop . τ)
295- @inbounds ρaccu[ig] += ρin[igired]
296- else
297- factor = cis2pi ( - T ( dot (G, symop . τ)))
298- @inbounds ρaccu[ig] += factor * ρin[igired]
299- end
274+ # For each G vector and symmetry S:
275+ # Transform ρin -> to the partial density at S * k.
276+ #
277+ # Since the eigenfunctions of the Hamiltonian at k and Sk satisfy
278+ # u_{Sk}(x) = u_{k}(S^{-1} (x - τ))
279+ # with Fourier transform
280+ # ̂u_{Sk}(G) = e^{-i G \cdot τ} ̂u_k(S^{-1} G)
281+ # equivalently
282+ # ρ ̂_{Sk}(G) = e^{-i G \cdot τ} ̂ρ_k(S^{-1} G)
283+ Gs = reshape ( G_vectors (basis), size (ρaccu))
284+ fft_size = basis . fft_size
285+
286+ # Need to transfer symmetry data to the GPU as isbit data
287+ symm_invS = to_device (basis . architecture, [ Mat3 {Int} ( inv (symop . S)) for symop in symmetries])
288+ symm_τ = to_device (basis . architecture, [symop . τ for symop in symmetries] )
289+ n_symm = length (symmetries )
290+
291+ # Looping over symmetries inside of map! on G vectors allow for a single GPU kernel launch
292+ map! (ρaccu, Gs) do G
293+ acc = zero ( complex (T))
294+ # Explicit loop over indicies because AMDGPU does not support zip() in map!
295+ for i_symm in 1 : n_symm
296+ invS = symm_invS[i_symm]
297+ τ = symm_τ[i_symm]
298+ idx = index_G_vectors (fft_size, invS * G)
299+ acc += isnothing (idx) ? zero ( complex (T)) : cis2pi ( - T ( dot (G, τ))) * ρin[idx]
300300 end
301- end # symop
301+ acc
302+ end
302303 ρaccu
303304end
304305
305306# Low-pass filters ρ (in Fourier) so that symmetry operations acting on it stay in the grid
307+ # This function is optimized for CPU and GPU.
306308function lowpass_for_symmetry! (ρ:: AbstractArray , basis; symmetries= basis. symmetries)
307- for symop in symmetries
308- isone (symop) && continue
309- for (ig, G) in enumerate (G_vectors_generator (basis. fft_size))
310- if index_G_vectors (basis, symop. S * G) === nothing
311- ρ[ig] = 0
312- end
309+ all (isone, symmetries) && return ρ
310+
311+ Gs = reshape (G_vectors (basis), size (ρ))
312+ fft_size = basis. fft_size
313+
314+ symm_S = to_device (basis. architecture, [symop. S for symop in symmetries])
315+
316+ # Loop structure optimized for both CPU and GPU
317+ map! (ρ, ρ, Gs) do ρ_i, G
318+ acc = ρ_i
319+ for S in symm_S
320+ idx = index_G_vectors (fft_size, S * G)
321+ acc *= isnothing (idx) ? 0 : 1
313322 end
323+ acc
314324 end
315325 ρ
316326end
@@ -320,15 +330,15 @@ Symmetrize a density by applying all the basis (by default) symmetries and formi
320330"""
321331@views @timing function symmetrize_ρ (basis, ρ:: AbstractArray{T} ;
322332 symmetries= basis. symmetries, do_lowpass= true ) where {T}
323- ρin_fourier = to_cpu ( fft (basis, ρ) )
333+ ρin_fourier = fft (basis, ρ)
324334 ρout_fourier = zero (ρin_fourier)
325335 for σ = 1 : size (ρ, 4 )
326336 accumulate_over_symmetries! (ρout_fourier[:, :, :, σ],
327337 ρin_fourier[:, :, :, σ], basis, symmetries)
328338 do_lowpass && lowpass_for_symmetry! (ρout_fourier[:, :, :, σ], basis; symmetries)
329339 end
330340 inv_fft = T <: Real ? irfft : ifft
331- inv_fft (basis, to_device (basis . architecture, ρout_fourier) ./ length (symmetries))
341+ inv_fft (basis, ρout_fourier ./ length (symmetries))
332342end
333343
334344"""
0 commit comments