Skip to content

Commit fcaa957

Browse files
committed
Refactoring assoc_array logic
Using a base class to avoid cross-includes between the API and AST directories.
1 parent f08f88d commit fcaa957

20 files changed

+450
-461
lines changed

Diff for: lib/coek/coek/CMakeLists.txt

+35-12
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,19 @@ SET(sources
66
util/option_cache.cpp
77
util/tictoc.cpp
88
util/string_utils.cpp
9+
)
10+
if (CMAKE_CXX_STANDARD GREATER_EQUAL 17)
11+
list(APPEND sources
12+
util/DataPortal.cpp
13+
)
14+
endif()
15+
16+
list(APPEND sources
917
ast/base_terms.cpp
1018
ast/constraint_terms.cpp
1119
ast/value_terms.cpp
1220
ast/expr_terms.cpp
21+
ast/visitor_expand.cpp
1322
ast/visitor_to_list.cpp
1423
ast/visitor_write_expr.cpp
1524
ast/visitor_to_MutableNLPExpr.cpp
@@ -20,18 +29,44 @@ SET(sources
2029
ast/visitor_simplify.cpp
2130
ast/visitor_eval.cpp
2231
#ast/varray.cpp
32+
)
33+
if (CMAKE_CXX_STANDARD GREATER_EQUAL 17)
34+
list(APPEND sources
35+
ast/compact_terms.cpp
36+
)
37+
endif()
38+
39+
list(APPEND sources
2340
api/constants.cpp
2441
api/expression.cpp
2542
api/expression_visitor.cpp
2643
api/objective.cpp
2744
api/constraint.cpp
2845
api/intrinsic_fn.cpp
46+
)
47+
if (CMAKE_CXX_STANDARD GREATER_EQUAL 17)
48+
list(APPEND sources
49+
api/parameter_assoc_array_repn.cpp
50+
api/parameter_assoc_array.cpp
51+
api/parameter_array.cpp
52+
api/variable_assoc_array_repn.cpp
53+
api/variable_assoc_array.cpp
54+
api/variable_array.cpp
55+
api/constraint_map.cpp
56+
api/subexpression_map.cpp
57+
)
58+
endif()
59+
60+
list(APPEND sources
2961
model/model.cpp
3062
model/compact_model.cpp
3163
model/nlp_model.cpp
3264
model/writer_lp.cpp
3365
model/writer_nl.cpp
3466
model/reader_jpof.cpp
67+
)
68+
69+
list(APPEND sources
3570
solvers/solver_results.cpp
3671
solvers/solver.cpp
3772
solvers/solver_repn.cpp
@@ -40,18 +75,6 @@ SET(sources
4075
autograd/autograd.cpp
4176
abstract/expr_rule.cpp
4277
)
43-
if (CMAKE_CXX_STANDARD GREATER_EQUAL 17)
44-
list(APPEND sources
45-
ast/compact_terms.cpp
46-
api/parameter_assoc_array.cpp
47-
api/parameter_array.cpp
48-
api/variable_assoc_array.cpp
49-
api/variable_array.cpp
50-
api/constraint_map.cpp
51-
api/subexpression_map.cpp
52-
util/DataPortal.cpp
53-
)
54-
endif()
5578

5679
if(with_compact)
5780
list(APPEND sources

Diff for: lib/coek/coek/api/expression.cpp

+1-9
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88

99
namespace coek {
1010

11-
#ifdef COEK_WITH_COMPACT_MODEL
12-
expr_pointer_t convert_expr_template(const expr_pointer_t& expr);
13-
#endif
14-
1511
//
1612
// Parameter
1713
//
@@ -371,11 +367,7 @@ Expression Expression::diff(const Variable& var) const
371367

372368
Expression Expression::expand()
373369
{
374-
#ifdef COEK_WITH_COMPACT_MODEL
375-
return convert_expr_template(repn);
376-
#else
377-
return *this;
378-
#endif
370+
return expand_expr(repn);
379371
}
380372

381373
std::ostream& operator<<(std::ostream& ostr, const Expression& arg)

Diff for: lib/coek/coek/api/parameter_array.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
#include "coek/ast/value_terms.hpp"
12
#include "coek/api/parameter_array.hpp"
2-
33
#include "coek/api/parameter_assoc_array_repn.hpp"
44
#include "coek/model/model.hpp"
55
#include "coek/model/model_repn.hpp"
@@ -39,7 +39,7 @@ class ParameterArrayRepn : public ParameterAssocArrayRepn {
3939

4040
virtual ~ParameterArrayRepn() {}
4141

42-
Parameter index(const IndexVector& args);
42+
std::shared_ptr<ParameterTerm> index(const IndexVector& args);
4343

4444
size_t dim() { return shape.size(); }
4545

@@ -79,7 +79,7 @@ void ParameterArrayRepn::generate_names()
7979
// If no name has been provided to this array object,
8080
// then we do not try to generate names. The default/simple
8181
// parameter names will be used.
82-
std::string name = parameter_template.name();
82+
std::string name = value_template.name();
8383
if (name.size() == 0)
8484
return;
8585

@@ -117,7 +117,7 @@ std::shared_ptr<ParameterAssocArrayRepn> ParameterArray::get_repn() { return rep
117117
Parameter ParameterArray::index(const IndexVector& args)
118118
{ return repn->index(args); }
119119

120-
Parameter ParameterArrayRepn::index(const IndexVector& args)
120+
std::shared_ptr<ParameterTerm> ParameterArrayRepn::index(const IndexVector& args)
121121
{
122122
//auto _repn = repn.get();
123123
//auto& shape = _repn->shape;
@@ -134,7 +134,7 @@ Parameter ParameterArrayRepn::index(const IndexVector& args)
134134
if (ndx > size()) {
135135
// TODO - Can't we do better than this check? Do we check if each index is in the correct
136136
// range?
137-
std::string err = "Unknown index value: " + parameter_template.name() + "[";
137+
std::string err = "Unknown index value: " + value_template.name() + "[";
138138
for (size_t i = 0; i < args.size(); i++) {
139139
if (i > 0)
140140
err += ",";
@@ -144,13 +144,13 @@ Parameter ParameterArrayRepn::index(const IndexVector& args)
144144
throw std::runtime_error(err);
145145
}
146146

147-
return values[ndx];
147+
return values[ndx].repn;
148148
}
149149

150150
void ParameterArray::index_error(size_t i)
151151
{
152152
auto _repn = repn.get();
153-
std::string err = "Unexpected index value: " + _repn->parameter_template.name() + " is an "
153+
std::string err = "Unexpected index value: " + _repn->value_template.name() + " is an "
154154
+ std::to_string(tmp.size()) + "-D parameter array but is being indexed with "
155155
+ std::to_string(i) + " indices.";
156156
throw std::runtime_error(err);

Diff for: lib/coek/coek/api/parameter_assoc_array.cpp

+4-111
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,10 @@
1-
#include "coek/api/parameter_assoc_array.hpp"
2-
31
#include <cmath>
42
#include <variant>
5-
6-
#include "coek/ast/compact_terms.hpp"
3+
#include "coek/ast/value_terms.hpp"
4+
#include "coek/api/parameter_assoc_array.hpp"
75

86
namespace coek {
97

10-
//
11-
// ParameterAssocArrayRepn
12-
//
13-
14-
ParameterAssocArrayRepn::ParameterAssocArrayRepn() { parameter_template.name("p"); }
15-
16-
void ParameterAssocArrayRepn::resize_index_vectors(IndexVector& tmp_,
17-
std::vector<refarg_types>& reftmp_)
18-
{
19-
auto dim_ = dim();
20-
#ifdef CUSTOM_INDEXVECTOR
21-
tmp = cache.alloc(dim_);
22-
tmp_ = cache.alloc(dim_);
23-
#else
24-
tmp.resize(dim_);
25-
tmp_.resize(dim_);
26-
#endif
27-
reftmp.resize(dim_);
28-
reftmp_.resize(dim_);
29-
}
30-
31-
void ParameterAssocArrayRepn::expand()
32-
{
33-
if (first_expand) {
34-
auto value = std::make_shared<ConstantTerm>(
35-
parameter_template.value_expression().expand().value());
36-
for (size_t i = 0; i < size(); i++) {
37-
values.emplace_back(CREATE_POINTER(ParameterTerm, value));
38-
}
39-
first_expand = false;
40-
}
41-
}
42-
43-
void ParameterAssocArrayRepn::value(double value)
44-
{
45-
parameter_template.value(value);
46-
if (values.size() > 0) {
47-
Expression e(value);
48-
for (auto& var : values)
49-
var.value(e);
50-
}
51-
}
52-
53-
void ParameterAssocArrayRepn::value(const Expression& value)
54-
{
55-
parameter_template.value(value);
56-
if (values.size() > 0) {
57-
for (auto& var : values)
58-
var.value(value);
59-
}
60-
}
61-
62-
void ParameterAssocArrayRepn::name(const std::string& name)
63-
{
64-
parameter_template.name(name);
65-
if (values.size() > 0) {
66-
// If the string is empty, then we reset the names of all variables
67-
if (name.size() == 0) {
68-
for (auto& var : values)
69-
var.name(name);
70-
}
71-
// Otherwise, we re-generate the names
72-
else
73-
generate_names();
74-
}
75-
}
76-
77-
//
78-
// ParameterAssocArray
79-
//
80-
818
size_t ParameterAssocArray::size() { return get_repn()->size(); }
829

8310
size_t ParameterAssocArray::dim() { return get_repn()->dim(); }
@@ -88,46 +15,12 @@ std::vector<Parameter>::iterator ParameterAssocArray::end() { return get_repn()-
8815

8916
#ifdef COEK_WITH_COMPACT_MODEL
9017
expr_pointer_t create_paramref(const std::vector<refarg_types>& indices, const std::string& name,
91-
std::shared_ptr<ParameterAssocArrayRepn>& var);
18+
std::shared_ptr<AssocArrayBase<ParameterTerm>> var);
9219

9320
Expression ParameterAssocArray::create_paramref(const std::vector<refarg_types>& args)
9421
{
9522
auto repn = get_repn();
96-
return coek::create_paramref(args, repn->parameter_template.name(), repn);
97-
}
98-
#endif
99-
100-
//
101-
// OTHER
102-
//
103-
104-
#ifdef COEK_WITH_COMPACT_MODEL
105-
expr_pointer_t get_concrete_param(ParameterRefTerm& paramref)
106-
{
107-
//* param = static_cast<ParameterAssocArray*>(paramref.param);
108-
auto param = std::dynamic_pointer_cast<ParameterAssocArrayRepn>(paramref.param);
109-
110-
std::vector<int> index;
111-
//for (auto it = paramref.indices.begin(); it != paramref.indices.end(); ++it) {
112-
// refarg_types& reftmp = *it;
113-
for (auto& reftmp : paramref.indices) {
114-
if (auto ival = std::get_if<int>(&reftmp))
115-
index.push_back(*ival);
116-
else {
117-
expr_pointer_t eval = std::get<expr_pointer_t>(reftmp);
118-
double vald = eval->eval();
119-
long int vali = std::lround(vald);
120-
assert(fabs(vald - vali) < 1e-7);
121-
index.push_back(static_cast<int>(vali));
122-
}
123-
}
124-
125-
IndexVector& tmp = param->tmp;
126-
for (size_t i = 0; i < index.size(); i++)
127-
tmp[i] = index[i];
128-
129-
Expression e = param->index(tmp);
130-
return e.repn;
23+
return coek::create_paramref(args, repn->value_template.name(), repn);
13124
}
13225
#endif
13326

Diff for: lib/coek/coek/api/parameter_assoc_array_repn.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "coek/ast/value_terms.hpp"
2+
#include "coek/api/parameter_assoc_array_repn.hpp"
3+
4+
namespace coek {
5+
6+
ParameterAssocArrayRepn::ParameterAssocArrayRepn()
7+
{
8+
value_template.name("p");
9+
}
10+
11+
void ParameterAssocArrayRepn::expand()
12+
{
13+
if (first_expand) {
14+
auto value = std::make_shared<ConstantTerm>( value_template.value_expression().expand().value() );
15+
//expand_expr(value_template.->value)->eval());
16+
for (size_t i = 0; i < size(); i++) {
17+
values.emplace_back(CREATE_POINTER(ParameterTerm, value));
18+
}
19+
first_expand = false;
20+
}
21+
}
22+
23+
void ParameterAssocArrayRepn::value(double value)
24+
{
25+
value_template.value(value);
26+
if (values.size() > 0) {
27+
Expression e(value);
28+
for (auto& var : values)
29+
var.value(e.repn);
30+
}
31+
}
32+
33+
void ParameterAssocArrayRepn::value(const Expression& value)
34+
{
35+
value_template.value(value);
36+
if (values.size() > 0) {
37+
for (auto& var : values)
38+
var.value(value);
39+
}
40+
}
41+
42+
void ParameterAssocArrayRepn::name(const std::string& name)
43+
{
44+
value_template.name(name);
45+
if (values.size() > 0) {
46+
// If the string is empty, then we reset the names of all variables
47+
if (name.size() == 0) {
48+
for (auto& var : values)
49+
var.name(name);
50+
}
51+
// Otherwise, we re-generate the names
52+
else
53+
generate_names();
54+
}
55+
}
56+
57+
}

0 commit comments

Comments
 (0)