diff --git a/Project.toml b/Project.toml index 443c7e241f..a22441d3f1 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c" Optim = "429524aa-4258-5aef-a3af-852621145aeb" PeriodicTable = "7b2266bf-644c-5ea3-82d8-af4bbd25a884" PkgVersion = "eebad327-c553-4316-9ea0-9fa01ccd7688" @@ -110,7 +111,8 @@ LinearMaps = "3" Logging = "1" MPI = "0.20.22" Markdown = "1" -Optim = "1" +NLSolversBase = "8" +Optim = "2" PeriodicTable = "1" PkgVersion = "0.3" Plots = "1" diff --git a/src/scf/direct_minimization.jl b/src/scf/direct_minimization.jl index a48e622c63..b1eafcc7b1 100644 --- a/src/scf/direct_minimization.jl +++ b/src/scf/direct_minimization.jl @@ -1,6 +1,7 @@ # Direct minimization of the energy using Optim +using NLSolversBase: only_fg! using LineSearches # This is all a bit annoying because our ψ is represented as ψ[k][G,n], and Optim accepts @@ -111,31 +112,26 @@ function direct_minimization(basis::PlaneWaveBasis{T}; history_Etot = T[] history_Δρ = T[] - # Will be later overwritten by the Optim-internal state, which we need in the - # callback to access certain quantities for convergence control. - optim_state = nothing + function optim_callback(optim_state) + optim_state.pseudo_iteration < 1 && return false + converged && return true - function compute_ρout(ψ, optim_state) # This is the current preconditioned, but unscaled gradient, which implies that # the next step would be ρout - ρ. We thus record convergence, but let Optim do # one more step. δψ = unsafe_unpack(optim_state.s) # TODO This looks weird ... should there not be a retraction ? ψ_next = [ortho_qr(ψ[ik] - δψ[ik]) for ik in 1:Nk] - compute_density(basis, ψ_next, occupation) - end + ρout = compute_density(basis, ψ_next, occupation) - function optim_callback(ts) - ts.iteration < 1 && return false - converged && return true - ρout = compute_ρout(ψ, optim_state) Δρ = ρout - ρ push!(history_Δρ, norm(Δρ) * sqrt(basis.dvol)) push!(history_Etot, energies.total) info = (; ham, basis, energies, occupation, ρout, ρin=ρ, ψ, runtime_ns=time_ns() - start_ns, history_Δρ, history_Etot, - stage=:iterate, algorithm="DM", n_iter=ts.iteration, optim_state) + stage=:iterate, algorithm="DM", + n_iter=optim_state.pseudo_iteration, optim_state) converged = is_converged(info) info = callback(info) @@ -167,13 +163,13 @@ function direct_minimization(basis::PlaneWaveBasis{T}; optim_options = Optim.Options(; allow_f_increases=true, callback=optim_callback, # Disable convergence control by Optim - x_tol=-1, f_tol=-1, g_tol=-1, + x_abstol=NaN, f_abstol=NaN, g_abstol=NaN, iterations=maxiter, kwargs...) optim_solver = optim_method(; P, precondprep=precondprep!, manifold, linesearch, alphaguess) ψ_packed = pack(ψ) - objective = OnceDifferentiable(Optim.only_fg!(fg!), ψ_packed, zero(T); inplace=true) - optim_state = Optim.initial_state(optim_solver, optim_options, objective, ψ_packed) - res = Optim.optimize(objective, ψ_packed, optim_solver, optim_options, optim_state) + objective = OnceDifferentiable(only_fg!(fg!), ψ_packed, zero(T); inplace=true) + res = Optim.optimize(objective, ψ_packed, optim_solver, optim_options) + ψ = unpack(Optim.minimizer(res)) # Final Rayleigh-Ritz (not strictly necessary, but sometimes useful)