Skip to content

Commit 6af41fa

Browse files
committed
fusion ControlToolboxTools
1 parent ccda01f commit 6af41fa

13 files changed

+558
-18
lines changed

Project.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
name = "CTBase"
22
uuid = "54762871-cc72-4466-b8e8-f6c8b58076cd"
33
authors = ["Olivier Cots <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
7-
ControlToolboxTools = "3a0bcf43-9180-47f3-913c-e71e0d69f39f"
87
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
98
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
109
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"

src/CTBase.jl

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
module CTBase
22

3-
# this should be in ControlToolboxTools, which should be renamed CTBase
4-
53
# using
64
using ForwardDiff: jacobian, gradient, ForwardDiff # automatic differentiation
75
using Parameters # @with_kw: permit to have default values in struct
86
using Interpolations: linear_interpolation, Line, Interpolations # for default interpolation
9-
#import Base: show # to print an OptimalControlModel
107
using Printf # to print a OptimalControlModel
11-
using ControlToolboxTools # tools: callbacks, exceptions, functions and more
8+
import Base: \, Base
129

1310
# --------------------------------------------------------------------------------------------------
1411
# Aliases for types
@@ -29,8 +26,24 @@ const Adjoint = MyVector # todo: ajouter type adjoint pour faire par exemple p*f
2926
const Dimension = Integer
3027

3128
#
32-
types() = MyNumber, MyVector, Time, Times, TimesDisc, States, Adjoints, Controls, State, Adjoint, Dimension
29+
#num_types() = MyNumber, MyVector, Time, Times, TimesDisc, States, Adjoints, Controls, State, Adjoint, Dimension
30+
31+
# General abstract type for callbacks
32+
abstract type CTCallback end
33+
const CTCallbacks = Tuple{Vararg{CTCallback}}
34+
35+
# A desription is a tuple of symbols
36+
const DescVarArg = Vararg{Symbol} # or Symbol...
37+
const Description = Tuple{DescVarArg}
38+
39+
#tools_types() = CTCallbacks, Description
3340

41+
#
42+
include("exceptions.jl")
43+
include("descriptions.jl")
44+
include("callbacks.jl")
45+
include("macros.jl")
46+
include("functions.jl")
3447
#
3548
include("utils.jl")
3649
#include("algorithms.jl")
@@ -39,33 +52,43 @@ include("print.jl")
3952
include("solutions.jl")
4053
include("default.jl")
4154

42-
#function solve(ocp::OptimalControlModel, algo::AbstractControlAlgorithm, method::Description; kwargs...)
43-
# error("solve not implemented")
44-
#end
45-
4655
#
47-
# export only for users
56+
# Numeric types
57+
export MyNumber, MyVector, Time, Times, TimesDisc
58+
export States, Adjoints, Controls, State, Adjoint, Dimension
59+
60+
# callback
61+
export CTCallback, CTCallbacks, PrintCallback, StopCallback
62+
export get_priority_print_callbacks, get_priority_stop_callbacks
63+
64+
# exceptions
65+
export CTException, AmbiguousDescription, InconsistentArgument, IncorrectMethod
66+
67+
# description
68+
export Description, makeDescription, add, getFullDescription
4869

4970
# utils
5071
export Ad, Poisson
5172

5273
# model
74+
export AbstractOptimalControlModel, OptimalControlModel
5375
export Model, time!, constraint!, objective!, state!, control!, remove_constraint!, constraint
5476
export ismin, dynamics, lagrange, criterion, initial_time, final_time
5577
export control_dimension, state_dimension, constraints, initial_condition, final_constraint
5678

5779
# solution
80+
export AbstractOptimalControlSolution, DirectSolution, DirectShootingSolution
5881
export time_steps_length, state_dimension, control_dimension
5982
export time_steps, state, control, adjoint, objective
6083
export iterations, success, message, stopping
6184
export constraints_violation
6285

63-
# export structs
64-
export AbstractOptimalControlModel, OptimalControlModel
65-
export AbstractOptimalControlSolution, DirectSolution, DirectShootingSolution
66-
#export AbstractControlAlgorithm, DirectAlgorithm, DirectShootingAlgorithm
86+
# macros
87+
export @callable, @time_dependence_function
6788

68-
# solve
69-
#export solve
89+
# functions
90+
export Hamiltonian, HamiltonianVectorField, VectorField
91+
export LagrangeFunction, DynamicsFunction, ControlFunction, MultiplierFunction
92+
export StateConstraintFunction, ControlConstraintFunction, MixedConstraintFunction
7093

7194
end

src/callbacks.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# --------------------------------------------------------------------------------------------------
2+
# Print callback
3+
mutable struct PrintCallback <: CTCallback
4+
callback::Function
5+
priority::Integer
6+
function PrintCallback(cb::Function; priority::Integer=1)
7+
new(cb, priority)
8+
end
9+
end
10+
# todo: essayer de mettre args... pour éviter de fixer ici les arguments
11+
function (cb::PrintCallback)(i, sᵢ, dᵢ, xᵢ, gᵢ, fᵢ)
12+
return cb.callback(i, sᵢ, dᵢ, xᵢ, gᵢ, fᵢ)
13+
end
14+
const PrintCallbacks = Tuple{Vararg{PrintCallback}}
15+
16+
#
17+
function get_priority_print_callbacks(cbs::CTCallbacks)
18+
callbacks_print = ()
19+
priority = -Inf
20+
21+
# search highest priority
22+
for cb in cbs
23+
if typeof(cb) === PrintCallback && cb.priority priority
24+
priority = cb.priority
25+
end
26+
end
27+
28+
# add callbacks
29+
for cb in cbs
30+
if typeof(cb) === PrintCallback && cb.priority == priority
31+
callbacks_print = (callbacks_print..., cb)
32+
end
33+
end
34+
return callbacks_print
35+
end
36+
37+
# Stop callback
38+
mutable struct StopCallback <: CTCallback
39+
callback::Function
40+
priority::Integer
41+
function StopCallback(cb::Function; priority::Integer=1)
42+
new(cb, priority)
43+
end
44+
end
45+
function (cb::StopCallback)(i, sᵢ, dᵢ, xᵢ, gᵢ, fᵢ, ng₀, optimalityTolerance, absoluteTolerance,
46+
stagnationTolerance, iterations)
47+
return cb.callback(i, sᵢ, dᵢ, xᵢ, gᵢ, fᵢ, ng₀, optimalityTolerance, absoluteTolerance,
48+
stagnationTolerance, iterations)
49+
end
50+
const StopCallbacks = Tuple{Vararg{StopCallback}}
51+
52+
#
53+
function get_priority_stop_callbacks(cbs::CTCallbacks)
54+
callbacks_stop = ()
55+
priority = -Inf
56+
57+
# search highest priority
58+
for cb in cbs
59+
if typeof(cb) === StopCallback && cb.priority priority
60+
priority = cb.priority
61+
end
62+
end
63+
64+
# add callbacks
65+
for cb in cbs
66+
if typeof(cb) === StopCallback && cb.priority == priority
67+
callbacks_stop = (callbacks_stop..., cb)
68+
end
69+
end
70+
return callbacks_stop
71+
end

src/descriptions.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# --------------------------------------------------------------------------------------------------
2+
# the description may be given as a tuple or a list of symbols (Vararg{Symbol})
3+
makeDescription(desc::DescVarArg) = Tuple(desc) # create a description from Vararg{Symbol}
4+
makeDescription(desc::Description) = desc
5+
6+
# --------------------------------------------------------------------------------------------------
7+
# Possible algorithms
8+
add(x::Tuple{}, y::Description) = (y,)
9+
add(x::Tuple{Vararg{Description}}, y::Description) = (x..., y)
10+
11+
# this function transform an incomplete description to a complete one
12+
function getFullDescription(desc::Description, desc_list)::Description
13+
n = length(desc_list)
14+
table = zeros(Int8, n, 2)
15+
for i in range(1, n)
16+
table[i, 1] = length(desc desc_list[i])
17+
table[i, 2] = desc desc_list[i] ? 1 : 0
18+
end
19+
if maximum(table[:, 2]) == 0
20+
throw(AmbiguousDescription(desc))
21+
end
22+
# argmax : Return the index or key of the maximal element in a collection.
23+
# If there are multiple maximal elements, then the first one will be returned.
24+
# This means that the first has the priority
25+
return desc_list[argmax(table[:, 1])]
26+
end
27+
28+
diff(x::Description, y::Description) = Tuple(setdiff(x, y))
29+
\(x::Description, y::Description) = diff(x, y)

src/exceptions.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# --------------------------------------------------------------------------------------------------
2+
# General abstract type for exceptions
3+
abstract type CTException <: Exception end
4+
5+
# ambiguous description
6+
struct AmbiguousDescription <: CTException
7+
var::Description
8+
end
9+
10+
"""
11+
Base.showerror(io::IO, e::AmbiguousDescription)
12+
13+
TBW
14+
"""
15+
Base.showerror(io::IO, e::AmbiguousDescription) = print(io, "the description ", e.var, " is ambiguous / incorrect")
16+
17+
# inconsistent argument
18+
struct InconsistentArgument <: CTException
19+
var::String
20+
end
21+
22+
"""
23+
Base.showerror(io::IO, e::InconsistentArgument)
24+
25+
TBW
26+
"""
27+
Base.showerror(io::IO, e::InconsistentArgument) = print(io, e.var)
28+
29+
# incorrect method
30+
struct IncorrectMethod <: CTException
31+
var::Symbol
32+
end
33+
34+
"""
35+
Base.showerror(io::IO, e::IncorrectMethod)
36+
37+
TBW
38+
"""
39+
Base.showerror(io::IO, e::IncorrectMethod) = print(io, e.var, " is not an existing method")

src/functions.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
abstract type AbstractCTFunction{time_dependence} <: Function end
2+
3+
@time_dependence_function Hamiltonian, AbstractCTFunction
4+
@time_dependence_function HamiltonianVectorField, AbstractCTFunction
5+
@time_dependence_function VectorField, AbstractCTFunction
6+
@time_dependence_function LagrangeFunction, AbstractCTFunction
7+
@time_dependence_function DynamicsFunction, AbstractCTFunction
8+
@time_dependence_function StateConstraintFunction, AbstractCTFunction
9+
@time_dependence_function ControlConstraintFunction, AbstractCTFunction
10+
@time_dependence_function MixedConstraintFunction, AbstractCTFunction
11+
@time_dependence_function ControlFunction, AbstractCTFunction
12+
@time_dependence_function MultiplierFunction, AbstractCTFunction

src/macros.jl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# exception thrower
2+
function throw_callable_error( msg = nothing )
3+
if isnothing(msg)
4+
throw(error("@callable input is incorrect, expected a struct"))
5+
else
6+
throw(error("@callable input is incorrect: ", msg))
7+
end
8+
end
9+
10+
macro callable(expr)
11+
12+
#= @show(expr)
13+
println("---")
14+
println(expr)
15+
println("---")
16+
println(typeof(expr))
17+
println("---")
18+
println(Meta.show_sexpr(expr))
19+
println("---")
20+
println(expr.head)
21+
println("---")
22+
println(expr.args[1])
23+
println("---")
24+
println(expr.args[2])
25+
println("---")
26+
println(expr.args[3])
27+
println("==================================================") =#
28+
29+
#dump(expr)
30+
31+
# first elements must be :struct
32+
if !hasproperty(expr, :head) || expr.head != Symbol(:struct)
33+
return :(throw_callable_error())
34+
end
35+
36+
#
37+
corps = expr.args[3]
38+
39+
# parametric struct or not
40+
curly=false
41+
if hasproperty(expr.args[2], :head) && expr.args[2].head == Symbol(:curly)
42+
struct_name = expr.args[2].args[1]
43+
struct_params = expr.args[2].args[2:end]
44+
curly=true
45+
else
46+
struct_name = expr.args[2]
47+
struct_params = ""
48+
end
49+
#println(struct_name)
50+
#println(struct_params)
51+
52+
fun = gensym("fun")
53+
54+
if curly
55+
esc(quote
56+
struct $struct_name{$(struct_params...)}
57+
$fun::Function
58+
$corps
59+
#function $struct_name{$(struct_params...)}(caller::Function, args...) where {$(struct_params...)}
60+
# new{$(struct_params...)}(caller, args...)
61+
#end
62+
#function $struct_name{$(struct_params...)}(args...; caller::Function) where {$(struct_params...)}
63+
# new{$(struct_params...)}(caller, args...)
64+
#end
65+
end
66+
(s::$struct_name{$(struct_params...)})(args...; kwargs...) where {$(struct_params...)} = s.$fun(args...; kwargs...)
67+
end)
68+
else
69+
esc(quote
70+
struct $struct_name
71+
$fun::Function
72+
$corps
73+
#function $struct_name(caller::Function, args...)
74+
# new(caller, args...)
75+
#end
76+
#function $struct_name(args...; caller::Function)
77+
# new(caller, args...)
78+
#end
79+
end
80+
(s::$struct_name)(args...; kwargs...) = s.$fun(args...; kwargs...)
81+
end)
82+
end
83+
end
84+
85+
86+
# generate callable structure with the managing of autonomous vs nonautonomous cases
87+
macro time_dependence_function(expr)
88+
89+
#dump(expr)
90+
91+
if !hasproperty(expr, :head) || expr.head != Symbol(:tuple)
92+
return :(throw_callable_error())
93+
end
94+
95+
function_name = expr.args[1]
96+
abstract_name = Symbol(:Abstract, function_name)
97+
abstract_heritance = expr.args[2]
98+
99+
esc(quote
100+
struct $(function_name){time_dependence} <: $(abstract_heritance){time_dependence}
101+
f::Function
102+
function $(function_name){time_dependence}(f::Function) where {time_dependence}
103+
if !(time_dependence [:autonomous, :nonautonomous])
104+
error("the function must be :autonomous or :nonautonomous")
105+
else
106+
new{time_dependence}(f)
107+
end
108+
end
109+
$(function_name)(f::Function) = new{:autonomous}(f)
110+
end
111+
function (F::$(function_name){time_dependence})(t, args...; kwargs...) where {time_dependence}
112+
return time_dependence==:autonomous ? F.f(args...; kwargs...) : F.f(t, args...; kwargs...)
113+
end
114+
end)
115+
116+
end

0 commit comments

Comments
 (0)