Skip to content

Commit fd1d34f

Browse files
authored
Merge pull request #107 from control-toolbox/init_for_direct
Init for direct
2 parents 23ac5d0 + cbde146 commit fd1d34f

File tree

7 files changed

+87
-17
lines changed

7 files changed

+87
-17
lines changed

src/CTDirect/problem.jl

+65-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function ADNLProblem(ocp::OptimalControlModel, N::Integer)
1+
function ADNLProblem(ocp::OptimalControlModel, N::Integer, init=nothing)
22

33
# direct_infos
44
t0, tf_, n_x, m, f, ξ, ψ, ϕ, dim_ξ, dim_ψ, dim_ϕ,
@@ -94,7 +94,7 @@ function ADNLProblem(ocp::OptimalControlModel, N::Integer)
9494
end
9595

9696
# bounds for the constraints
97-
function ipopt_l_u_b()
97+
function constraints_bounds()
9898
lb = zeros(nc)
9999
ub = zeros(nc)
100100
index = 1 # counter for the constraints
@@ -134,17 +134,75 @@ function ADNLProblem(ocp::OptimalControlModel, N::Integer)
134134
return lb, ub
135135
end
136136

137-
# todo: init a changer
138-
xu0 = 1.1*ones(dim_xu)
137+
# todo: retrieve optional bounds from ocp parsed constraints
138+
function variables_bounds()
139+
# unbounded case
140+
l_var = -Inf*ones(dim_xu)
141+
u_var = Inf*ones(dim_xu)
142+
return l_var, u_var
143+
end
144+
145+
# generate initial guess
146+
function set_state_at_time_step!(x, i, dim_x, N, xu)
147+
if i > N
148+
error("trying to set x(t_i) for i > N")
149+
else
150+
xu[1+i*dim_x:(i+1)*dim_x] = x[1:dim_x]
151+
end
152+
end
153+
154+
function set_control_at_time_step!(u, i, dim_x, N, m, xu)
155+
if i > N
156+
error("trying to set (t_i) for i > N")
157+
else
158+
xu[1+(N+1)*dim_x+i*m:m+(N+1)*dim_x+i*m] = u[1:m]
159+
end
160+
end
161+
162+
function initial_guess()
163+
#println("Initialization: ", init)
164+
165+
if init === nothing
166+
# default initialization
167+
xu0 = 1.1*ones(dim_xu)
168+
else
169+
if length(init) != (n_x + m)
170+
error("vector for initialization should be of size n+m",n_x+m)
171+
end
172+
# split state / control
173+
x_init = zeros(dim_x)
174+
x_init[1:n_x] = init[1:n_x]
175+
u_init = zeros(m)
176+
u_init[1:m] = init[n_x+1:n_x+m]
177+
178+
# mayer -> lagrange additional state
179+
if hasLagrangeCost
180+
x_init[dim_x] = 0.1
181+
end
182+
183+
# constant initialization
184+
xu0 = zeros(dim_xu)
185+
for i in 0:N
186+
set_state_at_time_step!(x_init, i, dim_x, N, xu0)
187+
set_control_at_time_step!(u_init, i, dim_x, N, m, xu0)
188+
end
189+
end
190+
return xu0
191+
end
192+
193+
# variables bounds
194+
l_var, u_var = variables_bounds()
195+
196+
# initial guess
197+
xu0 = initial_guess()
139198

140-
l_var = -Inf*ones(dim_xu)
141-
u_var = Inf*ones(dim_xu)
199+
# free final time case
142200
if has_free_final_time
143201
xu0[end] = 1.0
144202
l_var[end] = 1.e-3
145203
end
146204

147-
lb, ub = ipopt_l_u_b()
205+
lb, ub = constraints_bounds()
148206

149207
nlp = ADNLPModel(ipopt_objective, xu0, l_var, u_var, ipopt_constraint, lb, ub)
150208

src/CTDirect/solve.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ function direct_solve(ocp::OptimalControlModel,
55
print_level::Integer=__print_level_ipopt(),
66
mu_strategy::String=__mu_strategy_ipopt(),
77
display::Bool=__display(),
8+
init=nothing, #NB. for now, can be nothing or (n+m) vector
89
kwargs...)
910
"""
1011
Solve the optimal control problem
@@ -23,7 +24,8 @@ function direct_solve(ocp::OptimalControlModel,
2324
print_level = display ? print_level : 0
2425

2526
# from OCP to NLP
26-
nlp = ADNLProblem(ocp, grid_size)
27+
nlp = ADNLProblem(ocp, grid_size, init)
28+
#println("nlp x0:", nlp.meta.x0)
2729

2830
# solve by IPOPT: more info at
2931
# https://github.com/JuliaSmoothOptimizers/NLPModelsIpopt.jl/blob/main/src/NLPModelsIpopt.jl#L119

src/CTDirect/utils.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function get_state_at_time_step(xu, i, dim_x, N)
44
x(t_i)
55
"""
66
if i > N
7-
error("trying to access at x(t_i) for i > N")
7+
error("trying to get x(t_i) for i > N")
88
end
99
return xu[1+i*dim_x:(i+1)*dim_x]
1010
end
@@ -15,7 +15,7 @@ function get_control_at_time_step(xu, i, dim_x, N, m)
1515
u(t_i)
1616
"""
1717
if i > N
18-
error("trying to access at (t_i) for i > N")
18+
error("trying to get (t_i) for i > N")
1919
end
2020
return xu[1+(N+1)*dim_x+i*m:m+(N+1)*dim_x+i*m]
2121
end

test/test_basic_manual.jl renamed to test/manual_test_basic.jl

+6-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@ B = [ 0.0
1919
constraint!(ocp, :dynamics, (x, u) -> A*x + B*u[1])
2020
objective!(ocp, :lagrange, (x, u) -> 0.5u[1]^2) # default is to minimise
2121

22+
# initial guess (constant state and control functions)
23+
init = [1., 0.5, 0.3]
24+
2225
# solve
23-
sol = solve(ocp, grid_size=30)
26+
#sol = solve(ocp, grid_size=10, print_level=5)
27+
sol = solve(ocp, grid_size=10, print_level=5, init=init)
28+
2429

2530
# plot
2631
plot(sol)

test/test_goddard_manual.jl renamed to test/manual_test_goddard.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ end
4747

4848
constraint!(ocp, :dynamics, f)
4949

50-
sol = solve(ocp, 20)
50+
# initial guess (constant state and control functions)
51+
init = [1.01, 0.25, 0.5, 0.4]
52+
53+
#sol = solve(ocp, grid_size=20, print_level=5)
54+
sol = solve(ocp, grid_size=20, print_level=5, init=init)
5155

5256
plot(sol)

test/runtests.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ const __display = OptimalControl.CTBase.__display
2121
#
2222
@testset verbose = true showtiming = true "Optimal control tests" begin
2323
for name in (
24-
#"utils",
25-
#"direct_shooting_CTOptimization", # unconstrained direct simple shooting
24+
"utils",
25+
"direct_shooting_CTOptimization", # unconstrained direct simple shooting
2626
"basic",
27-
#"goddard_direct",
28-
#"goddard_indirect",
27+
"goddard_direct",
28+
"goddard_indirect",
2929
)
3030
@testset "$name" begin
3131
include("test_$name.jl")

test/test_goddard_direct.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ end
4747

4848
constraint!(ocp, :dynamics, f)
4949

50-
sol = solve(ocp, grid_size=10, print_level=0)
50+
init = [1.01, 0.25, 0.5, 0.4]
51+
sol = solve(ocp, grid_size=10, print_level=0, init=init)
5152

5253
@test objective(sol) -1.0 atol=1e-1
5354
@test constraints_violation(sol) < 1e-6

0 commit comments

Comments
 (0)