Skip to content

Commit 4d00737

Browse files
Merge pull request #833 from SciML/auglaggeneric
Make auglag generic and reusable with all solvers
2 parents f6a7301 + 9086bfd commit 4d00737

File tree

9 files changed

+280
-62
lines changed

9 files changed

+280
-62
lines changed

docs/src/index.md

+11-11
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ to add the specific wrapper packages.
8888
- Second order
8989
- Zeroth order
9090
- Box Constraints
91-
- Constrained 🟡
91+
- Constrained
9292
- <strong>Global Methods</strong>
9393
- Zeroth order
9494
- Unconstrained
@@ -126,21 +126,21 @@ to add the specific wrapper packages.
126126
- Zeroth order
127127
- Unconstrained
128128
- Box Constraints
129-
- Constrained 🟡
129+
- Constrained
130130
</details>
131131
<details>
132132
<summary><strong>NLopt</strong></summary>
133133
- <strong>Local Methods</strong>
134134
- First order
135135
- Zeroth order
136-
- Second order 🟡
136+
- Second order
137137
- Box Constraints
138-
- Local Constrained 🟡
138+
- Local Constrained
139139
- <strong>Global Methods</strong>
140140
- Zeroth order
141141
- First order
142142
- Unconstrained
143-
- Constrained 🟡
143+
- Constrained
144144
</details>
145145
<details>
146146
<summary><strong>Optim</strong></summary>
@@ -158,21 +158,21 @@ to add the specific wrapper packages.
158158
<details>
159159
<summary><strong>PRIMA</strong></summary>
160160
- <strong>Local Methods</strong>
161-
- Derivative-Free:
161+
- Derivative-Free:
162162
- **Constraints**
163-
- Box Constraints:
164-
- Local Constrained:
163+
- Box Constraints:
164+
- Local Constrained:
165165
</details>
166166
<details>
167167
<summary><strong>QuadDIRECT</strong></summary>
168168
- **Constraints**
169-
- Box Constraints:
169+
- Box Constraints:
170170
- <strong>Global Methods</strong>
171-
- Unconstrained:
171+
- Unconstrained:
172172
</details>
173173
```
174174

175-
🟡 = supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome
175+
= supported in downstream library but not yet implemented in `Optimization.jl`; PR to add this functionality are welcome
176176

177177
## Citation
178178

docs/src/optimization_packages/optimization.md

+20-20
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,28 @@ There are some solvers that are available in the Optimization.jl package directl
44

55
## Methods
66

7-
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
8-
9-
This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.
7+
- `LBFGS`: The popular quasi-Newton method that leverages limited memory BFGS approximation of the inverse of the Hessian. Through a wrapper over the [L-BFGS-B](https://users.iems.northwestern.edu/%7Enocedal/lbfgsb.html) fortran routine accessed from the [LBFGSB.jl](https://github.com/Gnimuc/LBFGSB.jl/) package. It directly supports box-constraints.
8+
9+
This can also handle arbitrary non-linear constraints through a Augmented Lagrangian method with bounds constraints described in 17.4 of Numerical Optimization by Nocedal and Wright. Thus serving as a general-purpose nonlinear optimization solver available directly in Optimization.jl.
1010

11-
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
11+
- `Sophia`: Based on the recent paper https://arxiv.org/abs/2305.14342. It incorporates second order information in the form of the diagonal of the Hessian matrix hence avoiding the need to compute the complete hessian. It has been shown to converge faster than other first order methods such as Adam and SGD.
12+
13+
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`
1214

13-
+ `solve(problem, Sophia(; η, βs, ϵ, λ, k, ρ))`
14-
15-
+ `η` is the learning rate
16-
+ `βs` are the decay of momentums
17-
+ `ϵ` is the epsilon value
18-
+ `λ` is the weight decay parameter
19-
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
20-
+ `ρ` is the momentum
21-
+ Defaults:
22-
23-
* `η = 0.001`
24-
* `βs = (0.9, 0.999)`
25-
* `ϵ = 1e-8`
26-
* `λ = 0.1`
27-
* `k = 10`
28-
* `ρ = 0.04`
15+
+ `η` is the learning rate
16+
+ `βs` are the decay of momentums
17+
+ `ϵ` is the epsilon value
18+
+ `λ` is the weight decay parameter
19+
+ `k` is the number of iterations to re-compute the diagonal of the Hessian matrix
20+
+ `ρ` is the momentum
21+
+ Defaults:
22+
23+
* `η = 0.001`
24+
* `βs = (0.9, 0.999)`
25+
* `ϵ = 1e-8`
26+
* `λ = 0.1`
27+
* `k = 10`
28+
* `ρ = 0.04`
2929

3030
## Examples
3131

lib/OptimizationOptimJL/src/OptimizationOptimJL.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,12 @@ function __map_optimizer_args(cache::OptimizationCache,
3838
abstol::Union{Number, Nothing} = nothing,
3939
reltol::Union{Number, Nothing} = nothing,
4040
kwargs...)
41-
4241
mapped_args = (; extended_trace = true, kwargs...)
4342

4443
if !isnothing(abstol)
4544
mapped_args = (; mapped_args..., f_abstol = abstol)
4645
end
47-
46+
4847
if !isnothing(callback)
4948
mapped_args = (; mapped_args..., callback = callback)
5049
end

src/Optimization.jl

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ include("utils.jl")
2424
include("state.jl")
2525
include("lbfgsb.jl")
2626
include("sophia.jl")
27+
include("auglag.jl")
2728

2829
export solve
2930

src/auglag.jl

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
@kwdef struct AugLag
2+
inner::Any
3+
τ = 0.5
4+
γ = 10.0
5+
λmin = -1e20
6+
λmax = 1e20
7+
μmin = 0.0
8+
μmax = 1e20
9+
ϵ = 1e-8
10+
end
11+
12+
SciMLBase.supports_opt_cache_interface(::AugLag) = true
13+
SciMLBase.allowsbounds(::AugLag) = true
14+
SciMLBase.requiresgradient(::AugLag) = true
15+
SciMLBase.allowsconstraints(::AugLag) = true
16+
SciMLBase.requiresconsjac(::AugLag) = true
17+
18+
function __map_optimizer_args(cache::Optimization.OptimizationCache, opt::AugLag;
19+
callback = nothing,
20+
maxiters::Union{Number, Nothing} = nothing,
21+
maxtime::Union{Number, Nothing} = nothing,
22+
abstol::Union{Number, Nothing} = nothing,
23+
reltol::Union{Number, Nothing} = nothing,
24+
verbose::Bool = false,
25+
kwargs...)
26+
if !isnothing(abstol)
27+
@warn "common abstol is currently not used by $(opt)"
28+
end
29+
if !isnothing(maxtime)
30+
@warn "common abstol is currently not used by $(opt)"
31+
end
32+
33+
mapped_args = (;)
34+
35+
if cache.lb !== nothing && cache.ub !== nothing
36+
mapped_args = (; mapped_args..., lb = cache.lb, ub = cache.ub)
37+
end
38+
39+
if !isnothing(maxiters)
40+
mapped_args = (; mapped_args..., maxiter = maxiters)
41+
end
42+
43+
if !isnothing(reltol)
44+
mapped_args = (; mapped_args..., pgtol = reltol)
45+
end
46+
47+
return mapped_args
48+
end
49+
50+
function SciMLBase.__solve(cache::OptimizationCache{
51+
F,
52+
RC,
53+
LB,
54+
UB,
55+
LC,
56+
UC,
57+
S,
58+
O,
59+
D,
60+
P,
61+
C
62+
}) where {
63+
F,
64+
RC,
65+
LB,
66+
UB,
67+
LC,
68+
UC,
69+
S,
70+
O <:
71+
AugLag,
72+
D,
73+
P,
74+
C
75+
}
76+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
77+
78+
local x
79+
80+
solver_kwargs = __map_optimizer_args(cache, cache.opt; maxiters, cache.solver_args...)
81+
82+
if !isnothing(cache.f.cons)
83+
eq_inds = [cache.lcons[i] == cache.ucons[i] for i in eachindex(cache.lcons)]
84+
ineq_inds = (!).(eq_inds)
85+
86+
τ = cache.opt.τ
87+
γ = cache.opt.γ
88+
λmin = cache.opt.λmin
89+
λmax = cache.opt.λmax
90+
μmin = cache.opt.μmin
91+
μmax = cache.opt.μmax
92+
ϵ = cache.opt.ϵ
93+
94+
λ = zeros(eltype(cache.u0), sum(eq_inds))
95+
μ = zeros(eltype(cache.u0), sum(ineq_inds))
96+
97+
cons_tmp = zeros(eltype(cache.u0), length(cache.lcons))
98+
cache.f.cons(cons_tmp, cache.u0)
99+
ρ = max(1e-6,
100+
min(10, 2 * (abs(cache.f(cache.u0, iterate(cache.p)[1]))) / norm(cons_tmp)))
101+
102+
_loss = function (θ, p = cache.p)
103+
x = cache.f(θ, p)
104+
cons_tmp .= zero(eltype(θ))
105+
cache.f.cons(cons_tmp, θ)
106+
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
107+
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
108+
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
109+
if cache.callback(opt_state, x...)
110+
error("Optimization halted by callback.")
111+
end
112+
return x[1] + sum(@. λ * cons_tmp[eq_inds] + ρ / 2 * (cons_tmp[eq_inds] .^ 2)) +
113+
1 / (2 * ρ) * sum((max.(Ref(0.0), μ .+.* cons_tmp[ineq_inds]))) .^ 2)
114+
end
115+
116+
prev_eqcons = zero(λ)
117+
θ = cache.u0
118+
β = max.(cons_tmp[ineq_inds], Ref(0.0))
119+
prevβ = zero(β)
120+
eqidxs = [eq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
121+
ineqidxs = [ineq_inds[i] > 0 ? i : nothing for i in eachindex(ineq_inds)]
122+
eqidxs = eqidxs[eqidxs .!= nothing]
123+
ineqidxs = ineqidxs[ineqidxs .!= nothing]
124+
function aug_grad(G, θ, p)
125+
cache.f.grad(G, θ, p)
126+
if !isnothing(cache.f.cons_jac_prototype)
127+
J = Float64.(cache.f.cons_jac_prototype)
128+
else
129+
J = zeros((length(cache.lcons), length(θ)))
130+
end
131+
cache.f.cons_j(J, θ)
132+
__tmp = zero(cons_tmp)
133+
cache.f.cons(__tmp, θ)
134+
__tmp[eq_inds] .= __tmp[eq_inds] .- cache.lcons[eq_inds]
135+
__tmp[ineq_inds] .= __tmp[ineq_inds] .- cache.ucons[ineq_inds]
136+
G .+= sum(
137+
λ[i] .* J[idx, :] + ρ * (__tmp[idx] .* J[idx, :])
138+
for (i, idx) in enumerate(eqidxs);
139+
init = zero(G)) #should be jvp
140+
G .+= sum(
141+
1 / ρ * (max.(Ref(0.0), μ[i] .+.* __tmp[idx])) .* J[idx, :])
142+
for (i, idx) in enumerate(ineqidxs);
143+
init = zero(G)) #should be jvp
144+
end
145+
146+
opt_ret = ReturnCode.MaxIters
147+
n = length(cache.u0)
148+
149+
augprob = OptimizationProblem(
150+
OptimizationFunction(_loss; grad = aug_grad), cache.u0, cache.p)
151+
152+
solver_kwargs = Base.structdiff(solver_kwargs, (; lb = nothing, ub = nothing))
153+
154+
for i in 1:(maxiters / 10)
155+
prev_eqcons .= cons_tmp[eq_inds] .- cache.lcons[eq_inds]
156+
prevβ .= copy(β)
157+
res = solve(augprob, cache.opt.inner, maxiters = maxiters / 10)
158+
θ = res.u
159+
cons_tmp .= 0.0
160+
cache.f.cons(cons_tmp, θ)
161+
λ = max.(min.(λmax, λ .+ ρ * (cons_tmp[eq_inds] .- cache.lcons[eq_inds])), λmin)
162+
β = max.(cons_tmp[ineq_inds], -1 .* μ ./ ρ)
163+
μ = min.(μmax, max.(μ .+ ρ * cons_tmp[ineq_inds], μmin))
164+
if max(norm(cons_tmp[eq_inds] .- cache.lcons[eq_inds], Inf), norm(β, Inf)) >
165+
τ * max(norm(prev_eqcons, Inf), norm(prevβ, Inf))
166+
ρ = γ * ρ
167+
end
168+
if norm(
169+
(cons_tmp[eq_inds] .- cache.lcons[eq_inds]) ./ cons_tmp[eq_inds], Inf) <
170+
ϵ && norm(β, Inf) < ϵ
171+
opt_ret = ReturnCode.Success
172+
break
173+
end
174+
end
175+
stats = Optimization.OptimizationStats(; iterations = maxiters,
176+
time = 0.0, fevals = maxiters, gevals = maxiters)
177+
return SciMLBase.build_solution(
178+
cache, cache.opt, θ, x,
179+
stats = stats, retcode = opt_ret)
180+
end
181+
end

test/diffeqfluxtests.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
7070
dudt2 = Lux.Chain(x -> x .^ 3,
7171
Lux.Dense(2, 50, tanh),
7272
Lux.Dense(50, 2))
73-
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
73+
prob_neuralode = NeuralODE(
74+
dudt2, tspan, Tsit5(), saveat = tsteps, abstol = 1e-8, reltol = 1e-8)
7475
pp, st = Lux.setup(rng, dudt2)
7576
pp = ComponentArray(pp)
7677

test/lbfgsb.jl

-27
This file was deleted.

0 commit comments

Comments
 (0)