Skip to content

Commit c4aa19b

Browse files
committed
Several changes
1. Fixing compact walker. 2. Updating tests
1 parent 411b3ac commit c4aa19b

10 files changed

+53
-57
lines changed

Diff for: lib/coek/coek/api/constraint_map.cpp

+1-16
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
#include "coek/api/expression.hpp"
2-
#ifdef COEK_WITH_COMPACT_MODEL
3-
# include "coek/api/constraint.hpp"
4-
# include "coek/api/objective.hpp"
5-
# include "coek/model/compact_model.hpp"
6-
#endif
72
#include "coek/util/io_utils.hpp"
83
#include "coek/api/constraint_map.hpp"
94
#include "coek/api/indexed_container.defs.hpp"
@@ -18,19 +13,9 @@ void Model::add_constraint(ConstraintMap& cons)
1813
{
1914
if (repn->name_generation_policy == Model::NameGeneration::eager)
2015
cons.generate_names();
21-
else if (repn->name_generation_policy == Model::NameGeneration::lazy)
22-
repn->constraint_maps.push_back(cons);
16+
repn->constraint_maps.push_back(cons);
2317
for (auto& con : cons.repn->value)
2418
add_constraint(con.second);
2519
}
2620

27-
#ifdef COEK_WITH_COMPACT_MODEL
28-
void CompactModel::add_constraint(ConstraintMap& cons)
29-
{
30-
// TODO - name management here
31-
for (auto& con : cons.repn->value)
32-
add_constraint(con.second);
33-
}
34-
#endif
35-
3621
} // namespace coek

Diff for: lib/coek/coek/ast/visitor_to_list.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,23 @@ void visit_VariableRefTerm(const expr_pointer_t& expr, std::list<std::string>& r
5353
auto tmp = safe_pointer_cast<VariableRefTerm>(expr);
5454
std::stringstream sstr;
5555
write_expr(tmp, sstr);
56-
repr.push_back(sstr.str());
56+
repr.push_back("~" + sstr.str());
5757
}
5858

5959
void visit_ParameterRefTerm(const expr_pointer_t& expr, std::list<std::string>& repr)
6060
{
6161
auto tmp = safe_pointer_cast<ParameterRefTerm>(expr);
6262
std::stringstream sstr;
6363
write_expr(tmp, sstr);
64-
repr.push_back(sstr.str());
64+
repr.push_back("~" + sstr.str());
6565
}
6666

6767
void visit_DataRefTerm(const expr_pointer_t& expr, std::list<std::string>& repr)
6868
{
6969
auto tmp = safe_pointer_cast<DataRefTerm>(expr);
7070
std::stringstream sstr;
7171
write_expr(tmp, sstr);
72-
repr.push_back(sstr.str());
72+
repr.push_back("~" + sstr.str());
7373
}
7474
#endif
7575

Diff for: lib/coek/coek/compact/visitor_exprtemplate.cpp

+16-15
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ expr_pointer_t visit_IndexParameterTerm(const expr_pointer_t& expr)
3636
expr_pointer_t visit_VariableRefTerm(const expr_pointer_t& expr)
3737
{
3838
auto tmp = safe_pointer_cast<VariableRefTerm>(expr);
39+
auto ans = tmp->get_concrete_variable();
3940
return tmp->get_concrete_variable();
4041
}
4142

@@ -55,7 +56,7 @@ expr_pointer_t visit_ObjectiveTerm(const expr_pointer_t& expr)
5556
{
5657
auto tmp = safe_pointer_cast<ObjectiveTerm>(expr);
5758
auto body = visit_expression(tmp->body);
58-
if (body->id() == tmp->id())
59+
if (body == tmp)
5960
return expr;
6061
return std::make_shared<ObjectiveTerm>(body, tmp->sense);
6162
}
@@ -66,8 +67,8 @@ expr_pointer_t visit_InequalityTerm(const expr_pointer_t& expr)
6667
auto lower = tmp->lower ? visit_expression(tmp->lower) : tmp->lower;
6768
auto body = visit_expression(tmp->body);
6869
auto upper = tmp->upper ? visit_expression(tmp->upper) : tmp->upper;
69-
if ((not tmp->lower or (lower->id() == tmp->lower->id())) and (body->id() == tmp->body->id())
70-
and (not tmp->upper or (upper->id() == tmp->upper->id())))
70+
if ((not tmp->lower or (lower == tmp->lower)) and (body == tmp->body)
71+
and (not tmp->upper or (upper == tmp->upper)))
7172
return expr;
7273
return std::make_shared<InequalityTerm>(lower, body, upper);
7374
}
@@ -78,8 +79,8 @@ expr_pointer_t visit_StrictInequalityTerm(const expr_pointer_t& expr)
7879
auto lower = tmp->lower ? visit_expression(tmp->lower) : tmp->lower;
7980
auto body = visit_expression(tmp->body);
8081
auto upper = tmp->upper ? visit_expression(tmp->upper) : tmp->upper;
81-
if ((not tmp->lower or (lower->id() == tmp->lower->id())) and (body->id() == tmp->body->id())
82-
and (not tmp->upper or (upper->id() == tmp->upper->id())))
82+
if ((not tmp->lower or (lower == tmp->lower)) and (body == tmp->body)
83+
and (not tmp->upper or (upper == tmp->upper)))
8384
return expr;
8485
return std::make_shared<StrictInequalityTerm>(lower, body, upper);
8586
}
@@ -101,7 +102,7 @@ expr_pointer_t visit_NegateTerm(const expr_pointer_t& expr)
101102
{
102103
auto tmp = safe_pointer_cast<NegateTerm>(expr);
103104
auto body = visit_expression(tmp->body);
104-
if (body->id() == tmp->body->id())
105+
if (body == tmp->body)
105106
return expr;
106107
return std::make_shared<NegateTerm>(body);
107108
}
@@ -114,18 +115,18 @@ expr_pointer_t visit_PlusTerm(const expr_pointer_t& expr)
114115
auto lhs = visit_expression(data[0]);
115116
auto curr = visit_expression(data[1]);
116117
if (tmp->n == 2) {
117-
if ((lhs->id() == data[0]->id()) and (curr->id() == data[1]->id()))
118+
if ((lhs == data[0]) and (curr == data[1]))
118119
return expr;
119120
return std::make_shared<PlusTerm>(lhs, curr, false);
120121
}
121122

122123
auto _curr = std::make_shared<PlusTerm>(lhs, curr, false);
123-
bool flag = (lhs->id() == data[0]->id()) and (curr->id() == data[1]->id());
124+
bool flag = (lhs == data[0]) and (curr == data[1]);
124125

125126
for (size_t i = 2; i < tmp->num_expressions(); i++) {
126127
auto curr = visit_expression(data[i]);
127128
_curr->push_back(curr);
128-
flag = flag and (curr->id() == data[i]->id());
129+
flag = flag and (curr == data[i]);
129130
}
130131
if (flag)
131132
return expr;
@@ -137,7 +138,7 @@ expr_pointer_t visit_TimesTerm(const expr_pointer_t& expr)
137138
auto tmp = safe_pointer_cast<TimesTerm>(expr);
138139
auto lhs = visit_expression(tmp->lhs);
139140
auto rhs = visit_expression(tmp->rhs);
140-
if ((lhs->id() == tmp->lhs->id()) and (rhs->id() == tmp->rhs->id()))
141+
if ((lhs == tmp->lhs) and (rhs == tmp->rhs))
141142
return expr;
142143
return std::make_shared<TimesTerm>(lhs, rhs);
143144
}
@@ -147,7 +148,7 @@ expr_pointer_t visit_DivideTerm(const expr_pointer_t& expr)
147148
auto tmp = safe_pointer_cast<DivideTerm>(expr);
148149
auto lhs = visit_expression(tmp->lhs);
149150
auto rhs = visit_expression(tmp->rhs);
150-
if ((lhs->id() == tmp->lhs->id()) and (rhs->id() == tmp->rhs->id()))
151+
if ((lhs == tmp->lhs) and (rhs == tmp->rhs))
151152
return expr;
152153
return std::make_shared<DivideTerm>(lhs, rhs);
153154
}
@@ -158,8 +159,8 @@ expr_pointer_t visit_IfThenElseTerm(const expr_pointer_t& expr)
158159
auto cond_expr = visit_expression(tmp->cond_expr);
159160
auto then_expr = visit_expression(tmp->then_expr);
160161
auto else_expr = visit_expression(tmp->else_expr);
161-
if ((cond_expr->id() == tmp->cond_expr->id()) and (then_expr->id() == tmp->then_expr->id())
162-
and (else_expr->id() == tmp->else_expr->id()))
162+
if ((cond_expr == tmp->cond_expr) and (then_expr == tmp->then_expr)
163+
and (else_expr == tmp->else_expr))
163164
return expr;
164165
return std::make_shared<IfThenElseTerm>(cond_expr, then_expr, else_expr);
165166
}
@@ -169,7 +170,7 @@ expr_pointer_t visit_IfThenElseTerm(const expr_pointer_t& expr)
169170
{ \
170171
auto tmp = safe_pointer_cast<TERM>(expr); \
171172
auto body = visit_expression(tmp->body); \
172-
if (body->id() == tmp->body->id()) \
173+
if (body == tmp->body) \
173174
return expr; \
174175
return std::make_shared<TERM>(body); \
175176
}
@@ -199,7 +200,7 @@ expr_pointer_t visit_PowTerm(const expr_pointer_t& expr)
199200
auto tmp = safe_pointer_cast<PowTerm>(expr);
200201
auto lhs = visit_expression(tmp->lhs);
201202
auto rhs = visit_expression(tmp->rhs);
202-
if ((lhs->id() == tmp->lhs->id()) and (rhs->id() == tmp->rhs->id()))
203+
if ((lhs == tmp->lhs) and (rhs == tmp->rhs))
203204
return expr;
204205
return std::make_shared<PowTerm>(lhs, rhs);
205206
}

Diff for: lib/coek/coek/model/compact_model.cpp

+15-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
#include "coek/util/string_utils.hpp"
3+
#include "coek/util/io_utils.hpp"
34
#include "coek/ast/varray.hpp"
45
#include "coek/api/constraint.hpp"
56
#include "coek/api/objective.hpp"
@@ -196,6 +197,8 @@ Objective& CompactModel::add(Objective&& obj)
196197
return obj;
197198
}
198199

200+
// Constraint
201+
199202
Constraint CompactModel::add_constraint(const Constraint& expr)
200203
{
201204
repn->constraints.push_back(expr);
@@ -221,9 +224,9 @@ Constraint& CompactModel::add(Constraint&& expr)
221224
return expr;
222225
}
223226

224-
# if __cpp_lib_variant
227+
void CompactModel::add_constraint(ConstraintMap& cons) { repn->constraints.push_back(cons); }
228+
225229
void CompactModel::add(ConstraintMap& expr) { add_constraint(expr); }
226-
# endif
227230

228231
void CompactModel::add_constraint(const Constraint& expr, const SequenceContext& context)
229232
{
@@ -242,6 +245,8 @@ void CompactModel::add(ConstraintSequence& seq) { repn->constraints.push_back(se
242245

243246
void CompactModel::add(ConstraintSequence&& seq) { repn->constraints.push_back(seq); }
244247

248+
// Expand
249+
245250
Model CompactModel::expand()
246251
{
247252
// std::cout << "CompactModel::expand()" << std::endl;
@@ -359,12 +364,17 @@ Model CompactModel::expand()
359364
Constraint c = cval->expand();
360365
model.repn->constraints.push_back(c);
361366
}
362-
else {
363-
auto& seq = std::get<ConstraintSequence>(val);
364-
for (auto jt = seq.begin(); jt != seq.end(); ++jt) {
367+
else if (auto seq = std::get_if<ConstraintSequence>(&val)) {
368+
// std::cout << "HERE " << std::endl;
369+
for (auto jt = seq->begin(); jt != seq->end(); ++jt) {
370+
// std::cout << *jt << std::endl;
371+
// std::cout << jt->to_list() << std::endl;
365372
model.repn->constraints.push_back(*jt);
366373
}
367374
}
375+
else if (auto eval = std::get_if<ConstraintMap>(&val)) {
376+
model.add_constraint(*eval);
377+
}
368378
}
369379
return model;
370380
}

Diff for: lib/coek/coek/model/model_repn.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class CompactModelRepn {
133133
public:
134134
std::string name;
135135
std::vector<std::variant<Objective, ObjectiveSequence>> objectives;
136-
std::vector<std::variant<Constraint, ConstraintSequence>> constraints;
136+
std::vector<std::variant<Constraint, ConstraintMap, ConstraintSequence>> constraints;
137137
std::vector<std::variant<Variable, VariableSequence, VariableMap, VariableArray>> variables;
138138
std::vector<std::variant<DataMap, DataArray>> data;
139139
std::vector<std::variant<Parameter, ParameterMap, ParameterArray>> parameters;

Diff for: lib/coek/coek/model/writer_nl.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,19 @@ inline void visit_IndexParameterTerm(const expr_pointer_t&, TYPE&)
124124
template <class TYPE>
125125
inline void visit_VariableRefTerm(const expr_pointer_t&, TYPE&)
126126
{
127-
throw std::runtime_error("Cannot write an NL file using an abstract expression!");
127+
throw std::runtime_error("Cannot write an NL file using an abstract variable expression!");
128128
}
129129

130130
template <class TYPE>
131131
inline void visit_ParameterRefTerm(const expr_pointer_t&, TYPE&)
132132
{
133-
throw std::runtime_error("Cannot write an NL file using an abstract expression!");
133+
throw std::runtime_error("Cannot write an NL file using an abstract parameter expression!");
134134
}
135135

136136
template <class TYPE>
137137
inline void visit_DataRefTerm(const expr_pointer_t&, TYPE&)
138138
{
139-
throw std::runtime_error("Cannot write an NL file using an abstract expression!");
139+
throw std::runtime_error("Cannot write an NL file using an abstract data expression!");
140140
}
141141
#endif
142142

Diff for: lib/coek/test/smoke/test_data.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,15 @@ TEST_CASE("1D_data_map", "[smoke]")
263263
{
264264
auto i = coek::set_element("i");
265265
auto e = dats(i);
266-
static std::list<std::string> baseline = {"dats[i]"};
266+
static std::list<std::string> baseline = {"~dats[i]"};
267267
REQUIRE(e.to_list() == baseline);
268268
}
269269

270270
WHEN("index4")
271271
{
272272
auto i = coek::set_element("i");
273273
auto e = dats(i + 1);
274-
static std::list<std::string> baseline = {"dats[i + 1]"};
274+
static std::list<std::string> baseline = {"~dats[i + 1]"};
275275
REQUIRE(e.to_list() == baseline);
276276
}
277277
}
@@ -524,7 +524,7 @@ TEST_CASE("2D_data_map", "[smoke]")
524524
auto i = coek::set_element("i");
525525
auto j = coek::set_element("j");
526526
auto e = dats(i, j);
527-
static std::list<std::string> baseline = {"dats[i,j]"};
527+
static std::list<std::string> baseline = {"~dats[i,j]"};
528528
REQUIRE(e.to_list() == baseline);
529529
}
530530

@@ -533,7 +533,7 @@ TEST_CASE("2D_data_map", "[smoke]")
533533
auto i = coek::set_element("i");
534534
auto j = coek::set_element("j");
535535
auto e = dats(i + 1, j - 1);
536-
static std::list<std::string> baseline = {"dats[i + 1,j + -1]"};
536+
static std::list<std::string> baseline = {"~dats[i + 1,j + -1]"};
537537
REQUIRE(e.to_list() == baseline);
538538
}
539539
}

Diff for: lib/coek/test/smoke/test_model.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ TEST_CASE("model_names", "[smoke]")
210210
REQUIRE(m.repn->parameter_maps.size() == 2);
211211
REQUIRE(m.repn->variable_maps.size() == 2);
212212
# endif
213-
REQUIRE(m.repn->constraint_maps.size() == 0);
213+
REQUIRE(m.repn->constraint_maps.size() == 4);
214214
#endif
215215

216216
REQUIRE(coek::starts_with(p.name(), "p"));
@@ -386,7 +386,7 @@ TEST_CASE("model_names", "[smoke]")
386386
REQUIRE(m.repn->parameter_maps.size() == 2);
387387
REQUIRE(m.repn->variable_maps.size() == 2);
388388
# endif
389-
REQUIRE(m.repn->constraint_maps.size() == 0);
389+
REQUIRE(m.repn->constraint_maps.size() == 4);
390390
#endif
391391

392392
REQUIRE(p.name() == "p");

Diff for: lib/coek/test/smoke/test_param.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,15 @@ TEST_CASE("1D_param_map", "[smoke]")
263263
{
264264
auto i = coek::set_element("i");
265265
auto e = params(i);
266-
static std::list<std::string> baseline = {"params[i]"};
266+
static std::list<std::string> baseline = {"~params[i]"};
267267
REQUIRE(e.to_list() == baseline);
268268
}
269269

270270
WHEN("index4")
271271
{
272272
auto i = coek::set_element("i");
273273
auto e = params(i + 1);
274-
static std::list<std::string> baseline = {"params[i + 1]"};
274+
static std::list<std::string> baseline = {"~params[i + 1]"};
275275
REQUIRE(e.to_list() == baseline);
276276
}
277277
}
@@ -515,7 +515,7 @@ TEST_CASE("2D_param_map", "[smoke]")
515515
auto i = coek::set_element("i");
516516
auto j = coek::set_element("j");
517517
auto e = params(i, j);
518-
static std::list<std::string> baseline = {"params[i,j]"};
518+
static std::list<std::string> baseline = {"~params[i,j]"};
519519
REQUIRE(e.to_list() == baseline);
520520
}
521521

@@ -524,7 +524,7 @@ TEST_CASE("2D_param_map", "[smoke]")
524524
auto i = coek::set_element("i");
525525
auto j = coek::set_element("j");
526526
auto e = params(i + 1, j - 1);
527-
static std::list<std::string> baseline = {"params[i + 1,j + -1]"};
527+
static std::list<std::string> baseline = {"~params[i + 1,j + -1]"};
528528
REQUIRE(e.to_list() == baseline);
529529
}
530530
}

Diff for: lib/coek/test/smoke/test_var.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,15 @@ TEST_CASE("1D_var_map", "[smoke]")
316316
{
317317
auto i = coek::set_element("i");
318318
auto e = vars(i);
319-
static std::list<std::string> baseline = {"vars[i]"};
319+
static std::list<std::string> baseline = {"~vars[i]"};
320320
REQUIRE(e.to_list() == baseline);
321321
}
322322

323323
WHEN("index4")
324324
{
325325
auto i = coek::set_element("i");
326326
auto e = vars(i + 1);
327-
static std::list<std::string> baseline = {"vars[i + 1]"};
327+
static std::list<std::string> baseline = {"~vars[i + 1]"};
328328
REQUIRE(e.to_list() == baseline);
329329
}
330330
}
@@ -639,7 +639,7 @@ TEST_CASE("2D_var_map", "[smoke]")
639639
auto i = coek::set_element("i");
640640
auto j = coek::set_element("j");
641641
auto e = vars(i, j);
642-
static std::list<std::string> baseline = {"vars[i,j]"};
642+
static std::list<std::string> baseline = {"~vars[i,j]"};
643643
REQUIRE(e.to_list() == baseline);
644644
}
645645

@@ -648,7 +648,7 @@ TEST_CASE("2D_var_map", "[smoke]")
648648
auto i = coek::set_element("i");
649649
auto j = coek::set_element("j");
650650
auto e = vars(i + 1, j - 1);
651-
static std::list<std::string> baseline = {"vars[i + 1,j + -1]"};
651+
static std::list<std::string> baseline = {"~vars[i + 1,j + -1]"};
652652
REQUIRE(e.to_list() == baseline);
653653
}
654654
}

0 commit comments

Comments
 (0)