-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathutils.jl
159 lines (140 loc) · 4.05 KB
/
utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
export coo_prod!, coo_sym_prod!
export @default_counters
export DimensionError, @lencheck, @rangecheck
"""
DimensionError <: Exception
DimensionError(name, dim_expected, dim_found)
Error for unexpected dimension.
Output: "DimensionError: Input `name` should have length `dim_expected` not `dim_found`"
"""
struct DimensionError <: Exception
name::Union{Symbol, String}
dim_expected::Int
dim_found::Int
end
function Base.showerror(io::IO, e::DimensionError)
print(
io,
"DimensionError: Input $(e.name) should have length $(e.dim_expected) not $(e.dim_found)",
)
end
# https://groups.google.com/forum/?fromgroups=#!topic/julia-users/b6RbQ2amKzg
"""
@lencheck n x y z …
Check that arrays `x`, `y`, `z`, etc. have a prescribed length `n`.
"""
macro lencheck(l, vars...)
exprs = Expr[]
for var in vars
varname = string(var)
push!(exprs, :(
if length($(esc(var))) != $(esc(l))
throw(DimensionError($varname, $(esc(l)), length($(esc(var)))))
end
))
end
Expr(:block, exprs...)
end
"""
@rangecheck ℓ u i j k …
Check that values `i`, `j`, `k`, etc. are in the range `[ℓ,u]`.
"""
macro rangecheck(lo, hi, vars...)
exprs = Expr[]
for var in vars
varname = string(var)
push!(
exprs,
:(
if (
length($(esc(var))) > 0 && (
any(broadcast(<, $(esc(var)), $(esc(lo)))) ||
any(broadcast(>, $(esc(var)), $(esc(hi))))
)
)
error(string($varname, " elements must be between ", $(esc(lo)), " and ", $(esc(hi))))
end
),
)
end
Expr(:block, exprs...)
end
const UnconstrainedErrorMessage = "Try to evaluate constraints, but the problem is unconstrained."
function check_unconstrained(nlp)
if unconstrained(nlp)
throw(error(UnconstrainedErrorMessage))
end
end
const NonlinearUnconstrainedErrorMessage = "Try to evaluate nonlinear constraints, but the problem has none."
function check_nonlinear_constraints(nlp)
if nlp.meta.nnln == 0
throw(error(NonlinearUnconstrainedErrorMessage))
end
end
const LinearUnconstrainedErrorMessage = "Try to evaluate linear constraints, but the problem has none."
function check_linear_constraints(nlp)
if nlp.meta.nlin == 0
throw(error(LinearUnconstrainedErrorMessage))
end
end
"""
coo_prod!(rows, cols, vals, v, Av)
Compute the product of a matrix `A` given by `(rows, cols, vals)` and the vector `v`.
The result is stored in `Av`, which should have length equals to the number of rows of `A`.
"""
function coo_prod!(
rows::AbstractVector{<:Integer},
cols::AbstractVector{<:Integer},
vals::AbstractVector,
v::AbstractVector,
Av::AbstractVector,
)
fill!(Av, zero(eltype(v)))
nnz = length(rows)
@inbounds for k = 1:nnz
i, j = rows[k], cols[k]
Av[i] += vals[k] * v[j]
end
return Av
end
"""
coo_sym_prod!(rows, cols, vals, v, Av)
Compute the product of a symmetric matrix `A` given by `(rows, cols, vals)` and the vector `v`.
The result is stored in `Av`, which should have length equals to the number of rows of `A`.
Only one triangle of `A` should be passed.
"""
function coo_sym_prod!(
rows::AbstractVector{<:Integer},
cols::AbstractVector{<:Integer},
vals::AbstractVector,
v::AbstractVector,
Av::AbstractVector,
)
fill!(Av, zero(eltype(v)))
nnz = length(rows)
@inbounds for k = 1:nnz
i, j, a = rows[k], cols[k], vals[k]
Av[i] += a * v[j]
if i != j
Av[j] += a * v[i]
end
end
return Av
end
"""
@default_counters Model inner
Define functions relating counters of `Model` to counters of `Model.inner`.
"""
macro default_counters(Model, inner)
ex = Expr(:block)
for foo in fieldnames(Counters) ∪ [:sum_counters]
push!(ex.args, :(NLPModels.$foo(nlp::$(esc(Model))) = $foo(nlp.$inner)))
end
push!(ex.args, :(NLPModels.reset!(nlp::$(esc(Model))) = begin
reset!(nlp.$inner)
reset_data!(nlp)
end))
push!(ex.args, :(NLPModels.increment!(nlp::$(esc(Model)), s::Symbol) = increment!(nlp.$inner, s)))
push!(ex.args, :(NLPModels.decrement!(nlp::$(esc(Model)), s::Symbol) = decrement!(nlp.$inner, s)))
ex
end