Skip to content

Commit fc867d7

Browse files
committed
fix #112
1 parent f219558 commit fc867d7

3 files changed

Lines changed: 88 additions & 0 deletions

File tree

src/optimize/optimize.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,19 @@ function optimize(
6565
status
6666
end
6767

68+
# Handle callable objects (non-Function types) with algorithm instances
69+
function optimize(
70+
f,
71+
search_space,
72+
method::AbstractAlgorithm;
73+
logger::Function = (status) -> nothing,
74+
)
75+
# Wrap callable object into a function
76+
wrapped_f = _wrap_objective_function(f)
77+
# Call the standard optimize function with the wrapped function
78+
optimize(wrapped_f, search_space, method; logger)
79+
end
80+
6881

6982
"""
7083
optimize!(f, search_space, method;logger)
@@ -133,6 +146,19 @@ function optimize!(
133146
return method
134147
end
135148

149+
# Handle callable objects (non-Function types) with optimize!
150+
function optimize!(
151+
f,
152+
search_space,
153+
method::AbstractAlgorithm;
154+
logger::Function = (status) -> nothing,
155+
)
156+
# Wrap callable object into a function
157+
wrapped_f = _wrap_objective_function(f)
158+
# Call the standard optimize! function with the wrapped function
159+
optimize!(wrapped_f, search_space, method; logger)
160+
end
161+
136162

137163
function optimize(
138164
f::Function,
@@ -149,3 +175,17 @@ function optimize(
149175
# call optimize api
150176
optimize(f, problem.search_space, algo; logger)
151177
end
178+
179+
function optimize(
180+
f,
181+
_search_space,
182+
::Type{T};
183+
logger::Function = (status) -> nothing,
184+
kargs...
185+
) where T <: AbstractParameters
186+
187+
# Wrap callable object into a function
188+
wrapped_f = _wrap_objective_function(f)
189+
# Call the standard optimize function with the wrapped function
190+
optimize(wrapped_f, _search_space, T; logger, kargs...)
191+
end

src/optimize/utils.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,18 @@ function show_status(status, parameters, options)
8383
@info msg
8484
display(status)
8585
end
86+
87+
88+
function _wrap_objective_function(objective)
89+
# Check if the object is callable with an AbstractVector argument
90+
# We check with Any as a representative type
91+
if !hasmethod(objective, (Any,))
92+
error("The objective function should be callable object.")
93+
end
94+
if objective isa Function
95+
return objective
96+
end
97+
98+
args -> objective(args)
99+
end
100+

test/optimize_api.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,39 @@
4949
optimize(f, [lb ub], ECA, iterations=10, verbose = true)
5050
end
5151
end
52+
53+
# test callable objects (issue #112)
54+
struct CostFunc{T}
55+
a::T
56+
b::T
57+
end
58+
59+
function (A::CostFunc)(x)
60+
return A.a*length(x) + sum( x.^2 - A.b*cos.(2π*x) )
61+
end
62+
function test_callable_objects()
63+
64+
callable_object = CostFunc(2.0, 2.0)
65+
66+
# limits/bounds
67+
bounds = BoxConstrainedSpace(lb = -5ones(2), ub = 5ones(2))
68+
69+
# information on the minimization problem
70+
information = Information(f_optimum = 0.0)
71+
72+
# generic settings
73+
options = Options(f_tol = 1e-5, seed = 5)
74+
75+
# metaheuristic used to optimize
76+
algorithm = ECA(information = information, options = options)
77+
78+
# start the minimization process
79+
result = optimize(callable_object, bounds, algorithm)
80+
81+
@test minimum(result) < 1e-4
82+
end
83+
5284
test_optimize()
5385
search_space_optimize()
86+
test_callable_objects()
5487
end

0 commit comments

Comments
 (0)