Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/scf/scf_callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,45 @@ function (cb::ScfDefaultCallback)(info)
info
end

"""
Callback that saves more convergence history
- `use_default`: Whether or not to also use the default callback to print a convergence
Comment thread
mfherbst marked this conversation as resolved.
Outdated
table.
- `history_extra_functions`: Dictionary of functions `f(; info...)` whose results are
computed during the iterations and then saved in `scfres.history_extra`.
"""
struct ScfSaveHistory
default::Union{ScfDefaultCallback, Nothing}
history_extra_functions::Dict
history_extra::Dict
end
function ScfSaveHistory(; use_default=true, history_extra_functions=Dict(), default_kwargs...)
if isempty(history_extra_functions)
return ScfDefaultCallback(; default_kwargs...)
else
default = use_default ? ScfDefaultCallback(; default_kwargs...) : nothing
history_extra = Dict(key => [] for (key, fun) in history_extra_functions)
return ScfSaveHistory(default, history_extra_functions, history_extra)
end
end

function (cb::ScfSaveHistory)(info)
# Calling default (ScfDefaultCallback) to print a convergence table.
if !isnothing(cb.default)
cb.default(info)
end
if info.stage == :iterate
for (key, fun) in cb.history_extra_functions
extra = fun(; info...)
cb.history_extra[key] = push!(cb.history_extra[key], extra)
end
end
if info.stage == :finalize
info = merge(info, (; cb.history_extra))
Comment thread
mfherbst marked this conversation as resolved.
Outdated
end
info
end

#
# Convergence checks
#
Expand Down
17 changes: 10 additions & 7 deletions src/scf/self_consistent_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,23 @@ Overview of parameters:

# Update info with results gathered so far
info_next = (; ham, basis, converged, stage=:iterate, algorithm="SCF",
ρin, α=damping, n_iter, nbandsalg.occupation_threshold,
ρin, ρout, α=damping, n_iter, nbandsalg.occupation_threshold,
runtime_ns=time_ns() - start_ns, nextstate...,
diagonalization=[nextstate.diagonalization])

# Compute the energy of the new state
if compute_consistent_energies
(; energies) = energy(basis, ψ, occupation; ρ=ρout, eigenvalues, εF)
end

# Update info with history
history_Etot = vcat(info.history_Etot, energies.total)
history_Δρ = vcat(info.history_Δρ, norm(Δρ) * sqrt(basis.dvol))
history_εF = vcat(info.history_εF, εF)
n_matvec = info.n_matvec + nextstate.n_matvec
info_next = merge(info_next, (; energies, history_Etot, history_Δρ, n_matvec))

info_next = merge(info_next, (; energies, history_Etot, history_Δρ, history_εF,
n_matvec))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the indention is weird, I'd do

Suggested change
info_next = merge(info_next, (; energies, history_Etot, history_Δρ, history_εF,
n_matvec))
info_next = merge(info_next, (; energies, history_Etot, history_Δρ, history_εF,
n_matvec))


# Apply mixing and pass it the full info as kwargs
ρnext = ρin .+ T(damping) .* mix_density(mixing, basis, Δρ; info_next...)

Expand All @@ -202,8 +206,7 @@ Overview of parameters:

info_init = (; ρin=ρ, ψ=ψ, occupation=nothing, eigenvalues=nothing, εF=nothing,
n_iter=0, n_matvec=0, timedout=false, converged=false,
history_Etot=T[], history_Δρ=T[])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you're killing an empty line, which I think would help the structure of the code.

history_Etot=T[], history_Δρ=T[], history_εF=T[])
# Convergence is flagged by is_converged inside the fixpoint_map.
_, info = solver(fixpoint_map, ρ, info_init; maxiter)

Expand All @@ -217,8 +220,8 @@ Overview of parameters:
scfres = (; ham, basis, energies, converged, nbandsalg.occupation_threshold,
ρ=ρout, α=damping, eigenvalues, occupation, εF, info.n_bands_converge,
info.n_iter, info.n_matvec, ψ, info.diagonalization, stage=:finalize,
info.history_Δρ, info.history_Etot, info.timedout,
runtime_ns=time_ns() - start_ns, algorithm="SCF")
info.history_Δρ, info.history_Etot, info.history_εF,
info.timedout, runtime_ns=time_ns() - start_ns, algorithm="SCF")
callback(scfres)
scfres
end